├── loss ├── __init__.py ├── __pycache__ │ ├── loss.cpython-37.pyc │ ├── loss.cpython-38.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ └── loss_embedding_mse.cpython-38.pyc └── loss_unlabel.py ├── model ├── __init__.py ├── __pycache__ │ ├── basic.cpython-37.pyc │ ├── basic.cpython-38.pyc │ ├── model_1.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── model_para.cpython-37.pyc │ ├── model_para.cpython-38.pyc │ ├── residual.cpython-37.pyc │ └── residual.cpython-38.pyc ├── model_para.py ├── monitor.py └── squeeze_excite.py ├── dataloader ├── dataloader.sh ├── __pycache__ │ ├── provider_valid.cpython-37.pyc │ ├── data_provider_labeled.cpython-37.pyc │ ├── data_provider_unlabel.cpython-37.pyc │ ├── data_provider_pretraining.cpython-37.pyc │ ├── data_provider_pretraining.cpython-38.pyc │ ├── provider_valid_pretraining.cpython-37.pyc │ ├── data_provider_pretraining_2d.cpython-37.pyc │ └── data_provider_pretraining_autoregress.cpython-38.pyc ├── unzip_tar.py └── test.py ├── mamba_local ├── mamba_ssm_local │ ├── ops │ │ ├── __init__.py │ │ ├── triton │ │ │ ├── __init__.py │ │ │ └── __pycache__ │ │ │ │ ├── __init__.cpython-38.pyc │ │ │ │ ├── layernorm.cpython-38.pyc │ │ │ │ └── selective_state_update.cpython-38.pyc │ │ └── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── selective_scan_interface.cpython-37.pyc │ │ │ └── selective_scan_interface.cpython-38.pyc │ ├── models │ │ ├── __init__.py │ │ └── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ └── mixer_seq_simple.cpython-38.pyc │ ├── modules │ │ ├── __init__.py │ │ └── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ └── mamba_simple.cpython-38.pyc │ ├── utils │ │ ├── __init__.py │ │ └── hf.py │ ├── causal_conv1d_local │ │ ├── AUTHORS │ │ ├── README.md │ │ ├── causal_conv1d │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-37.pyc │ │ │ │ └── __init__.cpython-38.pyc │ │ │ └── causal_conv1d_interface.py │ │ ├── csrc │ │ │ ├── static_switch.h │ │ │ ├── causal_conv1d.h │ │ │ ├── causal_conv1d_common.h │ │ │ └── causal_conv1d_update.cu │ │ └── LICENSE │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── __init__.cpython-38.pyc │ └── __init__.py ├── mamba_ssm.egg-info │ ├── dependency_links.txt │ ├── top_level.txt │ ├── requires.txt │ └── SOURCES.txt ├── build │ └── lib.linux-x86_64-3.8 │ │ └── mamba_ssm │ │ ├── models │ │ └── __init__.py │ │ ├── modules │ │ └── __init__.py │ │ ├── ops │ │ ├── __init__.py │ │ └── triton │ │ │ └── __init__.py │ │ ├── utils │ │ ├── __init__.py │ │ └── hf.py │ │ └── __init__.py ├── AUTHORS ├── assets │ └── selection.png ├── csrc │ └── selective_scan │ │ ├── selective_scan_bwd_fp16_real.cu │ │ ├── selective_scan_bwd_fp32_real.cu │ │ ├── selective_scan_bwd_bf16_real.cu │ │ ├── selective_scan_bwd_fp32_complex.cu │ │ ├── selective_scan_bwd_bf16_complex.cu │ │ ├── selective_scan_bwd_fp16_complex.cu │ │ ├── selective_scan_fwd_fp32.cu │ │ ├── selective_scan_fwd_fp16.cu │ │ ├── selective_scan_fwd_bf16.cu │ │ ├── static_switch.h │ │ ├── uninitialized_copy.cuh │ │ └── selective_scan.h ├── test_mamba_module.py ├── evals │ └── lm_harness_eval.py ├── tests │ └── ops │ │ └── triton │ │ └── test_selective_state_update.py └── benchmarks │ └── benchmark_generation_mamba_simple.py ├── results1.png ├── results2.png ├── framework1.png ├── framework2.png ├── visual_results.png ├── src ├── test_aff_img.png ├── test_raw_img.png ├── data_check.py ├── test_segmamba.py ├── run_mamba_mae.sh ├── run_MEC_seg.sh ├── run_mamba_mae_AR.sh ├── data_visual.py ├── run_mamba_seg.sh ├── launch.json ├── run_parallel_gaussian.sh ├── run_parallel.sh ├── run_parallel_superhuman.sh ├── eval_signle.py └── eval_single_25.py ├── data ├── __pycache__ │ ├── data_misc.cpython-37.pyc │ ├── data_misc.cpython-38.pyc │ ├── data_affinity.cpython-37.pyc │ ├── data_affinity.cpython-38.pyc │ ├── data_transform.cpython-37.pyc │ ├── data_transform.cpython-38.pyc │ ├── data_segmentation.cpython-37.pyc │ ├── data_segmentation.cpython-38.pyc │ └── total_data_provider.cpython-37.pyc ├── data_misc.py ├── total_data_provider.py └── data_affinity.py ├── util_mamba ├── __pycache__ │ ├── misc.cpython-37.pyc │ ├── misc.cpython-38.pyc │ ├── datasets.cpython-37.pyc │ ├── datasets.cpython-38.pyc │ ├── lr_decay.cpython-37.pyc │ ├── lr_decay.cpython-38.pyc │ ├── lr_sched.cpython-37.pyc │ ├── lr_sched.cpython-38.pyc │ ├── pos_embed.cpython-37.pyc │ ├── pos_embed.cpython-38.pyc │ └── MultiScaleAttention.cpython-37.pyc ├── lr_sched.py ├── crop.py ├── lars.py ├── datasets.py ├── lr_decay.py └── MultiScaleAttention.py ├── augmentation ├── __pycache__ │ ├── flip.cpython-36.pyc │ ├── flip.cpython-37.pyc │ ├── flip.cpython-38.pyc │ ├── mixup.cpython-36.pyc │ ├── mixup.cpython-37.pyc │ ├── mixup.cpython-38.pyc │ ├── warp.cpython-36.pyc │ ├── warp.cpython-37.pyc │ ├── warp.cpython-38.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── cutblur.cpython-36.pyc │ ├── cutblur.cpython-37.pyc │ ├── cutblur.cpython-38.pyc │ ├── cutnoise.cpython-36.pyc │ ├── cutnoise.cpython-37.pyc │ ├── cutnoise.cpython-38.pyc │ ├── misalign.cpython-36.pyc │ ├── misalign.cpython-37.pyc │ ├── misalign.cpython-38.pyc │ ├── rescale.cpython-36.pyc │ ├── rescale.cpython-37.pyc │ ├── rescale.cpython-38.pyc │ ├── rotation.cpython-36.pyc │ ├── rotation.cpython-37.pyc │ ├── rotation.cpython-38.pyc │ ├── augmentor.cpython-36.pyc │ ├── augmentor.cpython-37.pyc │ ├── augmentor.cpython-38.pyc │ ├── grayscale.cpython-36.pyc │ ├── grayscale.cpython-37.pyc │ ├── grayscale.cpython-38.pyc │ ├── composition.cpython-36.pyc │ ├── composition.cpython-37.pyc │ ├── composition.cpython-38.pyc │ ├── missing_parts.cpython-36.pyc │ ├── missing_parts.cpython-37.pyc │ ├── missing_parts.cpython-38.pyc │ ├── motion_blur.cpython-36.pyc │ ├── motion_blur.cpython-37.pyc │ ├── motion_blur.cpython-38.pyc │ ├── missing_section.cpython-36.pyc │ ├── missing_section.cpython-37.pyc │ ├── missing_section.cpython-38.pyc │ ├── test_augmentor.cpython-36.pyc │ ├── test_augmentor.cpython-37.pyc │ └── test_augmentor.cpython-38.pyc ├── missing_section.py ├── motion_blur.py ├── rotation.py ├── augmentor.py ├── cutnoise.py ├── mixup.py ├── flip.py ├── cutblur.py ├── warp.py ├── grayscale.py ├── test_augmentor.py ├── rescale.py ├── __init__.py ├── misalign.py └── missing_parts.py ├── utils ├── torch_utils.py ├── malis_loss.py ├── optim_weight_ema.py ├── utils.py ├── lmc.py ├── shift_channels.py ├── gen_pseudo.py ├── fragment.py ├── coordinate.py └── affinity_ours.py └── config ├── seg_3d_cremiA_data100.yaml ├── seg_3d_cremiB_data100.yaml ├── seg_3d_wafer26_data100.yaml ├── seg_3d_cremiC_data100.yaml ├── seg_3d_ac4_data80.yaml └── seg_3d_wafer4_data100.yaml /loss/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataloader/dataloader.sh: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mamba_local/mamba_ssm_local/ops/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mamba_local/mamba_ssm_local/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mamba_local/mamba_ssm_local/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mamba_local/mamba_ssm_local/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mamba_local/mamba_ssm_local/ops/triton/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mamba_local/mamba_ssm.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /mamba_local/build/lib.linux-x86_64-3.8/mamba_ssm/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mamba_local/build/lib.linux-x86_64-3.8/mamba_ssm/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mamba_local/build/lib.linux-x86_64-3.8/mamba_ssm/ops/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mamba_local/build/lib.linux-x86_64-3.8/mamba_ssm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mamba_local/build/lib.linux-x86_64-3.8/mamba_ssm/ops/triton/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mamba_local/AUTHORS: -------------------------------------------------------------------------------- 1 | Tri Dao, tri@tridao.me 2 | Albert Gu, agu@andrew.cmu.edu 3 | -------------------------------------------------------------------------------- /mamba_local/mamba_ssm.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | mamba_ssm 2 | selective_scan_cuda 3 | -------------------------------------------------------------------------------- /mamba_local/mamba_ssm_local/causal_conv1d_local/AUTHORS: -------------------------------------------------------------------------------- 1 | Tri Dao, tri@tridao.me 2 | -------------------------------------------------------------------------------- /results1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/results1.png -------------------------------------------------------------------------------- /results2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/results2.png -------------------------------------------------------------------------------- /framework1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/framework1.png -------------------------------------------------------------------------------- /framework2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/framework2.png -------------------------------------------------------------------------------- /visual_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/visual_results.png -------------------------------------------------------------------------------- /src/test_aff_img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/src/test_aff_img.png -------------------------------------------------------------------------------- /src/test_raw_img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/src/test_raw_img.png -------------------------------------------------------------------------------- /mamba_local/assets/selection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/mamba_local/assets/selection.png -------------------------------------------------------------------------------- /mamba_local/mamba_ssm_local/causal_conv1d_local/README.md: -------------------------------------------------------------------------------- 1 | # Causal depthwise conv1d in CUDA with a PyTorch interface 2 | -------------------------------------------------------------------------------- /loss/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/loss/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /loss/__pycache__/loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/loss/__pycache__/loss.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/basic.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/model/__pycache__/basic.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/basic.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/model/__pycache__/basic.cpython-38.pyc -------------------------------------------------------------------------------- /loss/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/loss/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /loss/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/loss/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mamba_local/mamba_ssm.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | causal_conv1d 2 | einops 3 | ninja 4 | packaging 5 | torch 6 | transformers 7 | triton 8 | -------------------------------------------------------------------------------- /model/__pycache__/model_1.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/model/__pycache__/model_1.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/data_misc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/data/__pycache__/data_misc.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/data_misc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/data/__pycache__/data_misc.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/model/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/model/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/model_para.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/model/__pycache__/model_para.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/model_para.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/model/__pycache__/model_para.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/residual.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/model/__pycache__/residual.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/residual.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/model/__pycache__/residual.cpython-38.pyc -------------------------------------------------------------------------------- /util_mamba/__pycache__/misc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/util_mamba/__pycache__/misc.cpython-37.pyc -------------------------------------------------------------------------------- /util_mamba/__pycache__/misc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/util_mamba/__pycache__/misc.cpython-38.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/flip.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/flip.cpython-36.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/flip.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/flip.cpython-37.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/flip.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/flip.cpython-38.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/mixup.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/mixup.cpython-36.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/mixup.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/mixup.cpython-37.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/mixup.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/mixup.cpython-38.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/warp.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/warp.cpython-36.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/warp.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/warp.cpython-37.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/warp.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/warp.cpython-38.pyc -------------------------------------------------------------------------------- /data/__pycache__/data_affinity.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/data/__pycache__/data_affinity.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/data_affinity.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/data/__pycache__/data_affinity.cpython-38.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/cutblur.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/cutblur.cpython-36.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/cutblur.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/cutblur.cpython-37.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/cutblur.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/cutblur.cpython-38.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/cutnoise.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/cutnoise.cpython-36.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/cutnoise.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/cutnoise.cpython-37.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/cutnoise.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/cutnoise.cpython-38.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/misalign.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/misalign.cpython-36.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/misalign.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/misalign.cpython-37.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/misalign.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/misalign.cpython-38.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/rescale.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/rescale.cpython-36.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/rescale.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/rescale.cpython-37.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/rescale.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/rescale.cpython-38.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/rotation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/rotation.cpython-36.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/rotation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/rotation.cpython-37.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/rotation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/rotation.cpython-38.pyc -------------------------------------------------------------------------------- /data/__pycache__/data_transform.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/data/__pycache__/data_transform.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/data_transform.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/data/__pycache__/data_transform.cpython-38.pyc -------------------------------------------------------------------------------- /util_mamba/__pycache__/datasets.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/util_mamba/__pycache__/datasets.cpython-37.pyc -------------------------------------------------------------------------------- /util_mamba/__pycache__/datasets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/util_mamba/__pycache__/datasets.cpython-38.pyc -------------------------------------------------------------------------------- /util_mamba/__pycache__/lr_decay.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/util_mamba/__pycache__/lr_decay.cpython-37.pyc -------------------------------------------------------------------------------- /util_mamba/__pycache__/lr_decay.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/util_mamba/__pycache__/lr_decay.cpython-38.pyc -------------------------------------------------------------------------------- /util_mamba/__pycache__/lr_sched.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/util_mamba/__pycache__/lr_sched.cpython-37.pyc -------------------------------------------------------------------------------- /util_mamba/__pycache__/lr_sched.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/util_mamba/__pycache__/lr_sched.cpython-38.pyc -------------------------------------------------------------------------------- /util_mamba/__pycache__/pos_embed.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/util_mamba/__pycache__/pos_embed.cpython-37.pyc -------------------------------------------------------------------------------- /util_mamba/__pycache__/pos_embed.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/util_mamba/__pycache__/pos_embed.cpython-38.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/augmentor.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/augmentor.cpython-36.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/augmentor.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/augmentor.cpython-37.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/augmentor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/augmentor.cpython-38.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/grayscale.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/grayscale.cpython-36.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/grayscale.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/grayscale.cpython-37.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/grayscale.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/grayscale.cpython-38.pyc -------------------------------------------------------------------------------- /data/__pycache__/data_segmentation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/data/__pycache__/data_segmentation.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/data_segmentation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/data/__pycache__/data_segmentation.cpython-38.pyc -------------------------------------------------------------------------------- /loss/__pycache__/loss_embedding_mse.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/loss/__pycache__/loss_embedding_mse.cpython-38.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/composition.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/composition.cpython-36.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/composition.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/composition.cpython-37.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/composition.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/composition.cpython-38.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/missing_parts.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/missing_parts.cpython-36.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/missing_parts.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/missing_parts.cpython-37.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/missing_parts.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/missing_parts.cpython-38.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/motion_blur.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/motion_blur.cpython-36.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/motion_blur.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/motion_blur.cpython-37.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/motion_blur.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/motion_blur.cpython-38.pyc -------------------------------------------------------------------------------- /data/__pycache__/total_data_provider.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/data/__pycache__/total_data_provider.cpython-37.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/provider_valid.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/dataloader/__pycache__/provider_valid.cpython-37.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/missing_section.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/missing_section.cpython-36.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/missing_section.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/missing_section.cpython-37.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/missing_section.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/missing_section.cpython-38.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/test_augmentor.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/test_augmentor.cpython-36.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/test_augmentor.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/test_augmentor.cpython-37.pyc -------------------------------------------------------------------------------- /augmentation/__pycache__/test_augmentor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/augmentation/__pycache__/test_augmentor.cpython-38.pyc -------------------------------------------------------------------------------- /util_mamba/__pycache__/MultiScaleAttention.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/util_mamba/__pycache__/MultiScaleAttention.cpython-37.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/data_provider_labeled.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/dataloader/__pycache__/data_provider_labeled.cpython-37.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/data_provider_unlabel.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/dataloader/__pycache__/data_provider_unlabel.cpython-37.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/data_provider_pretraining.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/dataloader/__pycache__/data_provider_pretraining.cpython-37.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/data_provider_pretraining.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/dataloader/__pycache__/data_provider_pretraining.cpython-38.pyc -------------------------------------------------------------------------------- /mamba_local/mamba_ssm_local/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/mamba_local/mamba_ssm_local/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /mamba_local/mamba_ssm_local/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/mamba_local/mamba_ssm_local/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/provider_valid_pretraining.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/dataloader/__pycache__/provider_valid_pretraining.cpython-37.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/data_provider_pretraining_2d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/dataloader/__pycache__/data_provider_pretraining_2d.cpython-37.pyc -------------------------------------------------------------------------------- /mamba_local/mamba_ssm_local/ops/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/mamba_local/mamba_ssm_local/ops/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /mamba_local/mamba_ssm_local/ops/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/mamba_local/mamba_ssm_local/ops/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mamba_local/mamba_ssm_local/causal_conv1d_local/causal_conv1d/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.0" 2 | 3 | from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update 4 | -------------------------------------------------------------------------------- /mamba_local/mamba_ssm_local/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/mamba_local/mamba_ssm_local/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mamba_local/mamba_ssm_local/modules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/mamba_local/mamba_ssm_local/modules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/data_provider_pretraining_autoregress.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/dataloader/__pycache__/data_provider_pretraining_autoregress.cpython-38.pyc -------------------------------------------------------------------------------- /mamba_local/mamba_ssm_local/modules/__pycache__/mamba_simple.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/mamba_local/mamba_ssm_local/modules/__pycache__/mamba_simple.cpython-38.pyc -------------------------------------------------------------------------------- /mamba_local/mamba_ssm_local/ops/triton/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/mamba_local/mamba_ssm_local/ops/triton/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mamba_local/mamba_ssm_local/ops/triton/__pycache__/layernorm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/mamba_local/mamba_ssm_local/ops/triton/__pycache__/layernorm.cpython-38.pyc -------------------------------------------------------------------------------- /mamba_local/mamba_ssm_local/models/__pycache__/mixer_seq_simple.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/mamba_local/mamba_ssm_local/models/__pycache__/mixer_seq_simple.cpython-38.pyc -------------------------------------------------------------------------------- /src/data_check.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | f_raw = h5py.File('/h3cstore_ns/Backbones/data/wafer/wafer36_inputs.h5', 'r') 3 | data = f_raw['main'][:] 4 | print(data.shape) 5 | f_raw.close() 6 | 7 | # data 8 | # import thop 9 | -------------------------------------------------------------------------------- /mamba_local/mamba_ssm_local/ops/__pycache__/selective_scan_interface.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/mamba_local/mamba_ssm_local/ops/__pycache__/selective_scan_interface.cpython-37.pyc -------------------------------------------------------------------------------- /mamba_local/mamba_ssm_local/ops/__pycache__/selective_scan_interface.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/mamba_local/mamba_ssm_local/ops/__pycache__/selective_scan_interface.cpython-38.pyc -------------------------------------------------------------------------------- /mamba_local/mamba_ssm_local/ops/triton/__pycache__/selective_state_update.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/mamba_local/mamba_ssm_local/ops/triton/__pycache__/selective_state_update.cpython-38.pyc -------------------------------------------------------------------------------- /mamba_local/mamba_ssm_local/causal_conv1d_local/causal_conv1d/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/mamba_local/mamba_ssm_local/causal_conv1d_local/causal_conv1d/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /mamba_local/mamba_ssm_local/causal_conv1d_local/causal_conv1d/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydchen0806/TokenUnify/HEAD/mamba_local/mamba_ssm_local/causal_conv1d_local/causal_conv1d/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mamba_local/build/lib.linux-x86_64-3.8/mamba_ssm/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.1" 2 | 3 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn 4 | from mamba_ssm.modules.mamba_simple import Mamba 5 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 6 | -------------------------------------------------------------------------------- /src/test_segmamba.py: -------------------------------------------------------------------------------- 1 | import torch 2 | try: 3 | from tensorboardX import SummaryWriter 4 | except: 5 | from torch.utils.tensorboard import SummaryWriter 6 | 7 | 8 | 9 | from segmamba import SegMamba 10 | 11 | 12 | 13 | model = SegMamba(in_chans=1, out_chans=3) 14 | device = torch.device('cuda') 15 | 16 | x = torch.randn(1, 1,16,160,160).to(device) 17 | print(x.shape) -------------------------------------------------------------------------------- /mamba_local/csrc/selective_scan/selective_scan_bwd_fp16_real.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba_local/csrc/selective_scan/selective_scan_bwd_fp32_real.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba_local/csrc/selective_scan/selective_scan_bwd_bf16_real.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba_local/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba_local/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba_local/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba_local/test_mamba_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from mamba_ssm import Mamba 3 | 4 | batch, length, dim = 2, 64, 768 5 | x = torch.randn(batch, length, dim).to("cuda") 6 | model = Mamba( 7 | # This module uses roughly 3 * expand * d_model^2 parameters 8 | d_model=dim, # Model dimension d_model 9 | d_state=16, # SSM state expansion factor # 64 10 | d_conv=4, # Local convolution width 11 | expand=2, # Block expansion factor 12 | use_fast_path=False, 13 | ).to("cuda") 14 | y = model(x) 15 | assert y.shape == x.shape 16 | -------------------------------------------------------------------------------- /mamba_local/mamba_ssm_local/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.1" 2 | try: 3 | from mamba_ssm_local.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn 4 | from mamba_ssm_local.modules.mamba_simple import Mamba 5 | from mamba_ssm_local.models.mixer_seq_simple import MambaLMHeadModel 6 | 7 | except: 8 | from .ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn 9 | from .modules.mamba_simple import Mamba 10 | from .models.mixer_seq_simple import MambaLMHeadModel 11 | -------------------------------------------------------------------------------- /mamba_local/csrc/selective_scan/selective_scan_fwd_fp32.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_fwd_kernel.cuh" 8 | 9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); 10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba_local/csrc/selective_scan/selective_scan_fwd_fp16.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_fwd_kernel.cuh" 8 | 9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); 10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba_local/csrc/selective_scan/selective_scan_fwd_bf16.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_fwd_kernel.cuh" 8 | 9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); 10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /dataloader/unzip_tar.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from multiprocessing import Pool 4 | from tqdm import tqdm 5 | from glob import glob 6 | import time 7 | 8 | 9 | path = '/braindat/lab/chenyd/DATASET/MSD' 10 | data_dir = sorted(glob(os.path.join(path, '*tar'))) 11 | def unzip_tar(data_dir): 12 | os.system(f'tar -xvf {data_dir} -C {path}') 13 | os.system(f'rm {data_dir}') 14 | print(f'{data_dir} is done') 15 | 16 | if __name__ == '__main__': 17 | t0 = time.time() 18 | pool = Pool(8) 19 | for _ in tqdm(pool.imap_unordered(unzip_tar, data_dir), total=len(data_dir)): 20 | pass 21 | pool.close() 22 | pool.join() 23 | t1 = time.time() 24 | print(f'All done, time cost: {t1-t0}') -------------------------------------------------------------------------------- /utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | from distutils.version import LooseVersion, StrictVersion 2 | import torch 3 | 4 | # align_corners option available for v1.3.0 and above 5 | HAS_AFFINE_ALIGN_CORNERS = LooseVersion(torch.__version__) >= LooseVersion('1.3.0') 6 | # align_corners defaults to True before v1.4.0, False from v1.4.0 and after 7 | AFFINE_ALIGN_CORNERS_DEFAULT = LooseVersion(torch.__version__) <= LooseVersion('1.3.0') 8 | 9 | 10 | def affine_align_corners_kw(val): 11 | if HAS_AFFINE_ALIGN_CORNERS: 12 | return dict(align_corners=val) 13 | else: 14 | if not val: 15 | raise RuntimeError('align_corners not available in torch version {} so ' 16 | 'cannot set to False'.format(torch.__version__)) 17 | return {} 18 | -------------------------------------------------------------------------------- /utils/malis_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from em_segLib.seg_malis import malis_init, malis_loss_weights_both 3 | from em_segLib.seg_util import mknhood3d 4 | 5 | def malis_loss(output_affs, test_label, seg): 6 | seg = seg.astype(np.uint64) 7 | conn_dims = np.array(output_affs.shape).astype(np.uint64) 8 | nhood_dims = np.array((3,3),dtype=np.uint64) 9 | nhood_data = mknhood3d(1).astype(np.int32).flatten() 10 | pre_ve, pre_prodDims, pre_nHood = malis_init(conn_dims, nhood_data, nhood_dims) 11 | weight = malis_loss_weights_both(seg.flatten(), conn_dims, nhood_data, nhood_dims, pre_ve, 12 | pre_prodDims, pre_nHood, output_affs.flatten(), test_label.flatten(), 0.5).reshape(conn_dims) 13 | malis = np.sum(weight * (output_affs - test_label) ** 2) 14 | return malis 15 | -------------------------------------------------------------------------------- /src/run_mamba_mae.sh: -------------------------------------------------------------------------------- 1 | pip3 uninstall -y timm && 2 | pip3 install timm==0.3.2 -i https://pypi.tuna.tsinghua.edu.cn/simple && 3 | sed -i 's/from torch._six import container_abcs/import collections.abc as container_abcs/g' /usr/local/lib/python3.8/dist-packages/timm/models/layers/helpers.py && 4 | # pip3 install --upgrade torchvision -i https://pypi.tuna.tsinghua.edu.cn/simple && 5 | export NCCL_P2P_DISABLE=1 && 6 | export ROOT_DIR=/h3cstore_ns/EM_pretrain/mamba_pretrain_MAE/segmamba0502_MAE_auto_fill_1_mask_0_4 && 7 | python3 -m torch.distributed.launch --nproc_per_node=8 /data/ydchen/VLP/EM_Mamba/mambamae_EM/main_pretrain.py --batch_size=40 \ 8 | --epochs=800 --model=segmamba --use_amp=True --output_dir=$ROOT_DIR \ 9 | --visual_dir=$ROOT_DIR/visual --log_dir=$ROOT_DIR/tensorboard_log \ 10 | --warmup_epochs=0 --fill_mode=1 --mask_ratio=0.4 -------------------------------------------------------------------------------- /src/run_MEC_seg.sh: -------------------------------------------------------------------------------- 1 | pip3 uninstall -y timm && 2 | pip3 install timm==0.3.2 -i https://pypi.tuna.tsinghua.edu.cn/simple && 3 | sed -i 's/from torch._six import container_abcs/import collections.abc as container_abcs/g' /usr/local/lib/python3.8/dist-packages/timm/models/layers/helpers.py && 4 | export NCCL_P2P_DISABLE=1 && 5 | python3 -m torch.distributed.launch --nproc_per_node=8 /h3cstore_ns/hyshi/EM_mamba_new/EM_mamba_seg/main_finetune2.py --batch_size=20 \ 6 | --epochs=800 --output_dir=/h3cstore_ns/hyshi/EM_mamba_new/result/superhuman_monai_wafer_lr5_b20_18_160_160_gaussian_8gpu \ 7 | --visual_dir=/h3cstore_ns/hyshi/EM_mamba_new/result/superhuman_monai_wafer_lr5_b20_18_160_160_gaussian_8gpu/visual \ 8 | --log_dir=/h3cstore_ns/hyshi/EM_mamba_new/result/superhuman_monai_wafer_lr5_b20_18_160_160_gaussian_8gpu/tensorboard_log \ 9 | --warmup_epochs=0 --blr=1e-5 -------------------------------------------------------------------------------- /util_mamba/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | def adjust_learning_rate(optimizer, epoch, args): 10 | """11l""" 11 | if epoch < args.warmup_epochs: 12 | lr = args.lr * epoch / args.warmup_epochs 13 | else: 14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr 22 | -------------------------------------------------------------------------------- /src/run_mamba_mae_AR.sh: -------------------------------------------------------------------------------- 1 | pip3 uninstall -y timm && 2 | pip3 install timm==0.3.2 -i https://pypi.tuna.tsinghua.edu.cn/simple && 3 | sed -i 's/from torch._six import container_abcs/import collections.abc as container_abcs/g' /usr/local/lib/python3.8/dist-packages/timm/models/layers/helpers.py && 4 | # pip3 install --upgrade torchvision -i https://pypi.tuna.tsinghua.edu.cn/simple && 5 | export NCCL_P2P_DISABLE=1 && 6 | export ROOT_DIR=/h3cstore_ns/EM_pretrain/mamba_pretrain_autoregress_0430_amp/0428segmamba_auto_mode${1}_fill_nose${2} && 7 | python3 -m torch.distributed.launch --nproc_per_node=8 /data/ydchen/VLP/EM_Mamba/mambamae_EM/main_pretrain_autoregress.py --batch_size=20 \ 8 | --epochs=800 --model=segmamba --use_amp=True --output_dir=$ROOT_DIR \ 9 | --visual_dir=$ROOT_DIR/visual --log_dir=$ROOT_DIR/tensorboard_log \ 10 | --warmup_epochs=0 --auto_mode=$1 --fill_nose=$2 --pretrain_path=$3 -------------------------------------------------------------------------------- /src/data_visual.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | from data.data_affinity import seg_to_aff 3 | from data.data_segmentation import seg_widen_border, weight_binary_ratio 4 | import numpy as np 5 | from PIL import Image 6 | 7 | # load raw data 8 | f_raw = h5py.File('/h3cstore_ns/Backbones/data/wafer/wafer4_inputs.h5', 'r') 9 | data = f_raw['main'][:] 10 | f_raw.close() 11 | 12 | # load labels 13 | f_label = h5py.File('/h3cstore_ns/Backbones/data/wafer/wafer4_labels.h5', 'r') 14 | label = f_label['main'][:] 15 | f_label.close() 16 | 17 | label = seg_widen_border(label, tsz_h=1) 18 | 19 | gt_aff = seg_to_aff(label).astype(np.float32) 20 | 21 | data_img = Image.fromarray((data[0,:]).astype(np.uint8)) 22 | aff_img = Image.fromarray((gt_aff[:, 0]*255).astype(np.uint8).transpose(1,2,0)) 23 | 24 | data_img.save("test_raw_img.png") 25 | aff_img.save("test_aff_img.png") 26 | 27 | 28 | print('done') 29 | 30 | # return data, label, gt_aff -------------------------------------------------------------------------------- /src/run_mamba_seg.sh: -------------------------------------------------------------------------------- 1 | pip3 uninstall -y timm && 2 | pip3 install timm==0.3.2 -i https://pypi.tuna.tsinghua.edu.cn/simple && 3 | sed -i 's/from torch._six import container_abcs/import collections.abc as container_abcs/g' /usr/local/lib/python3.8/dist-packages/timm/models/layers/helpers.py && 4 | export NCCL_P2P_DISABLE=1 && 5 | export NCCL_SOCKET_TIMEOUT=3600 && 6 | python3 -m torch.distributed.launch --nproc_per_node=8 /h3cstore_ns/hyshi/EM_mamba_new/EM_mamba_seg/main_finetune.py --batch_size=6 \ 7 | --epochs=400 \ 8 | --warmup_epochs=0 --blr=1e-4 9 | 10 | 11 | 12 | # sed -i 's/from torch._six import container_abcs/import collections.abc as container_abcs/g' /public/home/heart_llm/.local/lib/python3.10/site-packages/timm/models/layers/helpers.py && 13 | # sed -i 's/from torch._six import container_abcs/import collections.abc as container_abcs/g' /public/home/heart_llm/.local/lib/python3.8/site-packages/timm/models/layers/helpers.py && 14 | -------------------------------------------------------------------------------- /mamba_local/mamba_ssm_local/utils/hf.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import torch 4 | 5 | from transformers.utils import WEIGHTS_NAME, CONFIG_NAME 6 | from transformers.utils.hub import cached_file 7 | 8 | 9 | def load_config_hf(model_name): 10 | resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False) 11 | return json.load(open(resolved_archive_file)) 12 | 13 | 14 | def load_state_dict_hf(model_name, device=None, dtype=None): 15 | # If not fp32, then we don't want to load directly to the GPU 16 | mapped_device = "cpu" if dtype not in [torch.float32, None] else device 17 | resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False) 18 | return torch.load(resolved_archive_file, map_location=mapped_device) 19 | # Convert dtype before moving to GPU to save memory 20 | if dtype is not None: 21 | state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()} 22 | state_dict = {k: v.to(device=device) for k, v in state_dict.items()} 23 | return state_dict 24 | -------------------------------------------------------------------------------- /mamba_local/build/lib.linux-x86_64-3.8/mamba_ssm/utils/hf.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import torch 4 | 5 | from transformers.utils import WEIGHTS_NAME, CONFIG_NAME 6 | from transformers.utils.hub import cached_file 7 | 8 | 9 | def load_config_hf(model_name): 10 | resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False) 11 | return json.load(open(resolved_archive_file)) 12 | 13 | 14 | def load_state_dict_hf(model_name, device=None, dtype=None): 15 | # If not fp32, then we don't want to load directly to the GPU 16 | mapped_device = "cpu" if dtype not in [torch.float32, None] else device 17 | resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False) 18 | return torch.load(resolved_archive_file, map_location=mapped_device) 19 | # Convert dtype before moving to GPU to save memory 20 | if dtype is not None: 21 | state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()} 22 | state_dict = {k: v.to(device=device) for k, v in state_dict.items()} 23 | return state_dict 24 | -------------------------------------------------------------------------------- /utils/optim_weight_ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class EMAWeightOptimizer (object): 5 | def __init__(self, target_net, source_net, ema_alpha): 6 | self.target_net = target_net 7 | self.source_net = source_net 8 | self.ema_alpha = ema_alpha 9 | self.target_params = [p for p in target_net.state_dict().values() if p.dtype == torch.float] 10 | self.source_params = [p for p in source_net.state_dict().values() if p.dtype == torch.float] 11 | 12 | for tgt_p, src_p in zip(self.target_params, self.source_params): 13 | tgt_p[...] = src_p[...] 14 | 15 | target_keys = set(target_net.state_dict().keys()) 16 | source_keys = set(source_net.state_dict().keys()) 17 | if target_keys != source_keys: 18 | raise ValueError('Source and target networks do not have the same state dict keys; do they have different architectures?') 19 | 20 | 21 | def step(self): 22 | one_minus_alpha = 1.0 - self.ema_alpha 23 | for tgt_p, src_p in zip(self.target_params, self.source_params): 24 | tgt_p.mul_(self.ema_alpha) 25 | tgt_p.add_(src_p * one_minus_alpha) 26 | -------------------------------------------------------------------------------- /src/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // 使用 IntelliSense 了解相关属性。 3 | // 悬停以查看现有属性的描述。 4 | // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: ddp", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "/usr/local/lib/python3.8/dist-packages/torch/distributed/launch.py", 12 | "console": "integratedTerminal", 13 | "args": [ 14 | "--nproc_per_node=8", 15 | "/data/ydchen/VLP/EM_Mamba/EM_mamba_seg/main_finetune.py", 16 | "--batch_size=9", 17 | "--warmup_epochs=0", 18 | "--model=mae_vit_base_patch16_EM", 19 | "--epochs=400", 20 | ], 21 | "env": { 22 | "NCCL_SOCKET_IFNAME": "eth0" 23 | } 24 | }, 25 | { 26 | "name": "Python: 当前文件", 27 | "type": "python", 28 | "request": "launch", 29 | "program": "${file}", 30 | "console": "integratedTerminal", 31 | "justMyCode": true 32 | } 33 | ] 34 | } -------------------------------------------------------------------------------- /loss/loss_unlabel.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | from torch.nn.modules.loss import _Loss 3 | import torch.nn.functional as F 4 | import torch 5 | import numpy as np 6 | 7 | class MSELoss_unlabel(_Loss): 8 | def __init__(self): 9 | super(MSELoss_unlabel, self).__init__() 10 | 11 | def forward(self, input_y, target, weight): 12 | # assert target.requires_grad is False 13 | weight = weight.float() 14 | target = target.float() 15 | loss = weight * ((input_y - target) ** 2) 16 | loss = torch.sum(loss) / torch.sum(weight) 17 | return loss 18 | 19 | 20 | class BCELoss_unlabel(_Loss): 21 | def __init__(self): 22 | super(BCELoss_unlabel, self).__init__() 23 | 24 | def forward(self, input_y, target, weight): 25 | assert target.requires_grad is False 26 | input_y = torch.clamp(input_y, min=0.000001, max=0.999999) 27 | weight = weight.float() 28 | target = target.float() 29 | loss = -weight* (target * torch.log(input_y) + (1 - target) * torch.log(1 - input_y)) 30 | loss = torch.sum(loss) / torch.sum(weight) 31 | return loss 32 | -------------------------------------------------------------------------------- /model/model_para.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def model_structure(model): 4 | blank = ' ' 5 | print('-'*90) 6 | print('|'+' '*11+'weight name'+' '*10+'|' \ 7 | +' '*15+'weight shape'+' '*15+'|' \ 8 | +' '*3+'number'+' '*3+'|') 9 | print('-'*90) 10 | num_para = 0 11 | type_size = 1 12 | 13 | for index, (key, w_variable) in enumerate(model.named_parameters()): 14 | if len(key) <= 30: 15 | key = key + (30-len(key)) * blank 16 | shape = str(w_variable.shape) 17 | if len(shape) <= 40: 18 | shape = shape + (40-len(shape)) * blank 19 | each_para = 1 20 | for k in w_variable.shape: 21 | each_para *= k 22 | num_para += each_para 23 | str_num = str(each_para) 24 | if len(str_num) <= 10: 25 | str_num = str_num + (10-len(str_num)) * blank 26 | 27 | print('| {} | {} | {} |'.format(key, shape, str_num)) 28 | print('-'*90) 29 | print('The total number of parameters: ' + str(num_para)) 30 | print('The parameters of Model {}: {:4f}M'.format(model._get_name(), num_para * type_size / 1000 / 1000)) 31 | print('-'*90) 32 | -------------------------------------------------------------------------------- /mamba_local/mamba_ssm.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | README.md 2 | setup.py 3 | csrc/selective_scan/selective_scan.cpp 4 | csrc/selective_scan/selective_scan_bwd_bf16_complex.cu 5 | csrc/selective_scan/selective_scan_bwd_bf16_real.cu 6 | csrc/selective_scan/selective_scan_bwd_fp16_complex.cu 7 | csrc/selective_scan/selective_scan_bwd_fp16_real.cu 8 | csrc/selective_scan/selective_scan_bwd_fp32_complex.cu 9 | csrc/selective_scan/selective_scan_bwd_fp32_real.cu 10 | csrc/selective_scan/selective_scan_fwd_bf16.cu 11 | csrc/selective_scan/selective_scan_fwd_fp16.cu 12 | csrc/selective_scan/selective_scan_fwd_fp32.cu 13 | mamba_ssm/__init__.py 14 | mamba_ssm.egg-info/PKG-INFO 15 | mamba_ssm.egg-info/SOURCES.txt 16 | mamba_ssm.egg-info/dependency_links.txt 17 | mamba_ssm.egg-info/requires.txt 18 | mamba_ssm.egg-info/top_level.txt 19 | mamba_ssm/models/__init__.py 20 | mamba_ssm/models/mixer_seq_simple.py 21 | mamba_ssm/modules/__init__.py 22 | mamba_ssm/modules/mamba_simple.py 23 | mamba_ssm/ops/__init__.py 24 | mamba_ssm/ops/selective_scan_interface.py 25 | mamba_ssm/ops/triton/__init__.py 26 | mamba_ssm/ops/triton/layernorm.py 27 | mamba_ssm/ops/triton/selective_state_update.py 28 | mamba_ssm/utils/__init__.py 29 | mamba_ssm/utils/generation.py 30 | mamba_ssm/utils/hf.py -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | import subprocess 5 | 6 | def setup_seed(seed): 7 | torch.manual_seed(seed) 8 | # torch.cuda.manual_seed(seed) 9 | torch.cuda.manual_seed_all(seed) 10 | np.random.seed(seed) 11 | random.seed(seed) 12 | torch.backends.cudnn.deterministic = True 13 | 14 | def execute(cmd): 15 | popen = subprocess.Popen(cmd, stdout=subprocess.PIPE, universal_newlines=True) 16 | for stdout_line in iter(popen.stdout.readline, ""): 17 | yield stdout_line 18 | popen.stdout.close() 19 | return_code = popen.wait() 20 | if return_code: 21 | raise subprocess.CalledProcessError(return_code, cmd) 22 | 23 | def center_crop(image, det_shape=[18, 160, 160]): 24 | # To prevent overflow 25 | image = np.pad(image, ((2,2),(20,20),(20,20)), mode='reflect') 26 | src_shape = image.shape 27 | shift0 = (src_shape[0] - det_shape[0]) // 2 28 | shift1 = (src_shape[1] - det_shape[1]) // 2 29 | shift2 = (src_shape[2] - det_shape[2]) // 2 30 | assert shift0 > 0 or shift1 > 0 or shift2 > 0, "overflow in center-crop" 31 | image = image[shift0:shift0+det_shape[0], shift1:shift1+det_shape[1], shift2:shift2+det_shape[2]] 32 | return image -------------------------------------------------------------------------------- /augmentation/missing_section.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | from .augmentor import DataAugment 4 | 5 | class MissingSection(DataAugment): 6 | """Missing-section augmentation of image stacks. 7 | 8 | Args: 9 | num_sections (int): number of missing sections. Default: 2 10 | p (float): probability of applying the augmentation. Default: 0.5 11 | """ 12 | def __init__(self, num_sections=2, p=0.5): 13 | super(MissingSection, self).__init__(p=p) 14 | self.num_sections = num_sections 15 | self.set_params() 16 | 17 | def set_params(self): 18 | self.sample_params['add'] = [int(math.ceil(self.num_sections / 2.0)), 0, 0] 19 | 20 | def missing_section(self, data, random_state): 21 | images, labels = data['image'], data['label'] 22 | new_images = images.copy() 23 | new_labels = labels.copy() 24 | 25 | idx = random_state.choice(np.array(range(1, images.shape[0]-1)), self.num_sections, replace=False) 26 | 27 | new_images = np.delete(new_images, idx, 0) 28 | new_labels = np.delete(new_labels, idx, 0) 29 | 30 | return new_images, new_labels 31 | 32 | def __call__(self, data, random_state=np.random): 33 | new_images, new_labels = self.missing_section(data, random_state) 34 | return {'image': new_images, 'label': new_labels} 35 | -------------------------------------------------------------------------------- /utils/lmc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import elf.segmentation.multicut as mc 3 | import elf.segmentation.features as feats 4 | import elf.segmentation.watershed as ws 5 | 6 | 7 | # from .elf_local.segmentation import multicut as mc 8 | # from .elf_local.segmentation import features as feats 9 | # from .elf_local.segmentation import watershed as ws 10 | 11 | def mc_baseline(affs, fragments=None): 12 | affs = 1 - affs 13 | boundary_input = np.maximum(affs[1], affs[2]) 14 | if fragments is None: 15 | fragments = np.zeros_like(boundary_input, dtype='uint64') 16 | offset = 0 17 | for z in range(fragments.shape[0]): 18 | wsz, max_id = ws.distance_transform_watershed(boundary_input[z], threshold=.25, sigma_seeds=2.) 19 | wsz += offset 20 | offset += max_id 21 | fragments[z] = wsz 22 | rag = feats.compute_rag(fragments) 23 | offsets = [[-1, 0, 0], [0, -1, 0], [0, 0, -1]] 24 | costs = feats.compute_affinity_features(rag, affs, offsets)[:, 0] 25 | edge_sizes = feats.compute_boundary_mean_and_length(rag, boundary_input)[:, 1] 26 | costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes) 27 | node_labels = mc.multicut_kernighan_lin(rag, costs) 28 | segmentation = feats.project_node_labels_to_pixels(rag, node_labels) 29 | return segmentation 30 | -------------------------------------------------------------------------------- /mamba_local/csrc/selective_scan/static_switch.h: -------------------------------------------------------------------------------- 1 | // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h 2 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h 3 | 4 | #pragma once 5 | 6 | /// @param COND - a boolean expression to switch by 7 | /// @param CONST_NAME - a name given for the constexpr bool variable. 8 | /// @param ... - code to execute for true and false 9 | /// 10 | /// Usage: 11 | /// ``` 12 | /// BOOL_SWITCH(flag, BoolConst, [&] { 13 | /// some_function(...); 14 | /// }); 15 | /// ``` 16 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \ 17 | [&] { \ 18 | if (COND) { \ 19 | constexpr bool CONST_NAME = true; \ 20 | return __VA_ARGS__(); \ 21 | } else { \ 22 | constexpr bool CONST_NAME = false; \ 23 | return __VA_ARGS__(); \ 24 | } \ 25 | }() 26 | -------------------------------------------------------------------------------- /mamba_local/mamba_ssm_local/causal_conv1d_local/csrc/static_switch.h: -------------------------------------------------------------------------------- 1 | // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h 2 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h 3 | 4 | #pragma once 5 | 6 | /// @param COND - a boolean expression to switch by 7 | /// @param CONST_NAME - a name given for the constexpr bool variable. 8 | /// @param ... - code to execute for true and false 9 | /// 10 | /// Usage: 11 | /// ``` 12 | /// BOOL_SWITCH(flag, BoolConst, [&] { 13 | /// some_function(...); 14 | /// }); 15 | /// ``` 16 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \ 17 | [&] { \ 18 | if (COND) { \ 19 | static constexpr bool CONST_NAME = true; \ 20 | return __VA_ARGS__(); \ 21 | } else { \ 22 | static constexpr bool CONST_NAME = false; \ 23 | return __VA_ARGS__(); \ 24 | } \ 25 | }() 26 | -------------------------------------------------------------------------------- /mamba_local/evals/lm_harness_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import transformers 4 | from transformers import AutoTokenizer 5 | 6 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 7 | 8 | from lm_eval.api.model import LM 9 | from lm_eval.models.huggingface import HFLM 10 | from lm_eval.api.registry import register_model 11 | from lm_eval.__main__ import cli_evaluate 12 | 13 | 14 | @register_model("mamba") 15 | class MambaEvalWrapper(HFLM): 16 | 17 | AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM 18 | 19 | def __init__(self, pretrained="state-spaces/mamba-2.8b", max_length=2048, batch_size=None, device="cuda", 20 | dtype=torch.float16): 21 | LM.__init__(self) 22 | self._model = MambaLMHeadModel.from_pretrained(pretrained, device=device, dtype=dtype) 23 | self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") 24 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id 25 | self.vocab_size = self.tokenizer.vocab_size 26 | self._batch_size = batch_size if batch_size is None else 64 27 | self._max_length = max_length 28 | self._device = torch.device(device) 29 | 30 | @property 31 | def batch_size(self): 32 | return self._batch_size 33 | 34 | def _model_generate(self, context, max_length, stop, **generation_kwargs): 35 | raise NotImplementedError() 36 | 37 | 38 | if __name__ == "__main__": 39 | cli_evaluate() 40 | -------------------------------------------------------------------------------- /util_mamba/crop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | import torch 10 | 11 | from torchvision import transforms 12 | from torchvision.transforms import functional as F 13 | 14 | 15 | class RandomResizedCrop(transforms.RandomResizedCrop): 16 | """ 17 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. 18 | This may lead to results different with torchvision's version. 19 | Following BYOL's TF code: 20 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 21 | """ 22 | @staticmethod 23 | def get_params(img, scale, ratio): 24 | width, height = F._get_image_size(img) 25 | area = height * width 26 | 27 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 28 | log_ratio = torch.log(torch.tensor(ratio)) 29 | aspect_ratio = torch.exp( 30 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 31 | ).item() 32 | 33 | w = int(round(math.sqrt(target_area * aspect_ratio))) 34 | h = int(round(math.sqrt(target_area / aspect_ratio))) 35 | 36 | w = min(w, width) 37 | h = min(h, height) 38 | 39 | i = torch.randint(0, height - h + 1, size=(1,)).item() 40 | j = torch.randint(0, width - w + 1, size=(1,)).item() 41 | 42 | return i, j, h, w -------------------------------------------------------------------------------- /src/run_parallel_gaussian.sh: -------------------------------------------------------------------------------- 1 | pip3 uninstall -y timm && 2 | pip3 install timm==0.3.2 -i https://pypi.tuna.tsinghua.edu.cn/simple && 3 | sed -i 's/from torch._six import container_abcs/import collections.abc as container_abcs/g' /usr/local/lib/python3.8/dist-packages/timm/models/layers/helpers.py && 4 | export NCCL_P2P_DISABLE=1 && 5 | echo "Starting task on GPU 0 1 2 3" 6 | CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch --nproc_per_node=4 --master_port=57780 /h3cstore_ns/hyshi/EM_mamba_new/EM_mamba_seg/main_finetune.py --batch_size=6 --crop_size=16,160,160 \ 7 | --epochs=1000 --output_dir=/h3cstore_ns/hyshi/EM_mamba_new/result/segmambaV3_ac3_lr5_b6_16_160_160_gaussian \ 8 | --visual_dir=/h3cstore_ns/hyshi/EM_mamba_new/result/segmambaV3_ac3_lr5_b6_16_160_160_gaussian/visual \ 9 | --log_dir=/h3cstore_ns/hyshi/EM_mamba_new/result/segmambaV3_ac3_lr5_b6_16_160_160_gaussian/tensorboard_log \ 10 | --warmup_epochs=0 --blr=1e-5 & 11 | echo "Starting task on GPU 4 5 6 7" 12 | CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -m torch.distributed.launch --nproc_per_node=4 --master_port=57781 /h3cstore_ns/hyshi/EM_mamba_new/EM_mamba_seg/main_finetune_small.py --batch_size=6 --crop_size=16,160,160 \ 13 | --epochs=1000 --output_dir=/h3cstore_ns/hyshi/EM_mamba_new/result/segmambaV3small_ac3_lr5_b6_16_160_160_gaussian \ 14 | --visual_dir=/h3cstore_ns/hyshi/EM_mamba_new/result/segmambaV3small_ac3_lr5_b6_16_160_160_gaussian/visual \ 15 | --log_dir=/h3cstore_ns/hyshi/EM_mamba_new/result/segmambaV3small_ac3_lr5_b6_16_160_160_gaussian/tensorboard_log \ 16 | --warmup_epochs=0 --blr=1e-5 & 17 | wait 18 | echo "All tasks completed." -------------------------------------------------------------------------------- /dataloader/test.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from attrdict import AttrDict 3 | import sys 4 | sys.path.append('/braindat/lab/chenyd/code/Miccai23/SSNS-Net-main') 5 | from utils.show import show_one 6 | from data_provider_pretraining import Train 7 | import numpy as np 8 | from PIL import Image 9 | import random 10 | import time 11 | import os 12 | """""" 13 | seed = 555 14 | np.random.seed(seed) 15 | random.seed(seed) 16 | cfg_file = 'pretraining_all.yaml' 17 | with open('Miccai23/SSNS-Net-main/config/' + cfg_file, 'r') as f: 18 | cfg = AttrDict( yaml.safe_load(f) ) 19 | 20 | out_path = os.path.join('/braindat/lab/chenyd/code/Miccai23/SSNS-Net-main', 'data_temp') 21 | if not os.path.exists(out_path): 22 | os.mkdir(out_path) 23 | data = Train(cfg) 24 | t = time.time() 25 | for i in range(0, 20): 26 | t1 = time.time() 27 | tmp_data1, tmp_data2, gt = iter(data).__next__() 28 | print('single cost time: ', time.time()-t1) 29 | print('tmp_data1 shape: ', tmp_data1.shape, 'tmp_data2 shape: ', tmp_data2.shape, 'gt shape: ', gt.shape) 30 | tmp_data1 = np.squeeze(tmp_data1) 31 | tmp_data2 = np.squeeze(tmp_data2) 32 | gt = np.squeeze(gt) 33 | if cfg.MODEL.model_type == 'mala': 34 | tmp_data1 = tmp_data1[14:-14,106:-106,106:-106] 35 | tmp_data2 = tmp_data2[14:-14,106:-106,106:-106] 36 | gt = gt[14:-14,106:-106,106:-106] 37 | 38 | img_data1 = show_one(tmp_data1) 39 | img_data2 = show_one(tmp_data2) 40 | img_affs = show_one(gt) 41 | im_cat = np.concatenate([img_data1, img_data2, img_affs], axis=1) 42 | 43 | Image.fromarray(im_cat).save(os.path.join(out_path, str(i).zfill(4)+'.png')) 44 | print(time.time() - t) -------------------------------------------------------------------------------- /mamba_local/mamba_ssm_local/causal_conv1d_local/csrc/causal_conv1d.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | //////////////////////////////////////////////////////////////////////////////////////////////////// 8 | 9 | struct ConvParamsBase { 10 | using index_t = uint32_t; 11 | 12 | int batch, dim, seqlen, width; 13 | bool silu_activation; 14 | 15 | index_t x_batch_stride; 16 | index_t x_c_stride; 17 | index_t x_l_stride; 18 | index_t weight_c_stride; 19 | index_t weight_width_stride; 20 | index_t out_batch_stride; 21 | index_t out_c_stride; 22 | index_t out_l_stride; 23 | 24 | index_t conv_state_batch_stride; 25 | index_t conv_state_c_stride; 26 | index_t conv_state_l_stride; 27 | 28 | // Common data pointers. 29 | void *__restrict__ x_ptr; 30 | void *__restrict__ weight_ptr; 31 | void *__restrict__ bias_ptr; 32 | void *__restrict__ out_ptr; 33 | 34 | void *__restrict__ conv_state_ptr; 35 | }; 36 | 37 | struct ConvParamsBwd: public ConvParamsBase { 38 | index_t dx_batch_stride; 39 | index_t dx_c_stride; 40 | index_t dx_l_stride; 41 | index_t dweight_c_stride; 42 | index_t dweight_width_stride; 43 | index_t dout_batch_stride; 44 | index_t dout_c_stride; 45 | index_t dout_l_stride; 46 | 47 | // Common data pointers. 48 | void *__restrict__ dx_ptr; 49 | void *__restrict__ dweight_ptr; 50 | void *__restrict__ dbias_ptr; 51 | void *__restrict__ dout_ptr; 52 | }; 53 | 54 | -------------------------------------------------------------------------------- /mamba_local/mamba_ssm_local/causal_conv1d_local/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /model/monitor.py: -------------------------------------------------------------------------------- 1 | 2 | class monitor_lr(object): 3 | # adaptively change the learning rate based on the validation result 4 | def __init__(self, step_bin=3, step_wait=5, thres=0.95, step_max=100): 5 | # step_bin: how many validation needed for one point 6 | # step_wait: how many points consecutively to fail to decrease enough before changing the learning rate 7 | # thres: threshold of learning rate to be changed 8 | # step_max: if reached the max number, decrease the learning rate 9 | # 1. bin the validation results for robust statistics 10 | # 2. stopping criteria: 11 | self.step_bin = step_bin 12 | self.step_wait = step_wait 13 | self.thres = thres 14 | self.step_max = step_max 15 | self.num_change = 0 16 | 17 | self.reset() 18 | 19 | def add(self,result): 20 | self.val_result.append(result) 21 | self.val_id += 1 22 | if self.val_id % self.step_bin == 0: 23 | self.val_stat.append(sum(self.val_result[-self.step_bin:])/float(self.step_bin)) 24 | 25 | def toChange(self): 26 | change = False 27 | if self.val_id>self.step_max: 28 | change = True 29 | elif len(self.val_stat)>self.step_wait and self.val_id % self.step_bin == 0 \ 30 | and min(self.val_result[-self.step_wait:])>min(self.val_result[:-self.step_wait])*self.thres: 31 | change = True 32 | 33 | if change: 34 | self.num_change += 1 35 | self.reset() 36 | return change 37 | 38 | def reset(self): 39 | self.val_id = 0 40 | self.val_result = [] 41 | self.val_stat = [] 42 | self.change = False 43 | 44 | -------------------------------------------------------------------------------- /augmentation/motion_blur.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import random 4 | import numpy as np 5 | from .augmentor import DataAugment 6 | 7 | class MotionBlur(DataAugment): 8 | """Motion blur data augmentation of image stacks. 9 | 10 | Args: 11 | sections (int): number of sections along z dimension to apply motion blur. Default: 2 12 | kernel_size (int): kernel size for motion blur. Default: 11 13 | p (float): probability of applying the augmentation. Default: 0.5 14 | """ 15 | def __init__(self, sections=2, kernel_size=11, p=0.5): 16 | super(MotionBlur, self).__init__(p=p) 17 | self.size = kernel_size 18 | self.sections = sections 19 | self.set_params() 20 | 21 | def set_params(self): 22 | # No change in sample size 23 | pass 24 | 25 | def motion_blur(self, data, random_state): 26 | images = data['image'].copy() 27 | labels = data['label'].copy() 28 | 29 | # generating the kernel 30 | kernel_motion_blur = np.zeros((self.size, self.size)) 31 | if random.random() > 0.5: # horizontal kernel 32 | kernel_motion_blur[int((self.size-1)/2), :] = np.ones(self.size) 33 | else: # vertical kernel 34 | kernel_motion_blur[:, int((self.size-1)/2)] = np.ones(self.size) 35 | kernel_motion_blur = kernel_motion_blur / self.size 36 | 37 | k = min(self.sections, images.shape[0]) 38 | selected_idx = np.random.choice(images.shape[0], k, replace=True) 39 | 40 | for idx in selected_idx: 41 | # applying the kernel to the input image 42 | images[idx] = cv2.filter2D(images[idx], -1, kernel_motion_blur) 43 | 44 | return images, labels 45 | 46 | def __call__(self, data, random_state=np.random): 47 | new_images, new_labels = self.motion_blur(data, random_state) 48 | return {'image': new_images, 'label': new_labels} 49 | -------------------------------------------------------------------------------- /augmentation/rotation.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from .augmentor import DataAugment 4 | 5 | class Rotate(DataAugment): 6 | """ 7 | Continuous rotatation of the `xy`-plane. 8 | 9 | The sample size for `x`- and `y`-axes should be at least :math:`\sqrt{2}` times larger 10 | than the input size to make sure there is no non-valid region after center-crop. 11 | 12 | Args: 13 | p (float): probability of applying the augmentation. Default: 0.5 14 | """ 15 | def __init__(self, p=0.5): 16 | super(Rotate, self).__init__(p=p) 17 | self.image_interpolation = cv2.INTER_LINEAR 18 | self.label_interpolation = cv2.INTER_NEAREST 19 | self.border_mode = cv2.BORDER_CONSTANT 20 | self.set_params() 21 | 22 | def set_params(self): 23 | # sqrt(2) 24 | self.sample_params['ratio'] = [1.0, 1.42, 1.42] 25 | 26 | def rotate(self, imgs, M, interpolation): 27 | height, width = imgs.shape[-2:] 28 | transformedimgs = np.copy(imgs) 29 | for z in range(transformedimgs.shape[-3]): 30 | img = transformedimgs[z, :, :] 31 | dst = cv2.warpAffine(img, M ,(height,width), 1.0, flags=interpolation, borderMode=self.border_mode) 32 | transformedimgs[z, :, :] = dst 33 | 34 | return transformedimgs 35 | 36 | def __call__(self, data, random_state=np.random): 37 | 38 | if 'label' in data and data['label'] is not None: 39 | image, label = data['image'], data['label'] 40 | else: 41 | image, label = data['image'], None 42 | 43 | height, width = image.shape[-2:] 44 | M = cv2.getRotationMatrix2D((height/2, width/2), random_state.rand()*360.0, 1) 45 | 46 | output = {} 47 | output['image'] = self.rotate(image, M, self.image_interpolation) 48 | if label is not None: 49 | output['label'] = self.rotate(label, M, self.label_interpolation) 50 | 51 | return output 52 | -------------------------------------------------------------------------------- /util_mamba/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # LARS optimizer, implementation from MoCo v3: 8 | # https://github.com/facebookresearch/moco-v3 9 | # -------------------------------------------------------- 10 | 11 | import torch 12 | 13 | 14 | class LARS(torch.optim.Optimizer): 15 | """ 16 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 17 | """ 18 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 19 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 20 | super().__init__(params, defaults) 21 | 22 | @torch.no_grad() 23 | def step(self): 24 | for g in self.param_groups: 25 | for p in g['params']: 26 | dp = p.grad 27 | 28 | if dp is None: 29 | continue 30 | 31 | if p.ndim > 1: # if not normalization gamma/beta or bias 32 | dp = dp.add(p, alpha=g['weight_decay']) 33 | param_norm = torch.norm(p) 34 | update_norm = torch.norm(dp) 35 | one = torch.ones_like(param_norm) 36 | q = torch.where(param_norm > 0., 37 | torch.where(update_norm > 0, 38 | (g['trust_coefficient'] * param_norm / update_norm), one), 39 | one) 40 | dp = dp.mul(q) 41 | 42 | param_state = self.state[p] 43 | if 'mu' not in param_state: 44 | param_state['mu'] = torch.zeros_like(p) 45 | mu = param_state['mu'] 46 | mu.mul_(g['momentum']).add_(dp) 47 | p.add_(mu, alpha=-g['lr']) -------------------------------------------------------------------------------- /mamba_local/mamba_ssm_local/causal_conv1d_local/csrc/causal_conv1d_common.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | 10 | //////////////////////////////////////////////////////////////////////////////////////////////////// 11 | 12 | template struct BytesToType {}; 13 | 14 | template<> struct BytesToType<16> { 15 | using Type = uint4; 16 | static_assert(sizeof(Type) == 16); 17 | }; 18 | 19 | template<> struct BytesToType<8> { 20 | using Type = uint64_t; 21 | static_assert(sizeof(Type) == 8); 22 | }; 23 | 24 | template<> struct BytesToType<4> { 25 | using Type = uint32_t; 26 | static_assert(sizeof(Type) == 4); 27 | }; 28 | 29 | template<> struct BytesToType<2> { 30 | using Type = uint16_t; 31 | static_assert(sizeof(Type) == 2); 32 | }; 33 | 34 | template<> struct BytesToType<1> { 35 | using Type = uint8_t; 36 | static_assert(sizeof(Type) == 1); 37 | }; 38 | 39 | //////////////////////////////////////////////////////////////////////////////////////////////////// 40 | 41 | template 42 | struct SumOp { 43 | __device__ inline T operator()(T const & x, T const & y) { return x + y; } 44 | }; 45 | 46 | template 47 | struct Allreduce { 48 | static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); 49 | template 50 | static __device__ inline T run(T x, Operator &op) { 51 | constexpr int OFFSET = THREADS / 2; 52 | x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); 53 | return Allreduce::run(x, op); 54 | } 55 | }; 56 | 57 | template<> 58 | struct Allreduce<2> { 59 | template 60 | static __device__ inline T run(T x, Operator &op) { 61 | x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); 62 | return x; 63 | } 64 | }; 65 | -------------------------------------------------------------------------------- /augmentation/augmentor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class DataAugment(object): 4 | """ 5 | DataAugment interface. 6 | 7 | A data transform needs to conduct the following steps: 8 | 9 | 1. Set :attr:`sample_params` at initialization to compute required sample size. 10 | 2. Randomly generate augmentation parameters for the current transform. 11 | 3. Apply the transform to a pair of images and corresponding labels. 12 | 13 | All the real data augmentations should be a subclass of this class. 14 | """ 15 | def __init__(self, p=0.5): 16 | assert p >= 0.0 and p <=1.0 17 | self.p = p 18 | self.sample_params = { 19 | 'ratio': np.array([1.0, 1.0, 1.0]), 20 | 'add': np.array([0, 0, 0])} 21 | 22 | def set_params(self): 23 | """ 24 | Calculate the appropriate sample size with data augmentation. 25 | 26 | Some data augmentations (wrap, misalignment, etc.) require a larger sample 27 | size than the original, depending on the augmentation parameters that are 28 | randomly chosen. This function takes the data augmentation 29 | parameters and returns an updated data sampling size accordingly. 30 | """ 31 | raise NotImplementedError 32 | 33 | def __call__(self, data, random_state=None): 34 | """ 35 | Apply the data augmentation. 36 | 37 | For a multi-CPU dataloader, one may need to use a unique index to generate 38 | the random seed (:attr:`random_state`), otherwise different workers may generate 39 | the same pseudo-random number for augmentation and sampling. 40 | """ 41 | raise NotImplementedError 42 | 43 | def apply_last(self, data): 44 | """ 45 | Apply the last data augmentation generated by __call__(). This function can be used 46 | for model training under the semi-supervised setting. If this function is called 47 | before any __call__(), raise an error. 48 | """ 49 | raise NotImplementedError 50 | -------------------------------------------------------------------------------- /data/data_misc.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | from typing import Optional, Tuple, List, Union 3 | import numpy as np 4 | 5 | def get_padsize(pad_size: Union[int, List[int]], ndim: int=3) -> Tuple[int]: 6 | """Convert the padding size for 3D input volumes into numpy.pad compatible format. 7 | 8 | Args: 9 | pad_size (int, List[int]): number of values padded to the edges of each axis. 10 | ndim (int): the dimension of the array to be padded. Default: 3 11 | """ 12 | if type(pad_size) == int: 13 | pad_size = [tuple([pad_size, pad_size]) for _ in range(ndim)] 14 | return tuple(pad_size) 15 | 16 | assert len(pad_size) in [1, ndim, 2*ndim] 17 | if len(pad_size) == 1: 18 | pad_size = pad_size[0] 19 | pad_size = [tuple([pad_size, pad_size]) for _ in range(ndim)] 20 | return tuple(pad_size) 21 | elif len(pad_size) == ndim: 22 | return tuple([tuple([x, x]) for x in pad_size]) 23 | else: 24 | return tuple( 25 | [tuple([pad_size[2*i], pad_size[2*i+1]]) 26 | for i in range(len(pad_size) // 2)]) 27 | 28 | def array_unpad(data: np.ndarray, 29 | pad_size: Tuple[int]) -> np.ndarray: 30 | """Unpad a given numpy.ndarray based on the given padding size. 31 | 32 | Args: 33 | data (numpy.ndarray): the input volume to unpad. 34 | pad_size (tuple): number of values removed from the edges of each axis. 35 | Should be in the format of ((before_1, after_1), ... (before_N, after_N)) 36 | representing the unique pad widths for each axis. 37 | """ 38 | diff = data.ndim - len(pad_size) 39 | if diff > 0: 40 | extra = [(0, 0) for _ in range(diff)] 41 | pad_size = tuple(extra + list(pad_size)) 42 | 43 | assert len(pad_size) == data.ndim 44 | index = tuple([ 45 | slice(pad_size[i][0], data.shape[i]-pad_size[i][1]) 46 | for i in range(data.ndim) 47 | ]) 48 | return data[index] 49 | -------------------------------------------------------------------------------- /util_mamba/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # -------------------------------------------------------- 10 | 11 | import os 12 | import PIL 13 | 14 | from torchvision import datasets, transforms 15 | 16 | from timm.data import create_transform 17 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 18 | 19 | 20 | def build_dataset(is_train, args): 21 | transform = build_transform(is_train, args) 22 | 23 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 24 | dataset = datasets.ImageFolder(root, transform=transform) 25 | 26 | print(dataset) 27 | 28 | return dataset 29 | 30 | 31 | def build_transform(is_train, args): 32 | mean = IMAGENET_DEFAULT_MEAN 33 | std = IMAGENET_DEFAULT_STD 34 | # train transform 35 | if is_train: 36 | # this should always dispatch to transforms_imagenet_train 37 | transform = create_transform( 38 | input_size=args.input_size, 39 | is_training=True, 40 | color_jitter=args.color_jitter, 41 | auto_augment=args.aa, 42 | interpolation='bicubic', 43 | re_prob=args.reprob, 44 | re_mode=args.remode, 45 | re_count=args.recount, 46 | mean=mean, 47 | std=std, 48 | ) 49 | return transform 50 | 51 | # eval transform 52 | t = [] 53 | if args.input_size <= 224: 54 | crop_pct = 224 / 256 55 | else: 56 | crop_pct = 1.0 57 | size = int(args.input_size / crop_pct) 58 | t.append( 59 | transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images 60 | ) 61 | t.append(transforms.CenterCrop(args.input_size)) 62 | 63 | t.append(transforms.ToTensor()) 64 | t.append(transforms.Normalize(mean, std)) 65 | return transforms.Compose(t) 66 | -------------------------------------------------------------------------------- /augmentation/cutnoise.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .augmentor import DataAugment 3 | 4 | class CutNoise(DataAugment): 5 | """3D CutNoise data augmentation. 6 | 7 | Randomly add noise to a cuboid region in the volume to force the model 8 | to learn denoising when making predictions. 9 | 10 | Args: 11 | length_ratio (float): the ratio of the cuboid length compared with volume length. 12 | mode (string): the distribution of the noise pattern. Default: ``'uniform'``. 13 | scale (float): scale of the random noise. Default: 0.2. 14 | p (float): probability of applying the augmentation. 15 | """ 16 | 17 | def __init__(self, 18 | length_ratio=0.25, 19 | mode='uniform', 20 | scale=0.2, 21 | p=0.5): 22 | super(CutNoise, self).__init__(p=p) 23 | self.length_ratio = length_ratio 24 | self.mode = mode 25 | self.scale = scale 26 | 27 | def set_params(self): 28 | # No change in sample size 29 | pass 30 | 31 | def cut_noise(self, data, random_state): 32 | images = data['image'].copy() 33 | labels = data['label'].copy() 34 | 35 | zl, zh = self.random_region(images.shape[0], random_state) 36 | yl, yh = self.random_region(images.shape[1], random_state) 37 | xl, xh = self.random_region(images.shape[2], random_state) 38 | 39 | temp = images[zl:zh, yl:yh, xl:xh].copy() 40 | noise = random_state.uniform(-self.scale, self.scale, temp.shape) 41 | temp = temp + noise 42 | temp = np.clip(temp, 0, 1) 43 | 44 | images[zl:zh, yl:yh, xl:xh] = temp 45 | return images, labels 46 | 47 | def random_region(self, vol_len, random_state): 48 | cuboid_len = int(self.length_ratio * vol_len) 49 | low = random_state.randint(0, vol_len-cuboid_len) 50 | high = low + cuboid_len 51 | return low, high 52 | 53 | def __call__(self, data, random_state=np.random): 54 | new_images, new_labels = self.cut_noise(data, random_state) 55 | return {'image': new_images, 'label': new_labels} -------------------------------------------------------------------------------- /mamba_local/tests/ops/triton/test_selective_state_update.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023, Tri Dao. 2 | 3 | import math 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import pytest 8 | 9 | from einops import rearrange 10 | 11 | from mamba_ssm.ops.triton.selective_state_update import selective_state_update, selective_state_update_ref 12 | 13 | 14 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) 15 | # @pytest.mark.parametrize('itype', [torch.float16]) 16 | @pytest.mark.parametrize("has_z", [False, True]) 17 | # @pytest.mark.parametrize('has_z', [True]) 18 | @pytest.mark.parametrize("dstate", [16, 32, 64]) 19 | # @pytest.mark.parametrize("dstate", [16]) 20 | @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) 21 | # @pytest.mark.parametrize("dim", [2048]) 22 | def test_causal_conv1d_update(dim, dstate, has_z, itype): 23 | device = "cuda" 24 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) 25 | if itype == torch.bfloat16: 26 | rtol, atol = 1e-2, 5e-2 27 | # set seed 28 | torch.random.manual_seed(0) 29 | batch_size = 2 30 | state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device) 31 | x = torch.randn(batch_size, dim, device=device, dtype=itype) 32 | dt = torch.randn(batch_size, dim, device=device, dtype=itype) 33 | dt_bias = torch.rand(dim, device=device) - 4.0 34 | A = -torch.rand(dim, dstate, device=device) - 1.0 35 | B = torch.randn(batch_size, dstate, device=device) 36 | C = torch.randn(batch_size, dstate, device=device) 37 | D = torch.randn(dim, device=device) 38 | if has_z: 39 | z = torch.randn_like(x) 40 | else: 41 | z = None 42 | state_ref = state.detach().clone() 43 | out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) 44 | out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) 45 | 46 | print(f"Output max diff: {(out - out_ref).abs().max().item()}") 47 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") 48 | assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) 49 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) 50 | -------------------------------------------------------------------------------- /config/seg_3d_cremiA_data100.yaml: -------------------------------------------------------------------------------- 1 | NAME: 'seg_3d_cremiA_data100' 2 | 3 | MODEL: 4 | model_type: 'superhuman' # 'mala' or 'superhuman' 5 | input_nc: 1 6 | output_nc: 3 7 | if_sigmoid: True 8 | # for 'mala': 9 | init_mode_mala: 'kaiming' 10 | # for 'superhuman': 11 | if_skip: 'False' 12 | filters: 13 | - 28 14 | - 36 15 | - 48 16 | - 64 17 | - 80 18 | upsample_mode: 'bilinear' # 'bilinear', 'nearest', 'transpose', 'transposeS' 19 | decode_ratio: 1 20 | merge_mode: 'add' # 'add', 'cat' 21 | pad_mode: 'zero' # 'zero', 'replicate' 22 | bn_mode: 'async' # 'sync', 'async' 23 | relu_mode: 'elu' # 'elu', 'relu', 'leaky' 24 | init_mode: 'kaiming_normal' # 'kaiming_normal', 'kaiming_uniform', 'xavier_normal', 'xavier_uniform' 25 | 26 | pre_train: False 27 | trained_gpus: 1 28 | pre_train_mode: 'finetune' # 'finetune', 'extract_feature' 29 | trained_model_name: '2020-12-21--08-27-49_ssl_suhu_noskip_mse_lr0001_snemi3d_ulb5' 30 | trained_model_id: 400000 31 | 32 | TRAIN: 33 | resume: False 34 | if_valid: True 35 | if_seg: True 36 | cache_path: '../caches/' 37 | save_path: '../models/' 38 | pad: 0 39 | loss_func: 'WeightedBCELoss' # 'WeightedBCELoss', 'BCELoss' 40 | 41 | opt_type: 'adam' 42 | total_iters: 200000 43 | warmup_iters: 0 44 | base_lr: 0.0001 45 | end_lr: 0.0001 46 | display_freq: 100 47 | valid_freq: 1000 48 | save_freq: 1000 49 | decay_iters: 100000 50 | weight_decay: ~ 51 | power: 1.5 52 | 53 | batch_size: 4 54 | num_workers: 4 55 | if_cuda: True 56 | 57 | random_seed: 555 # -1 is none 58 | 59 | DATA: 60 | dataset_name: 'cremiA' # 'snemi3d-ac3', 'cremi-A', 'cremi' 61 | train_split: 100 62 | test_split: 25 63 | data_folder: '../data' 64 | padding: 20 65 | shift_channels: ~ 66 | if_dilate: True 67 | if_scale_aug: False 68 | if_filp_aug: True 69 | if_rotation_aug: True 70 | if_intensity_aug: True 71 | if_elastic_aug: True 72 | 73 | TEST: 74 | pad: 0 75 | model_name: '2021-04-24--13-18-01_seg_3d_ac4_data80' 76 | -------------------------------------------------------------------------------- /config/seg_3d_cremiB_data100.yaml: -------------------------------------------------------------------------------- 1 | NAME: 'seg_3d_cremiB_data100' 2 | 3 | MODEL: 4 | model_type: 'superhuman' # 'mala' or 'superhuman' 5 | input_nc: 1 6 | output_nc: 3 7 | if_sigmoid: True 8 | # for 'mala': 9 | init_mode_mala: 'kaiming' 10 | # for 'superhuman': 11 | if_skip: 'False' 12 | filters: 13 | - 28 14 | - 36 15 | - 48 16 | - 64 17 | - 80 18 | upsample_mode: 'bilinear' # 'bilinear', 'nearest', 'transpose', 'transposeS' 19 | decode_ratio: 1 20 | merge_mode: 'add' # 'add', 'cat' 21 | pad_mode: 'zero' # 'zero', 'replicate' 22 | bn_mode: 'async' # 'sync', 'async' 23 | relu_mode: 'elu' # 'elu', 'relu', 'leaky' 24 | init_mode: 'kaiming_normal' # 'kaiming_normal', 'kaiming_uniform', 'xavier_normal', 'xavier_uniform' 25 | 26 | pre_train: False 27 | trained_gpus: 1 28 | pre_train_mode: 'finetune' # 'finetune', 'extract_feature' 29 | trained_model_name: '2020-12-21--08-27-49_ssl_suhu_noskip_mse_lr0001_snemi3d_ulb5' 30 | trained_model_id: 400000 31 | 32 | TRAIN: 33 | resume: False 34 | if_valid: True 35 | if_seg: True 36 | cache_path: '../caches/' 37 | save_path: '../models/' 38 | pad: 0 39 | loss_func: 'WeightedBCELoss' # 'WeightedBCELoss', 'BCELoss' 40 | 41 | opt_type: 'adam' 42 | total_iters: 200000 43 | warmup_iters: 0 44 | base_lr: 0.0001 45 | end_lr: 0.0001 46 | display_freq: 100 47 | valid_freq: 1000 48 | save_freq: 1000 49 | decay_iters: 100000 50 | weight_decay: ~ 51 | power: 1.5 52 | 53 | batch_size: 4 54 | num_workers: 4 55 | if_cuda: True 56 | 57 | random_seed: 555 # -1 is none 58 | 59 | DATA: 60 | dataset_name: 'cremiB' # 'snemi3d-ac3', 'cremi-A', 'cremi' 61 | train_split: 100 62 | test_split: 25 63 | data_folder: '../data' 64 | padding: 20 65 | shift_channels: ~ 66 | if_dilate: True 67 | if_scale_aug: False 68 | if_filp_aug: True 69 | if_rotation_aug: True 70 | if_intensity_aug: True 71 | if_elastic_aug: True 72 | 73 | TEST: 74 | pad: 0 75 | model_name: '2021-04-24--13-18-01_seg_3d_ac4_data80' 76 | -------------------------------------------------------------------------------- /config/seg_3d_wafer26_data100.yaml: -------------------------------------------------------------------------------- 1 | NAME: 'seg_3d_wafer26_data100' 2 | 3 | MODEL: 4 | model_type: 'superhuman' # 'mala' or 'superhuman' 5 | input_nc: 1 6 | output_nc: 3 7 | if_sigmoid: True 8 | # for 'mala': 9 | init_mode_mala: 'kaiming' 10 | # for 'superhuman': 11 | if_skip: 'False' 12 | filters: 13 | - 28 14 | - 36 15 | - 48 16 | - 64 17 | - 80 18 | upsample_mode: 'bilinear' # 'bilinear', 'nearest', 'transpose', 'transposeS' 19 | decode_ratio: 1 20 | merge_mode: 'add' # 'add', 'cat' 21 | pad_mode: 'zero' # 'zero', 'replicate' 22 | bn_mode: 'async' # 'sync', 'async' 23 | relu_mode: 'elu' # 'elu', 'relu', 'leaky' 24 | init_mode: 'kaiming_normal' # 'kaiming_normal', 'kaiming_uniform', 'xavier_normal', 'xavier_uniform' 25 | 26 | pre_train: False 27 | trained_gpus: 1 28 | pre_train_mode: 'finetune' # 'finetune', 'extract_feature' 29 | trained_model_name: '2020-12-21--08-27-49_ssl_suhu_noskip_mse_lr0001_snemi3d_ulb5' 30 | trained_model_id: 400000 31 | 32 | TRAIN: 33 | resume: False 34 | if_valid: True 35 | if_seg: True 36 | cache_path: '../caches/' 37 | save_path: '../models/' 38 | pad: 0 39 | loss_func: 'WeightedBCELoss' # 'WeightedBCELoss', 'BCELoss' 40 | 41 | opt_type: 'adam' 42 | total_iters: 200000 43 | warmup_iters: 0 44 | base_lr: 0.0001 45 | end_lr: 0.0001 46 | display_freq: 100 47 | valid_freq: 1000 48 | save_freq: 1000 49 | decay_iters: 100000 50 | weight_decay: ~ 51 | power: 1.5 52 | 53 | batch_size: 4 54 | num_workers: 4 55 | if_cuda: True 56 | 57 | random_seed: 555 # -1 is none 58 | 59 | DATA: 60 | dataset_name: 'wafer' # 'snemi3d-ac3', 'cremi-A', 'cremi' 61 | train_split: 100 62 | test_split: 125 63 | data_folder: '/h3cstore_ns/EM_data' 64 | padding: 20 65 | shift_channels: ~ 66 | if_dilate: True 67 | if_scale_aug: False 68 | if_filp_aug: True 69 | if_rotation_aug: True 70 | if_intensity_aug: True 71 | if_elastic_aug: True 72 | 73 | TEST: 74 | pad: 0 75 | model_name: '2021-04-24--13-18-01_seg_3d_ac4_data80' 76 | -------------------------------------------------------------------------------- /config/seg_3d_cremiC_data100.yaml: -------------------------------------------------------------------------------- 1 | NAME: 'seg_3d_cremiC_data100' 2 | 3 | MODEL: 4 | model_type: 'segmamba' # 'mala' or 'superhuman' 5 | input_nc: 1 6 | output_nc: 3 7 | if_sigmoid: True 8 | train_model_path: False 9 | # for 'mala': 10 | init_mode_mala: 'kaiming' 11 | # for 'superhuman': 12 | if_skip: 'False' 13 | filters: 14 | - 28 15 | - 36 16 | - 48 17 | - 64 18 | - 80 19 | upsample_mode: 'bilinear' # 'bilinear', 'nearest', 'transpose', 'transposeS' 20 | decode_ratio: 1 21 | merge_mode: 'add' # 'add', 'cat' 22 | pad_mode: 'zero' # 'zero', 'replicate' 23 | bn_mode: 'async' # 'sync', 'async' 24 | relu_mode: 'elu' # 'elu', 'relu', 'leaky' 25 | init_mode: 'kaiming_normal' # 'kaiming_normal', 'kaiming_uniform', 'xavier_normal', 'xavier_uniform' 26 | 27 | pre_train: False 28 | trained_gpus: 1 29 | pre_train_mode: 'finetune' # 'finetune', 'extract_feature' 30 | trained_model_name: '2020-12-21--08-27-49_ssl_suhu_noskip_mse_lr0001_snemi3d_ulb5' 31 | trained_model_id: 400000 32 | 33 | TRAIN: 34 | resume: False 35 | if_valid: True 36 | if_seg: True 37 | cache_path: '/h3cstore_ns/EM_seg/CREMIC_segmamba_exp/caches/' 38 | save_path: '/h3cstore_ns/EM_seg/CREMIC_segmamba_exp/models/' 39 | pad: 0 40 | loss_func: 'WeightedMSELoss' # 'WeightedBCELoss', 'BCELoss' 41 | 42 | opt_type: 'adam' 43 | total_iters: 200000 44 | warmup_iters: 0 45 | base_lr: 0.0001 46 | end_lr: 0.0001 47 | display_freq: 100 48 | valid_freq: 1000 49 | save_freq: 1000 50 | decay_iters: 100000 51 | weight_decay: ~ 52 | power: 1.5 53 | 54 | batch_size: 8 55 | num_workers: 4 56 | if_cuda: True 57 | 58 | random_seed: 555 # -1 is none 59 | 60 | DATA: 61 | dataset_name: 'cremiC' # 'snemi3d-ac3', 'cremi-A', 'cremi' 62 | train_split: 100 63 | test_split: 25 64 | data_folder: '/h3cstore_ns/Backbones/data' 65 | padding: 20 66 | shift_channels: ~ 67 | if_dilate: True 68 | if_scale_aug: False 69 | if_filp_aug: True 70 | if_rotation_aug: True 71 | if_intensity_aug: True 72 | if_elastic_aug: True 73 | 74 | TEST: 75 | pad: 0 76 | model_name: '2021-04-24--13-18-01_seg_3d_ac4_data80' 77 | -------------------------------------------------------------------------------- /model/squeeze_excite.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from em_net.libs.sync import SynchronizedBatchNorm1d, SynchronizedBatchNorm3d 6 | 7 | # -- squeeze-and-excitation layer -- 8 | class SELayer(nn.Module): 9 | def __init__(self, channel, reduction=4): 10 | super(SELayer, self).__init__() 11 | self.avg_pool = nn.AdaptiveAvgPool3d(1) 12 | self.fc = nn.Sequential( 13 | nn.Linear(channel, channel // reduction), 14 | SynchronizedBatchNorm1d(channel // reduction), 15 | nn.ELU(inplace=True), 16 | nn.Linear(channel // reduction, channel), 17 | SynchronizedBatchNorm1d(channel), 18 | nn.Sigmoid()) 19 | 20 | def forward(self, x): 21 | b, c, _, _, _ = x.size() 22 | y = self.avg_pool(x).view(b, c) 23 | y = self.fc(y).view(b, c, 1, 1, 1) 24 | return x * y 25 | 26 | class SELayerCS(nn.Module): 27 | # Squeeze-and-excitation layer (channel & spatial) 28 | def __init__(self, channel, reduction=4): 29 | super(SELayerCS, self).__init__() 30 | self.avg_pool = nn.AdaptiveAvgPool3d(1) 31 | self.fc = nn.Sequential( 32 | nn.Linear(channel, channel // reduction), 33 | SynchronizedBatchNorm1d(channel // reduction), 34 | nn.ELU(inplace=True), 35 | nn.Linear(channel // reduction, channel), 36 | SynchronizedBatchNorm1d(channel), 37 | nn.Sigmoid()) 38 | 39 | self.sc = nn.Sequential( 40 | nn.Conv3d(channel, 1, kernel_size=(1, 1, 1)), 41 | SynchronizedBatchNorm3d(1), 42 | nn.ELU(inplace=True), 43 | nn.MaxPool3d(kernel_size=(1, 8, 8), stride=(1, 8, 8)), 44 | conv3d_bn_elu(1, 1, kernel_size=(3, 3, 3), padding=(1, 1, 1)), 45 | nn.Upsample(scale_factor=(1, 8, 8), mode='trilinear', align_corners=False), 46 | nn.Conv3d(1, channel, kernel_size=(1, 1, 1)), 47 | SynchronizedBatchNorm3d(channel), 48 | nn.Sigmoid()) 49 | 50 | def forward(self, x): 51 | b, c, _, _, _ = x.size() 52 | y = self.avg_pool(x).view(b, c) 53 | y = self.fc(y).view(b, c, 1, 1, 1) 54 | z = self.sc(x) 55 | return (x * y) + (x * z) 56 | 57 | -------------------------------------------------------------------------------- /config/seg_3d_ac4_data80.yaml: -------------------------------------------------------------------------------- 1 | NAME: 'seg_3d_ac4_data80' 2 | 3 | MODEL: 4 | model_type: 'superhuman' # 'mala' or 'superhuman' 5 | input_nc: 1 6 | output_nc: 3 7 | if_sigmoid: True 8 | # for 'mala': 9 | init_mode_mala: 'kaiming' 10 | # for 'superhuman': 11 | if_skip: 'False' 12 | filters: 13 | - 28 14 | - 36 15 | - 48 16 | - 64 17 | - 80 18 | upsample_mode: 'bilinear' # 'bilinear', 'nearest', 'transpose', 'transposeS' 19 | decode_ratio: 1 20 | merge_mode: 'add' # 'add', 'cat' 21 | pad_mode: 'zero' # 'zero', 'replicate' 22 | bn_mode: 'async' # 'sync', 'async' 23 | relu_mode: 'elu' # 'elu', 'relu', 'leaky' 24 | init_mode: 'kaiming_normal' # 'kaiming_normal', 'kaiming_uniform', 'xavier_normal', 'xavier_uniform' 25 | 26 | pre_train: True 27 | train_model_path: /h3cstore_ns/EM_seg_models/models/2023-10-02--15-36-33_seg_3d_ac4_data80/model-200000.ckpt 28 | trained_gpus: 1 29 | load_encoder: False 30 | pre_train_mode: 'finetune' # 'finetune', 'extract_feature' 31 | # trained_model_name: '2020-12-21--08-27-49_ssl_suhu_noskip_mse_lr0001_snemi3d_ulb5' 32 | # trained_model_id: 400000 33 | 34 | TRAIN: 35 | resume: False 36 | if_valid: True 37 | if_seg: True 38 | cache_path: '../caches/' 39 | save_path: '../models/' 40 | pad: 0 41 | loss_func: 'WeightedBCELoss' # 'WeightedBCELoss', 'BCELoss' 42 | 43 | opt_type: 'adam' 44 | total_iters: 200000 45 | warmup_iters: 0 46 | base_lr: 0.0001 47 | end_lr: 0.0001 48 | display_freq: 100 49 | valid_freq: 1000 50 | save_freq: 1000 51 | decay_iters: 100000 52 | weight_decay: ~ 53 | power: 1.5 54 | 55 | batch_size: 2 56 | num_workers: 4 57 | if_cuda: True 58 | 59 | random_seed: 555 # -1 is none 60 | 61 | DATA: 62 | dataset_name: 'ac4' # 'snemi3d-ac3', 'cremi-A', 'cremi' 63 | train_split: 80 64 | test_split: 20 65 | data_folder: '/h3cstore_ns/hyshi/EM_seg/data' 66 | padding: 20 67 | shift_channels: ~ 68 | if_dilate: True 69 | if_scale_aug: False 70 | if_filp_aug: True 71 | if_rotation_aug: True 72 | if_intensity_aug: True 73 | if_elastic_aug: True 74 | 75 | TEST: 76 | pad: 0 77 | model_name: '2021-04-24--13-18-01_seg_3d_ac4_data80' 78 | window_size: [4,32,32] 79 | -------------------------------------------------------------------------------- /config/seg_3d_wafer4_data100.yaml: -------------------------------------------------------------------------------- 1 | NAME: 'seg_3d_wafer4_data100' 2 | 3 | MODEL: 4 | model_type: 'superhuman' # 'mala' or 'superhuman' 5 | input_nc: 1 6 | output_nc: 3 7 | if_sigmoid: True 8 | # for 'mala': 9 | init_mode_mala: 'kaiming' 10 | # for 'superhuman': 11 | if_skip: 'False' 12 | filters: 13 | - 28 14 | - 36 15 | - 48 16 | - 64 17 | - 80 18 | upsample_mode: 'bilinear' # 'bilinear', 'nearest', 'transpose', 'transposeS' 19 | decode_ratio: 1 20 | merge_mode: 'add' # 'add', 'cat' 21 | pad_mode: 'zero' # 'zero', 'replicate' 22 | bn_mode: 'async' # 'sync', 'async' 23 | relu_mode: 'elu' # 'elu', 'relu', 'leaky' 24 | init_mode: 'kaiming_normal' # 'kaiming_normal', 'kaiming_uniform', 'xavier_normal', 'xavier_uniform' 25 | 26 | pre_train: False 27 | trained_gpus: 1 28 | pre_train_mode: 'finetune' # 'finetune', 'extract_feature' 29 | trained_model_name: '2020-12-21--08-27-49_ssl_suhu_noskip_mse_lr0001_snemi3d_ulb5' 30 | trained_model_id: 400000 31 | train_model_path: '' #Fasle #'/h3cstore_ns/Backbones/models/2023-06-01--09-17-08_superhuman_3d_ac4_data80_long234927_aff10/model-129000.ckpt' 32 | 33 | TRAIN: 34 | resume: False 35 | if_valid: True 36 | if_seg: True 37 | cache_path: '/h3cstore_ns/EM_seg_wafer_total_superhuman/caches/' 38 | save_path: '/h3cstore_ns/EM_seg_wafer_total_superhuman/models/' 39 | pad: 0 40 | loss_func: 'WeightedBCELoss' # 'WeightedBCELoss', 'BCELoss' 41 | 42 | opt_type: 'adam' 43 | total_iters: 200000 44 | warmup_iters: 0 45 | base_lr: 0.0001 46 | end_lr: 0.0001 47 | display_freq: 100 48 | valid_freq: 2000 49 | save_freq: 2000 50 | decay_iters: 100000 51 | weight_decay: ~ 52 | power: 1.5 53 | 54 | batch_size: 8 55 | num_workers: 4 56 | if_cuda: True 57 | 58 | random_seed: 555 # -1 is none 59 | 60 | DATA: 61 | dataset_name: 'wafer' # 'snemi3d-ac3', 'cremi-A', 'cremi' 62 | train_split: 100 63 | test_split: 25 64 | data_folder: '/h3cstore_ns/EM_data' 65 | padding: 20 66 | shift_channels: ~ 67 | if_dilate: True 68 | if_scale_aug: False 69 | if_filp_aug: True 70 | if_rotation_aug: True 71 | if_intensity_aug: True 72 | if_elastic_aug: True 73 | 74 | TEST: 75 | pad: 0 76 | model_name: '2021-04-24--13-18-01_seg_3d_ac4_data80' 77 | -------------------------------------------------------------------------------- /utils/shift_channels.py: -------------------------------------------------------------------------------- 1 | 2 | def shift_func(shift_channels=3): 3 | if shift_channels == 3: 4 | shift = [[-1, 0, 0], [0, -1, 0], [0, 0, -1]] 5 | elif shift_channels == 7: 6 | shift = [[-1, 0, 0], [0, -1, 0], [0, 0, -1], 7 | # direct 3d nhood for attractive edges 8 | [-1, -1, -1], [-1, 1, 1], [-1, -1, 1], [-1, 1, -1]] 9 | elif shift_channels == 9: 10 | shift = [[-1, 0, 0], [0, -1, 0], [0, 0, -1], 11 | # direct 3d nhood for attractive edges 12 | [-1, -1, -1], [-1, 1, 1], [-1, -1, 1], [-1, 1, -1], 13 | # indirect 3d nhood for dam edges 14 | [0, -9, 0], [0, 0, -9]] 15 | elif shift_channels == 15: 16 | shift = [[-1, 0, 0], [0, -1, 0], [0, 0, -1], 17 | # direct 3d nhood for attractive edges 18 | [-1, -1, -1], [-1, 1, 1], [-1, -1, 1], [-1, 1, -1], 19 | # indirect 3d nhood for dam edges 20 | [0, -9, 0], [0, 0, -9], 21 | # long range direct hood 22 | [0, -9, -9], [0, 9, -9], [0, -9, -4], [0, -4, -9], [0, 4, -9], [0, 9, -4]] 23 | elif shift_channels == 17: 24 | shift = [[-1, 0, 0], [0, -1, 0], [0, 0, -1], 25 | # direct 3d nhood for attractive edges 26 | [-1, -1, -1], [-1, 1, 1], [-1, -1, 1], [-1, 1, -1], 27 | # indirect 3d nhood for dam edges 28 | [0, -9, 0], [0, 0, -9], 29 | # long range direct hood 30 | [0, -9, -9], [0, 9, -9], [0, -9, -4], [0, -4, -9], [0, 4, -9], [0, 9, -4], 31 | # inplane diagonal dam edges 32 | [0, -27, 0], [0, 0, -27]] 33 | elif shift_channels == 23: 34 | shift = [[-1, 0, 0], [0, -1, 0], [0, 0, -1], 35 | # direct 3d nhood for attractive edges 36 | [-1, -1, -1], [-1, 1, 1], [-1, -1, 1], [-1, 1, -1], 37 | # indirect 3d nhood for dam edges 38 | [0, -9, 0], [0, 0, -9], 39 | # long range direct hood 40 | [0, -9, -9], [0, 9, -9], [0, -9, -4], [0, -4, -9], [0, 4, -9], [0, 9, -4], 41 | # inplane diagonal dam edges 42 | [0, -27, 0], [0, 0, -27], 43 | # new 44 | [0, -27, -27], [0, 27, -27], [0, -27, -9], [0, -9, -27], [0, 9, -27], [0, 27, -9]] 45 | else: 46 | raise NotImplementedError 47 | return shift 48 | -------------------------------------------------------------------------------- /augmentation/mixup.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from numpy.core.numeric import indices 4 | import torch 5 | from itertools import combinations 6 | 7 | class MixupAugmentor(object): 8 | """Mixup augmentor (experimental). 9 | 10 | The input can be a `numpy.ndarray` or `torch.Tensor` of shape :math:`(B, C, Z, Y, X)`. 11 | 12 | Args: 13 | min_ratio (float): minimal interpolation ratio of the target volume. Default: 0.7 14 | max_ratio (float): maximal interpolation ratio of the target volume. Default: 0.9 15 | num_aug (int): number of volumes to be augmented in a batch. Default: 2 16 | 17 | Examples:: 18 | >>> from connectomics.data.augmentation import MixupAugmentor 19 | >>> mixup_augmentor = MixupAugmentor(num_aug=2) 20 | >>> volume = mixup_augmentor(volume) 21 | >>> pred = model(volume) 22 | """ 23 | def __init__(self, min_ratio=0.7, max_ratio=0.9, num_aug=2): 24 | self.min_ratio = min_ratio 25 | self.max_ratio = max_ratio 26 | self.num_aug = num_aug 27 | 28 | def __call__(self, volume): 29 | if isinstance(volume, torch.Tensor): 30 | num_vol = volume.size(0) 31 | elif isinstance(volume, np.ndarray): 32 | num_vol = volume.shape[0] 33 | else: 34 | raise TypeError("Type {} is not supported in MixupAugmentor".format(type(volume))) 35 | 36 | num_aug = self.num_aug if self.num_aug <= num_vol else num_vol 37 | indices = list(range(num_vol)) 38 | major_idx = random.sample(indices, num_aug) 39 | minor_idx = [] 40 | for x in major_idx: 41 | temp = indices.copy() 42 | temp.remove(x) 43 | minor_idx.append(random.sample(temp, 1)[0]) 44 | 45 | for i in range(len(major_idx)): 46 | ratio = random.uniform(self.min_ratio, self.max_ratio) 47 | volume[major_idx[i]] = volume[major_idx[i]] * ratio + volume[minor_idx[i]] * (1-ratio) 48 | 49 | return volume 50 | 51 | def test(): 52 | mixup_augmentor = MixupAugmentor(num_aug=2) 53 | volume = np.ones((4,1,8,32,32)) 54 | volume = mixup_augmentor(volume) 55 | print('Tested numpy.ndarray.') 56 | 57 | volume = torch.ones(4,1,8,32,32) 58 | volume = mixup_augmentor(volume) 59 | print('Tested torch.Tensor.') 60 | 61 | volume = [1,2,3,4,5] 62 | volume = mixup_augmentor(volume) 63 | 64 | if __name__ == '__main__': 65 | test() 66 | -------------------------------------------------------------------------------- /augmentation/flip.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from .augmentor import DataAugment 4 | 5 | class Flip(DataAugment): 6 | """ 7 | Randomly flip along `z`-, `y`- and `x`-axes as well as swap `y`- and `x`-axes 8 | for anisotropic image volumes. For learning on isotropic image volumes set 9 | :attr:`do_ztrans` to 1 to swap `z`- and `x`-axes (the inputs need to be cubic). 10 | 11 | Args: 12 | p (float): probability of applying the augmentation. Default: 0.5 13 | do_ztrans (int): set to 1 to swap z- and x-axes for isotropic data. Default: 0 14 | """ 15 | def __init__(self, p=0.5, do_ztrans=0): 16 | super(Flip, self).__init__(p) 17 | self.do_ztrans = do_ztrans 18 | 19 | def set_params(self): 20 | # No change in sample size 21 | pass 22 | 23 | def flip_and_swap(self, data, rule): 24 | assert data.ndim==3 or data.ndim==4 25 | if data.ndim == 3: # 3-channel input in z,y,x 26 | # z reflection. 27 | if rule[0]: 28 | data = data[::-1, :, :] 29 | # y reflection. 30 | if rule[1]: 31 | data = data[:, ::-1, :] 32 | # x reflection. 33 | if rule[2]: 34 | data = data[:, :, ::-1] 35 | # Transpose in xy. 36 | if rule[3]: 37 | data = data.transpose(0, 2, 1) 38 | # Transpose in xz. 39 | if self.do_ztrans==1 and rule[4]: 40 | data = data.transpose(2, 1, 0) 41 | else: # 4-channel input in c,z,y,x 42 | # z reflection. 43 | if rule[0]: 44 | data = data[:, ::-1, :, :] 45 | # y reflection. 46 | if rule[1]: 47 | data = data[:, :, ::-1, :] 48 | # x reflection. 49 | if rule[2]: 50 | data = data[:, :, :, ::-1] 51 | # Transpose in xy. 52 | if rule[3]: 53 | data = data.transpose(0, 1, 3, 2) 54 | # Transpose in xz. 55 | if self.do_ztrans==1 and rule[4]: 56 | data = data.transpose(0, 3, 2, 1) 57 | return data 58 | 59 | def __call__(self, data, random_state=np.random): 60 | output = {} 61 | 62 | rule = random_state.randint(2, size=4+self.do_ztrans) 63 | augmented_image = self.flip_and_swap(data['image'], rule) 64 | augmented_label = self.flip_and_swap(data['label'], rule) 65 | output['image'] = augmented_image 66 | output['label'] = augmented_label 67 | 68 | return output 69 | -------------------------------------------------------------------------------- /src/run_parallel.sh: -------------------------------------------------------------------------------- 1 | pip3 uninstall -y timm && 2 | pip3 install timm==0.3.2 -i https://pypi.tuna.tsinghua.edu.cn/simple && 3 | sed -i 's/from torch._six import container_abcs/import collections.abc as container_abcs/g' /usr/local/lib/python3.8/dist-packages/timm/models/layers/helpers.py && 4 | export NCCL_P2P_DISABLE=1 && 5 | echo "Starting task on GPU 0 1" 6 | CUDA_VISIBLE_DEVICES=0,1 python3 -m torch.distributed.launch --nproc_per_node=2 --master_port=56780 /h3cstore_ns/hyshi/EM_mamba_new/EM_mamba_seg/main_finetune.py --batch_size=6 --crop_size=16,160,160 \ 7 | --epochs=800 --output_dir=/h3cstore_ns/hyshi/EM_mamba_new/result/segmambaV3_ac3_lr5_b6_16_160_160 \ 8 | --visual_dir=/h3cstore_ns/hyshi/EM_mamba_new/result/segmambaV3_ac3_lr5_b6_16_160_160/visual \ 9 | --log_dir=/h3cstore_ns/hyshi/EM_mamba_new/result/segmambaV3_ac3_lr5_b6_16_160_160/tensorboard_log \ 10 | --warmup_epochs=0 --blr=1e-5 & 11 | echo "Starting task on GPU 2 3" 12 | CUDA_VISIBLE_DEVICES=2,3 python3 -m torch.distributed.launch --nproc_per_node=2 --master_port=56781 /h3cstore_ns/hyshi/EM_mamba_new/EM_mamba_seg/main_finetune.py --batch_size=2 --crop_size=16,320,320 \ 13 | --epochs=800 --output_dir=/h3cstore_ns/hyshi/EM_mamba_new/result/segmambaV3_ac3_lr5_16_320_320 \ 14 | --visual_dir=/h3cstore_ns/hyshi/EM_mamba_new/result/segmambaV3_ac3_lr5_16_320_320/visual \ 15 | --log_dir=/h3cstore_ns/hyshi/EM_mamba_new/result/segmambaV3_ac3_lr5_16_320_320/tensorboard_log \ 16 | --warmup_epochs=0 --blr=1e-5 & 17 | echo "Starting task on GPU 4 5" 18 | CUDA_VISIBLE_DEVICES=4,5 python3 -m torch.distributed.launch --nproc_per_node=2 --master_port=56782 /h3cstore_ns/hyshi/EM_mamba_new/EM_mamba_seg/main_finetune.py --batch_size=2 --crop_size=32,320,320 \ 19 | --epochs=800 --output_dir=/h3cstore_ns/hyshi/EM_mamba_new/result/segmambaV3_ac3_lr5_32_320_320 \ 20 | --visual_dir=/h3cstore_ns/hyshi/EM_mamba_new/result/segmambaV3_ac3_lr5_32_320_320/visual \ 21 | --log_dir=/h3cstore_ns/hyshi/EM_mamba_new/result/segmambaV3_ac3_lr5_32_320_320/tensorboard_log \ 22 | --warmup_epochs=0 --blr=1e-5 & 23 | echo "Starting task on GPU 6 7" 24 | CUDA_VISIBLE_DEVICES=6,7 python3 -m torch.distributed.launch --nproc_per_node=2 --master_port=56783 /h3cstore_ns/hyshi/EM_mamba_new/EM_mamba_seg/main_finetune.py --batch_size=2 --crop_size=16,160,160 \ 25 | --epochs=800 --output_dir=/h3cstore_ns/hyshi/EM_mamba_new/result/segmambaV3_ac3_lr5_b2_16_160_160 \ 26 | --visual_dir=/h3cstore_ns/hyshi/EM_mamba_new/result/segmambaV3_ac3_lr5_b2_16_160_160/visual \ 27 | --log_dir=/h3cstore_ns/hyshi/EM_mamba_new/result/segmambaV3_ac3_lr5_b2_16_160_160/tensorboard_log \ 28 | --warmup_epochs=0 --blr=1e-5 & 29 | wait 30 | echo "All tasks completed." -------------------------------------------------------------------------------- /src/run_parallel_superhuman.sh: -------------------------------------------------------------------------------- 1 | pip3 uninstall -y timm && 2 | pip3 install timm==0.3.2 -i https://pypi.tuna.tsinghua.edu.cn/simple && 3 | sed -i 's/from torch._six import container_abcs/import collections.abc as container_abcs/g' /usr/local/lib/python3.8/dist-packages/timm/models/layers/helpers.py && 4 | export NCCL_P2P_DISABLE=1 && 5 | echo "Starting task on GPU 0 1" 6 | CUDA_VISIBLE_DEVICES=0,1 python3 -m torch.distributed.launch --nproc_per_node=2 --master_port=46667 /h3cstore_ns/hyshi/EM_mamba_new/EM_mamba_seg/main_finetune.py --batch_size=6 --crop_size=18,160,160 \ 7 | --epochs=800 --output_dir=/h3cstore_ns/hyshi/EM_mamba_new/result/superhuman_ac3_lr5_b6_18_160_160 \ 8 | --visual_dir=/h3cstore_ns/hyshi/EM_mamba_new/result/superhuman_ac3_lr5_b6_18_160_160/visual \ 9 | --log_dir=/h3cstore_ns/hyshi/EM_mamba_new/result/superhuman_ac3_lr5_b6_18_160_160/tensorboard_log \ 10 | --warmup_epochs=0 --blr=1e-5 & 11 | echo "Starting task on GPU 2 3" 12 | CUDA_VISIBLE_DEVICES=2,3 python3 -m torch.distributed.launch --nproc_per_node=2 --master_port=46668 /h3cstore_ns/hyshi/EM_mamba_new/EM_mamba_seg/main_finetune.py --batch_size=6 --crop_size=16,160,160 \ 13 | --epochs=800 --output_dir=/h3cstore_ns/hyshi/EM_mamba_new/result/superhuman_ac3_lr5_b6_16_160_160 \ 14 | --visual_dir=/h3cstore_ns/hyshi/EM_mamba_new/result/superhuman_ac3_lr5_b6_16_160_160/visual \ 15 | --log_dir=/h3cstore_ns/hyshi/EM_mamba_new/result/superhuman_ac3_lr5_b6_16_160_160/tensorboard_log \ 16 | --warmup_epochs=0 --blr=1e-5 & 17 | echo "Starting task on GPU 4 5" 18 | CUDA_VISIBLE_DEVICES=4,5 python3 -m torch.distributed.launch --nproc_per_node=2 --master_port=46669 /h3cstore_ns/hyshi/EM_mamba_new/EM_mamba_seg/main_finetune.py --batch_size=2 --crop_size=18,160,160 \ 19 | --epochs=800 --output_dir=/h3cstore_ns/hyshi/EM_mamba_new/result/superhuman_ac3_lr5_b2_18_160_160 \ 20 | --visual_dir=/h3cstore_ns/hyshi/EM_mamba_new/result/superhuman_ac3_lr5_b2_18_160_160/visual \ 21 | --log_dir=/h3cstore_ns/hyshi/EM_mamba_new/result/superhuman_ac3_lr5_b2_18_160_160/tensorboard_log \ 22 | --warmup_epochs=0 --blr=1e-5 & 23 | echo "Starting task on GPU 6 7" 24 | CUDA_VISIBLE_DEVICES=6,7 python3 -m torch.distributed.launch --nproc_per_node=2 --master_port=46670 /h3cstore_ns/hyshi/EM_mamba_new/EM_mamba_seg/main_finetune.py --batch_size=2 --crop_size=16,160,160 \ 25 | --epochs=800 --output_dir=/h3cstore_ns/hyshi/EM_mamba_new/result/superhuman_ac3_lr5_b2_16_160_160 \ 26 | --visual_dir=/h3cstore_ns/hyshi/EM_mamba_new/result/superhuman_ac3_lr5_b2_16_160_160/visual \ 27 | --log_dir=/h3cstore_ns/hyshi/EM_mamba_new/result/superhuman_ac3_lr5_b2_16_160_160/tensorboard_log \ 28 | --warmup_epochs=0 --blr=1e-5 & 29 | wait 30 | echo "All tasks completed." -------------------------------------------------------------------------------- /util_mamba/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | 14 | 15 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 16 | """ 17 | Parameter groups for layer-wise lr decay 18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 19 | """ 20 | param_group_names = {} 21 | param_groups = {} 22 | 23 | num_layers = len(model.blocks) + 1 24 | 25 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 26 | 27 | for n, p in model.named_parameters(): 28 | if not p.requires_grad: 29 | continue 30 | 31 | # no decay: all 1D parameters and model specific ones 32 | if p.ndim == 1 or n in no_weight_decay_list: 33 | g_decay = "no_decay" 34 | this_decay = 0. 35 | else: 36 | g_decay = "decay" 37 | this_decay = weight_decay 38 | 39 | layer_id = get_layer_id_for_vit(n, num_layers) 40 | group_name = "layer_%d_%s" % (layer_id, g_decay) 41 | 42 | if group_name not in param_group_names: 43 | this_scale = layer_scales[layer_id] 44 | 45 | param_group_names[group_name] = { 46 | "lr_scale": this_scale, 47 | "weight_decay": this_decay, 48 | "params": [], 49 | } 50 | param_groups[group_name] = { 51 | "lr_scale": this_scale, 52 | "weight_decay": this_decay, 53 | "params": [], 54 | } 55 | 56 | param_group_names[group_name]["params"].append(n) 57 | param_groups[group_name]["params"].append(p) 58 | 59 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 60 | 61 | return list(param_groups.values()) 62 | 63 | 64 | def get_layer_id_for_vit(name, num_layers): 65 | """ 66 | Assign a parameter with its layer id 67 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 68 | """ 69 | if name in ['cls_token', 'pos_embed']: 70 | return 0 71 | elif name.startswith('patch_embed'): 72 | return 0 73 | elif name.startswith('blocks'): 74 | return int(name.split('.')[1]) + 1 75 | else: 76 | return num_layers -------------------------------------------------------------------------------- /mamba_local/csrc/selective_scan/uninitialized_copy.cuh: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted provided that the following conditions are met: 6 | * * Redistributions of source code must retain the above copyright 7 | * notice, this list of conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright 9 | * notice, this list of conditions and the following disclaimer in the 10 | * documentation and/or other materials provided with the distribution. 11 | * * Neither the name of the NVIDIA CORPORATION nor the 12 | * names of its contributors may be used to endorse or promote products 13 | * derived from this software without specific prior written permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 18 | * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | * 26 | ******************************************************************************/ 27 | 28 | #pragma once 29 | 30 | #include 31 | 32 | #include 33 | 34 | 35 | namespace detail 36 | { 37 | 38 | #if defined(_NVHPC_CUDA) 39 | template 40 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val) 41 | { 42 | // NVBug 3384810 43 | new (ptr) T(::cuda::std::forward(val)); 44 | } 45 | #else 46 | template ::value, 50 | int 51 | >::type = 0> 52 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val) 53 | { 54 | *ptr = ::cuda::std::forward(val); 55 | } 56 | 57 | template ::value, 61 | int 62 | >::type = 0> 63 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val) 64 | { 65 | new (ptr) T(::cuda::std::forward(val)); 66 | } 67 | #endif 68 | 69 | } // namespace detail 70 | -------------------------------------------------------------------------------- /mamba_local/csrc/selective_scan/selective_scan.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | //////////////////////////////////////////////////////////////////////////////////////////////////// 8 | 9 | struct SSMScanParamsBase { 10 | using index_t = uint32_t; 11 | 12 | int batch, seqlen, n_chunks; 13 | index_t a_batch_stride; 14 | index_t b_batch_stride; 15 | index_t out_batch_stride; 16 | 17 | // Common data pointers. 18 | void *__restrict__ a_ptr; 19 | void *__restrict__ b_ptr; 20 | void *__restrict__ out_ptr; 21 | void *__restrict__ x_ptr; 22 | }; 23 | 24 | //////////////////////////////////////////////////////////////////////////////////////////////////// 25 | 26 | struct SSMParamsBase { 27 | using index_t = uint32_t; 28 | 29 | int batch, dim, seqlen, dstate, n_groups, n_chunks; 30 | int dim_ngroups_ratio; 31 | bool is_variable_B; 32 | bool is_variable_C; 33 | 34 | bool delta_softplus; 35 | 36 | index_t A_d_stride; 37 | index_t A_dstate_stride; 38 | index_t B_batch_stride; 39 | index_t B_d_stride; 40 | index_t B_dstate_stride; 41 | index_t B_group_stride; 42 | index_t C_batch_stride; 43 | index_t C_d_stride; 44 | index_t C_dstate_stride; 45 | index_t C_group_stride; 46 | index_t u_batch_stride; 47 | index_t u_d_stride; 48 | index_t delta_batch_stride; 49 | index_t delta_d_stride; 50 | index_t z_batch_stride; 51 | index_t z_d_stride; 52 | index_t out_batch_stride; 53 | index_t out_d_stride; 54 | index_t out_z_batch_stride; 55 | index_t out_z_d_stride; 56 | 57 | // Common data pointers. 58 | void *__restrict__ A_ptr; 59 | void *__restrict__ B_ptr; 60 | void *__restrict__ C_ptr; 61 | void *__restrict__ D_ptr; 62 | void *__restrict__ u_ptr; 63 | void *__restrict__ delta_ptr; 64 | void *__restrict__ delta_bias_ptr; 65 | void *__restrict__ out_ptr; 66 | void *__restrict__ x_ptr; 67 | void *__restrict__ z_ptr; 68 | void *__restrict__ out_z_ptr; 69 | }; 70 | 71 | struct SSMParamsBwd: public SSMParamsBase { 72 | index_t dout_batch_stride; 73 | index_t dout_d_stride; 74 | index_t dA_d_stride; 75 | index_t dA_dstate_stride; 76 | index_t dB_batch_stride; 77 | index_t dB_group_stride; 78 | index_t dB_d_stride; 79 | index_t dB_dstate_stride; 80 | index_t dC_batch_stride; 81 | index_t dC_group_stride; 82 | index_t dC_d_stride; 83 | index_t dC_dstate_stride; 84 | index_t du_batch_stride; 85 | index_t du_d_stride; 86 | index_t dz_batch_stride; 87 | index_t dz_d_stride; 88 | index_t ddelta_batch_stride; 89 | index_t ddelta_d_stride; 90 | 91 | // Common data pointers. 92 | void *__restrict__ dout_ptr; 93 | void *__restrict__ dA_ptr; 94 | void *__restrict__ dB_ptr; 95 | void *__restrict__ dC_ptr; 96 | void *__restrict__ dD_ptr; 97 | void *__restrict__ du_ptr; 98 | void *__restrict__ dz_ptr; 99 | void *__restrict__ ddelta_ptr; 100 | void *__restrict__ ddelta_bias_ptr; 101 | }; 102 | -------------------------------------------------------------------------------- /augmentation/cutblur.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .augmentor import DataAugment 3 | from skimage.transform import resize 4 | 5 | class CutBlur(DataAugment): 6 | """3D CutBlur data augmentation, adapted from https://arxiv.org/abs/2004.00448. 7 | 8 | Randomly downsample a cuboid region in the volume to force the model 9 | to learn super-resolution when making predictions. 10 | 11 | Args: 12 | length_ratio (float): the ratio of the cuboid length compared with volume length. 13 | down_ratio_min (float): minimal downsample ratio to generate low-res region. 14 | down_ratio_max (float): maximal downsample ratio to generate low-res region. 15 | downsample_z (bool): downsample along the z axis (default: False). 16 | p (float): probability of applying the augmentation. 17 | """ 18 | 19 | def __init__(self, 20 | length_ratio=0.25, 21 | down_ratio_min=2.0, 22 | down_ratio_max=8.0, 23 | downsample_z=False, 24 | p=0.5): 25 | super(CutBlur, self).__init__(p=p) 26 | self.length_ratio = length_ratio 27 | self.down_ratio_min = down_ratio_min 28 | self.down_ratio_max = down_ratio_max 29 | self.downsample_z = downsample_z 30 | 31 | def set_params(self): 32 | # No change in sample size 33 | pass 34 | 35 | def cut_blur(self, data, random_state): 36 | images = data['image'].copy() 37 | labels = data['label'].copy() 38 | 39 | zdim = images.shape[0] 40 | 41 | if zdim > 1: 42 | zl, zh = self.random_region(images.shape[0], random_state) 43 | yl, yh = self.random_region(images.shape[1], random_state) 44 | xl, xh = self.random_region(images.shape[2], random_state) 45 | 46 | if zdim == 1: 47 | temp = images[:, yl:yh, xl:xh].copy() 48 | else: 49 | temp = images[zl:zh, yl:yh, xl:xh].copy() 50 | 51 | down_ratio = random_state.uniform(self.down_ratio_min, self.down_ratio_max) 52 | if zdim > 1 and self.downsample_z: 53 | out_shape = np.array(temp.shape) / down_ratio 54 | else: 55 | out_shape = np.array(temp.shape) / np.array([1, down_ratio, down_ratio]) 56 | 57 | out_shape = out_shape.astype(int) 58 | downsampled = resize(temp, out_shape, order=1, mode='reflect', 59 | clip=True, preserve_range=True, anti_aliasing=True) 60 | upsampled = resize(downsampled, temp.shape, order=0, mode='reflect', 61 | clip=True, preserve_range=True, anti_aliasing=False) 62 | 63 | if zdim == 1: 64 | images[:, yl:yh, xl:xh] = upsampled 65 | else: 66 | images[zl:zh, yl:yh, xl:xh] = upsampled 67 | return images, labels 68 | 69 | 70 | def random_region(self, vol_len, random_state): 71 | cuboid_len = int(self.length_ratio * vol_len) 72 | low = random_state.randint(0, vol_len-cuboid_len) 73 | high = low + cuboid_len 74 | return low, high 75 | 76 | def __call__(self, data, random_state=np.random): 77 | new_images, new_labels = self.cut_blur(data, random_state) 78 | return {'image': new_images, 'label': new_labels} 79 | -------------------------------------------------------------------------------- /augmentation/warp.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from scipy.ndimage.filters import gaussian_filter 4 | 5 | from .augmentor import DataAugment 6 | 7 | class Elastic(DataAugment): 8 | """Elastic deformation of images as described in [Simard2003]_ (with modifications). 9 | The implementation is based on https://gist.github.com/erniejunior/601cdf56d2b424757de5. 10 | 11 | .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for 12 | Convolutional Neural Networks applied to Visual Document Analysis", in 13 | Proc. of the International Conference on Document Analysis and 14 | Recognition, 2003. 15 | 16 | Args: 17 | alpha (float): maximum pixel-moving distance of elastic deformation. Default: 10.0 18 | sigma (float): standard deviation of the Gaussian filter. Default: 4.0 19 | p (float): probability of applying the augmentation. Default: 0.5 20 | """ 21 | def __init__(self, 22 | alpha=10.0, 23 | sigma=4.0, 24 | p=0.5): 25 | 26 | super(Elastic, self).__init__(p) 27 | self.alpha = alpha 28 | self.sigma = sigma 29 | self.image_interpolation = cv2.INTER_LINEAR 30 | self.label_interpolation = cv2.INTER_NEAREST 31 | self.border_mode = cv2.BORDER_CONSTANT 32 | self.set_params() 33 | 34 | def set_params(self): 35 | max_margin = int(self.alpha) + 1 36 | self.sample_params['add'] = [0, max_margin, max_margin] 37 | 38 | def __call__(self, data, random_state=np.random): 39 | if 'label' in data and data['label'] is not None: 40 | image, label = data['image'], data['label'] 41 | else: 42 | image = data['image'] 43 | 44 | height, width = image.shape[-2:] # (c, z, y, x) 45 | 46 | dx = np.float32(gaussian_filter((random_state.rand(height, width) * 2 - 1), self.sigma) * self.alpha) 47 | dy = np.float32(gaussian_filter((random_state.rand(height, width) * 2 - 1), self.sigma) * self.alpha) 48 | 49 | x, y = np.meshgrid(np.arange(width), np.arange(height)) 50 | mapx, mapy = np.float32(x + dx), np.float32(y + dy) 51 | 52 | output = {} 53 | transformed_image = [] 54 | transformed_label = [] 55 | 56 | for i in range(image.shape[-3]): 57 | if image.ndim == 3: 58 | transformed_image.append(cv2.remap(image[i], mapx, mapy, 59 | self.image_interpolation, borderMode=self.border_mode)) 60 | else: 61 | temp = [cv2.remap(image[channel, i], mapx, mapy, self.image_interpolation, 62 | borderMode=self.border_mode) for channel in range(image.shape[0])] 63 | transformed_image.append(np.stack(temp, 0)) 64 | if 'label' in data and data['label'] is not None: 65 | transformed_label.append(cv2.remap(label[i], mapx, mapy, self.label_interpolation, borderMode=self.border_mode)) 66 | 67 | if image.ndim == 3: # (z,y,x) 68 | transformed_image = np.stack(transformed_image, 0) 69 | else: # (c,z,y,x) 70 | transformed_image = np.stack(transformed_image, 1) 71 | 72 | transformed_label = np.stack(transformed_label, 0) 73 | output['image'] = transformed_image 74 | output['label'] = transformed_label 75 | 76 | return output 77 | -------------------------------------------------------------------------------- /mamba_local/benchmarks/benchmark_generation_mamba_simple.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao, Albert Gu. 2 | 3 | import argparse 4 | import time 5 | import json 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from einops import rearrange 11 | 12 | from transformers import AutoTokenizer, AutoModelForCausalLM 13 | 14 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 15 | 16 | 17 | parser = argparse.ArgumentParser(description="Generation benchmarking") 18 | parser.add_argument("--model-name", type=str, default="state-spaces/mamba-130m") 19 | parser.add_argument("--prompt", type=str, default=None) 20 | parser.add_argument("--promptlen", type=int, default=100) 21 | parser.add_argument("--genlen", type=int, default=100) 22 | parser.add_argument("--temperature", type=float, default=1.0) 23 | parser.add_argument("--topk", type=int, default=1) 24 | parser.add_argument("--topp", type=float, default=1.0) 25 | parser.add_argument("--batch", type=int, default=1) 26 | args = parser.parse_args() 27 | 28 | repeats = 3 29 | device = "cuda" 30 | dtype = torch.float16 31 | 32 | print(f"Loading model {args.model_name}") 33 | is_mamba = args.model_name.startswith("state-spaces/mamba-") or "mamba" in args.model_name 34 | 35 | if is_mamba: 36 | tokenizer = AutoTokenizer.from_pretrained("/home/zhulianghui/VisionProjects/mamba/ckpts/gpt-neox-20b-tokenizer") 37 | model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype) 38 | else: 39 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 40 | model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map={"": device}, torch_dtype=dtype) 41 | model.eval() 42 | print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") 43 | 44 | torch.random.manual_seed(0) 45 | if args.prompt is None: 46 | input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device="cuda") 47 | attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda") 48 | else: 49 | tokens = tokenizer(args.prompt, return_tensors="pt") 50 | input_ids = tokens.input_ids.to(device=device) 51 | attn_mask = tokens.attention_mask.to(device=device) 52 | max_length = input_ids.shape[1] + args.genlen 53 | 54 | if is_mamba: 55 | fn = lambda: model.generate( 56 | input_ids=input_ids, 57 | max_length=max_length, 58 | cg=True, 59 | return_dict_in_generate=True, 60 | output_scores=True, 61 | enable_timing=False, 62 | temperature=args.temperature, 63 | top_k=args.topk, 64 | top_p=args.topp, 65 | ) 66 | else: 67 | fn = lambda: model.generate( 68 | input_ids=input_ids, 69 | attention_mask=attn_mask, 70 | max_length=max_length, 71 | return_dict_in_generate=True, 72 | pad_token_id=tokenizer.eos_token_id, 73 | do_sample=True, 74 | temperature=args.temperature, 75 | top_k=args.topk, 76 | top_p=args.topp, 77 | ) 78 | out = fn() 79 | if args.prompt is not None: 80 | print(tokenizer.batch_decode(out.sequences.tolist())) 81 | 82 | torch.cuda.synchronize() 83 | start = time.time() 84 | for _ in range(repeats): 85 | fn() 86 | torch.cuda.synchronize() 87 | print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}") 88 | print(f"{args.model_name} prompt processing + decoding time: {(time.time() - start) / repeats * 1000:.0f}ms") 89 | -------------------------------------------------------------------------------- /utils/gen_pseudo.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | from yaml.events import NodeEvent 5 | 6 | class GenPseudo(object): 7 | def __init__(self, mode='threshold', 8 | threshold=0.99, 9 | proportion=0.20): 10 | super(GenPseudo, self).__init__() 11 | self.mode = mode 12 | self.threshold = threshold 13 | self.proportion = proportion 14 | 15 | def __call__(self, inputs): 16 | return self.forward(inputs) 17 | 18 | def forward(self, inputs): 19 | inputs = inputs.detach().clone() 20 | mask = torch.zeros_like(inputs) 21 | if self.mode == 'threshold': 22 | inputs[inputs > self.threshold] = 1 23 | mask[inputs == 1] = 1 24 | inputs[inputs < (1-self.threshold)] = 0 25 | mask[inputs == 0] = 1 26 | return inputs, mask 27 | else: 28 | num_classes = 2 # binary classification 29 | pseudo_lb = [] 30 | masks = [] 31 | batch_size = inputs.shape[0] 32 | for k in range(batch_size): 33 | output_affs = inputs[k] 34 | output_affs_0 = 1 - output_affs.clone() 35 | output_affs_1 = output_affs.clone() 36 | output_affs_all = torch.stack([output_affs_0, output_affs_1], dim=0) 37 | probmap_max, pred_label = torch.max(output_affs_all, dim=0) 38 | for idx_cls in range(num_classes): 39 | out_div_all = [] 40 | for i in range(3): 41 | pred_label_temp = pred_label[i] 42 | probmap_max_temp = probmap_max[i] 43 | probmap_max_cls_temp = probmap_max_temp[pred_label_temp == idx_cls] 44 | if len(probmap_max_cls_temp) > 0: 45 | # probmap_max_cls_temp = probmap_max_cls_temp.view(probmap_max_cls_temp.size(0), -1) 46 | probmap_max_cls_temp = probmap_max_cls_temp[0:len(probmap_max_cls_temp)] 47 | probmap_max_cls_temp, _ = torch.sort(probmap_max_cls_temp, descending=True) 48 | len_cls = len(probmap_max_cls_temp) 49 | thresh_len = int(math.floor(len_cls * self.proportion)) 50 | thresh_temp = probmap_max_cls_temp[thresh_len - 1] 51 | out_div = torch.div(output_affs_all[idx_cls, i], thresh_temp) 52 | else: 53 | out_div = output_affs_all[idx_cls, i] 54 | out_div_all.append(out_div) 55 | out_div_all = torch.stack(out_div_all, dim=0) 56 | output_affs_all[idx_cls] = out_div_all 57 | 58 | rw_probmap_max, pseudo_label = torch.max(output_affs_all, dim=0) 59 | mask = torch.zeros_like(rw_probmap_max) 60 | mask[rw_probmap_max>=1] = 1 61 | pseudo_lb.append(pseudo_label) 62 | masks.append(mask) 63 | pseudo_lb = torch.stack(pseudo_lb, dim=0) 64 | masks = torch.stack(masks, dim=0) 65 | 66 | return pseudo_lb, masks 67 | 68 | 69 | if __name__ == "__main__": 70 | gen_pseudo = GenPseudo(mode='prop') 71 | pred = np.random.random((2,3,18,160,160)).astype(np.float32) 72 | pred = torch.tensor(pred).to('cuda:0') 73 | 74 | pseudo_lb, masks = gen_pseudo(pred) 75 | -------------------------------------------------------------------------------- /utils/fragment.py: -------------------------------------------------------------------------------- 1 | import mahotas 2 | import numpy as np 3 | from scipy import ndimage 4 | 5 | def randomlabel(segmentation): 6 | segmentation = segmentation.astype(np.uint32) 7 | uid = np.unique(segmentation) 8 | mid = int(uid.max()) + 1 9 | mapping = np.zeros(mid, dtype=segmentation.dtype) 10 | mapping[uid] = np.random.choice(len(uid), len(uid), replace=False).astype(segmentation.dtype)#(len(uid), dtype=segmentation.dtype) 11 | out = mapping[segmentation] 12 | out[segmentation==0] = 0 13 | return out 14 | 15 | def watershed(affs, seed_method, use_mahotas_watershed=True): 16 | affs_xy = 1.0 - 0.5*(affs[1] + affs[2]) 17 | depth = affs_xy.shape[0] 18 | fragments = np.zeros_like(affs[0]).astype(np.uint64) 19 | next_id = 1 20 | for z in range(depth): 21 | seeds, num_seeds = get_seeds(affs_xy[z], next_id=next_id, method=seed_method) 22 | if use_mahotas_watershed: 23 | fragments[z] = mahotas.cwatershed(affs_xy[z], seeds) 24 | else: 25 | fragments[z] = ndimage.watershed_ift((255.0*affs_xy[z]).astype(np.uint8), seeds) 26 | next_id += num_seeds 27 | return fragments 28 | 29 | def get_seeds(boundary, method='grid', next_id=1, seed_distance=10): 30 | if method == 'grid': 31 | height = boundary.shape[0] 32 | width = boundary.shape[1] 33 | seed_positions = np.ogrid[0:height:seed_distance, 0:width:seed_distance] 34 | num_seeds_y = seed_positions[0].size 35 | num_seeds_x = seed_positions[1].size 36 | num_seeds = num_seeds_x*num_seeds_y 37 | seeds = np.zeros_like(boundary).astype(np.int32) 38 | seeds[seed_positions] = np.arange(next_id, next_id + num_seeds).reshape((num_seeds_y,num_seeds_x)) 39 | 40 | if method == 'minima': 41 | minima = mahotas.regmin(boundary) 42 | seeds, num_seeds = mahotas.label(minima) 43 | seeds += next_id 44 | seeds[seeds==next_id] = 0 45 | 46 | if method == 'maxima_distance': 47 | distance = mahotas.distance(boundary<0.5) 48 | maxima = mahotas.regmax(distance) 49 | seeds, num_seeds = mahotas.label(maxima) 50 | seeds += next_id 51 | seeds[seeds==next_id] = 0 52 | 53 | return seeds, num_seeds 54 | 55 | 56 | def elf_watershed(affs): 57 | import elf.segmentation.watershed as ws 58 | affs = 1 - affs 59 | boundary_input = np.maximum(affs[1], affs[2]) 60 | fragments = np.zeros_like(boundary_input, dtype='uint64') 61 | offset = 0 62 | for z in range(fragments.shape[0]): 63 | wsz, max_id = ws.distance_transform_watershed(boundary_input[z], threshold=.25, sigma_seeds=2.) 64 | wsz += offset 65 | offset += max_id 66 | fragments[z] = wsz 67 | return fragments 68 | 69 | def relabel(seg): 70 | # get the unique labels 71 | uid = np.unique(seg) 72 | # ignore all-background samples 73 | if len(uid)==1 and uid[0] == 0: 74 | return seg 75 | 76 | uid = uid[uid > 0] 77 | mid = int(uid.max()) + 1 # get the maximum label for the segment 78 | 79 | # create an array from original segment id to reduced id 80 | m_type = seg.dtype 81 | mapping = np.zeros(mid, dtype=m_type) 82 | mapping[uid] = np.arange(1, len(uid) + 1, dtype=m_type) 83 | return mapping[seg] 84 | 85 | def remove_small(seg, thres=100): 86 | sz = seg.shape 87 | seg = seg.reshape(-1) 88 | uid, uc = np.unique(seg, return_counts=True) 89 | seg[np.in1d(seg,uid[uc 0.5 else '2D' 33 | else: 34 | mode = self.mode 35 | 36 | # apply augmentations 37 | if mode == '2D': 38 | data = self._augment2D(data, random_state) 39 | if mode == '3D': 40 | data = self._augment3D(data, random_state) 41 | return data 42 | 43 | def _augment2D(self, data, random_state=np.random): 44 | """ 45 | Adapted from ELEKTRONN (http://elektronn.org/). 46 | """ 47 | imgs = data['image'] 48 | transformedimgs = np.copy(imgs) 49 | ran = random_state.rand(transformedimgs.shape[-3]*3) 50 | 51 | for z in range(transformedimgs.shape[-3]): 52 | img = transformedimgs[z, :, :] 53 | img *= 1 + (ran[z*3] - 0.5)*self.CONTRAST_FACTOR 54 | img += (ran[z*3+1] - 0.5)*self.BRIGHTNESS_FACTOR 55 | img = np.clip(img, 0, 1) 56 | img **= 2.0**(ran[z*3+2]*2 - 1) 57 | transformedimgs[z, :, :] = img 58 | 59 | data['image'] = transformedimgs 60 | return data 61 | 62 | def _augment3D(self, data, random_state=np.random): 63 | """ 64 | Adapted from ELEKTRONN (http://elektronn.org/). 65 | """ 66 | ran = random_state.rand(3) 67 | 68 | imgs = data['image'] 69 | transformedimgs = np.copy(imgs) 70 | transformedimgs *= 1 + (ran[0] - 0.5)*self.CONTRAST_FACTOR 71 | transformedimgs += (ran[1] - 0.5)*self.BRIGHTNESS_FACTOR 72 | transformedimgs = np.clip(transformedimgs, 0, 1) 73 | transformedimgs **= 2.0**(ran[2]*2 - 1) 74 | 75 | data['image'] = transformedimgs 76 | return data 77 | 78 | def _invert(self, data, random_state=np.random): 79 | """ 80 | Invert input images 81 | """ 82 | imgs = data['image'] 83 | transformedimgs = np.copy(imgs) 84 | transformedimgs = 1.0-transformedimgs 85 | transformedimgs = np.clip(transformedimgs, 0, 1) 86 | 87 | data['image'] = transformedimgs 88 | return data 89 | 90 | #################################################################### 91 | ## Setters. 92 | #################################################################### 93 | 94 | def _set_mode(self, mode): 95 | """Set 2D/3D/mix greyscale value augmentation mode.""" 96 | assert mode=='2D' or mode=='3D' or mode=='mix' 97 | self.mode = mode 98 | -------------------------------------------------------------------------------- /augmentation/test_augmentor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import itertools 3 | import torch 4 | 5 | class TestAugmentor(object): 6 | """Test-time augmentor. 7 | 8 | Our test-time augmentation includes horizontal/vertical flips 9 | over the `xy`-plane, swap of `x` and `y` axes, and flip in `z`-dimension, 10 | resulting in 16 variants. Considering inference efficiency, we also 11 | provide the option to apply only `x-y` swap and `z`-flip, resulting in 4 variants. 12 | By default the test-time augmentor returns the pixel-wise mean value of the predictions. 13 | 14 | Args: 15 | mode (str): one of ``'min'``, ``'max'`` or ``'mean'``. Default: ``'mean'`` 16 | num_aug (int): number of data augmentation variants: 0, 4 or 16. Default: 4 17 | 18 | Examples:: 19 | >>> from connectomics.data.augmentation import TestAugmentor 20 | >>> test_augmentor = TestAugmentor(mode='mean', num_aug=16) 21 | >>> output = test_augmentor(model, inputs) # output is a numpy.ndarray on CPU 22 | """ 23 | def __init__(self, mode='mean', num_aug=4): 24 | self.mode = mode 25 | self.num_aug = num_aug 26 | assert num_aug in [0, 4, 16], "TestAugmentor.num_aug should be either 0, 4 or 16!" 27 | 28 | def __call__(self, model, data): 29 | out = None 30 | cc = 0 31 | if self.num_aug == 0: 32 | opts = itertools.product((False, ), (False, ), (False, ), (False, )) 33 | elif self.num_aug == 4: 34 | opts = itertools.product((False, ), (False, ), (False, True), (False, True)) 35 | else: 36 | opts = itertools.product((False, True), (False, True), (False, True), (False, True)) 37 | 38 | for xflip, yflip, zflip, transpose in opts: 39 | volume = data.clone() 40 | # b,c,z,y,x 41 | 42 | if xflip: 43 | volume = torch.flip(volume, [4]) 44 | if yflip: 45 | volume = torch.flip(volume, [3]) 46 | if zflip: 47 | volume = torch.flip(volume, [2]) 48 | if transpose: 49 | volume = torch.transpose(volume, 3, 4) 50 | # aff: 3*z*y*x 51 | vout = model(volume).detach().cpu() 52 | 53 | if transpose: # swap x-/y-affinity 54 | vout = torch.transpose(vout, 3, 4) 55 | if zflip: 56 | vout = torch.flip(vout, [2]) 57 | if yflip: 58 | vout = torch.flip(vout, [3]) 59 | if xflip: 60 | vout = torch.flip(vout, [4]) 61 | 62 | # cast to numpy array 63 | vout = vout.numpy() 64 | if out is None: 65 | if self.mode == 'min': 66 | out = np.ones(vout.shape, dtype=np.float32) 67 | elif self.mode == 'max': 68 | out = np.zeros(vout.shape, dtype=np.float32) 69 | elif self.mode == 'mean': 70 | out = np.zeros(vout.shape, dtype=np.float32) 71 | 72 | if self.mode == 'min': 73 | out = np.minimum(out, vout) 74 | elif self.mode == 'max': 75 | out = np.maximum(out, vout) 76 | elif self.mode == 'mean': 77 | out += vout 78 | cc+=1 79 | 80 | if self.mode == 'mean': 81 | out = out/cc 82 | 83 | return out 84 | 85 | def update_name(self, name): 86 | extension = "_" 87 | if self.num_aug == 4: 88 | extension += "tz" 89 | elif self.num_aug == 16: 90 | extension += "tzyx" 91 | else: 92 | return name 93 | 94 | # Update the suffix of the output filename to indicate 95 | # the use of test-time data augmentation. 96 | name_list = name.split('.') 97 | new_filename = name_list[0] + extension + "." + name_list[1] 98 | return new_filename 99 | -------------------------------------------------------------------------------- /augmentation/rescale.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import cv2 3 | import numpy as np 4 | from .augmentor import DataAugment 5 | from skimage.transform import resize 6 | 7 | class Rescale(DataAugment): 8 | """ 9 | Rescale augmentation. 10 | 11 | Args: 12 | low (float): lower bound of the random scale factor. Default: 0.8 13 | high (float): higher bound of the random scale factor. Default: 1.2 14 | fix_aspect (bool): fix aspect ratio or not. Default: False 15 | p (float): probability of applying the augmentation. Default: 0.5 16 | """ 17 | def __init__(self, low=0.8, high=1.2, fix_aspect=False, p=0.5): 18 | super(Rescale, self).__init__(p=p) 19 | self.low = low 20 | self.high = high 21 | self.fix_aspect = fix_aspect 22 | 23 | self.image_interpolation = 1 24 | self.label_interpolation = 0 25 | self.set_params() 26 | 27 | def set_params(self): 28 | assert (self.low >= 0.5) 29 | assert (self.low <= 1.0) 30 | ratio = 1.0 / self.low 31 | self.sample_params['ratio'] = [1.0, ratio, ratio] 32 | 33 | def random_scale(self, random_state): 34 | rand_scale = random_state.rand() * (self.high - self.low) + self.low 35 | return rand_scale 36 | 37 | def apply_rescale(self, image, label, sf_x, sf_y, random_state): 38 | # apply image and mask at the same time 39 | transformed_image = image.copy() 40 | transformed_label = label.copy() 41 | 42 | y_length = int(sf_y * image.shape[1]) 43 | if y_length <= image.shape[1]: 44 | y0 = random_state.randint(low=0, high=image.shape[1]-y_length+1) 45 | y1 = y0 + y_length 46 | transformed_image = transformed_image[:, y0:y1, :] 47 | transformed_label = transformed_label[:, y0:y1, :] 48 | else: 49 | y0 = int(np.floor((y_length - image.shape[1]) / 2)) 50 | y1 = int(np.ceil((y_length - image.shape[1]) / 2)) 51 | transformed_image = np.pad(transformed_image, ((0, 0),(y0, y1),(0, 0)), mode='constant') 52 | transformed_label = np.pad(transformed_label, ((0, 0),(y0, y1),(0, 0)), mode='constant') 53 | 54 | x_length = int(sf_x * image.shape[2]) 55 | if x_length <= image.shape[2]: 56 | x0 = random_state.randint(low=0, high=image.shape[2]-x_length+1) 57 | x1 = x0 + x_length 58 | transformed_image = transformed_image[:, :, x0:x1] 59 | transformed_label = transformed_label[:, :, x0:x1] 60 | else: 61 | x0 = int(np.floor((x_length - image.shape[2]) / 2)) 62 | x1 = int(np.ceil((x_length - image.shape[2]) / 2)) 63 | transformed_image = np.pad(transformed_image, ((0, 0),(0, 0),(x0, x1)), mode='constant') 64 | transformed_label = np.pad(transformed_label, ((0, 0),(0, 0),(x0, x1)), mode='constant') 65 | 66 | output_image = resize(transformed_image, image.shape, order=self.image_interpolation, mode='constant', cval=0, 67 | clip=True, preserve_range=True, anti_aliasing=True) 68 | output_label = resize(transformed_label, image.shape, order=self.label_interpolation, mode='constant', cval=0, 69 | clip=True, preserve_range=True, anti_aliasing=False) 70 | return output_image, output_label 71 | 72 | def __call__(self, data, random_state=np.random): 73 | 74 | if 'label' in data and data['label'] is not None: 75 | image, label = data['image'], data['label'] 76 | else: 77 | image, label = data['image'], None 78 | 79 | if self.fix_aspect: 80 | sf_x = self.random_scale(random_state) 81 | sf_y = sf_x 82 | else: 83 | sf_x = self.random_scale(random_state) 84 | sf_y = self.random_scale(random_state) 85 | 86 | output = {} 87 | output['image'], output['label'] = self.apply_rescale(image, label, sf_x, sf_y, random_state) 88 | 89 | return output 90 | -------------------------------------------------------------------------------- /mamba_local/mamba_ssm_local/causal_conv1d_local/causal_conv1d/causal_conv1d_interface.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao. 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | import causal_conv1d_cuda 8 | 9 | 10 | class CausalConv1dFn(torch.autograd.Function): 11 | @staticmethod 12 | def forward(ctx, x, weight, bias=None, activation=None): 13 | if activation not in [None, "silu", "swish"]: 14 | raise NotImplementedError("activation must be None, silu, or swish") 15 | if x.stride(2) != 1 and x.stride(1) != 1: 16 | x = x.contiguous() 17 | bias = bias.contiguous() if bias is not None else None 18 | ctx.save_for_backward(x, weight, bias) 19 | ctx.activation = activation in ["silu", "swish"] 20 | out = causal_conv1d_cuda.causal_conv1d_fwd(x, weight, bias, ctx.activation) 21 | return out 22 | 23 | @staticmethod 24 | def backward(ctx, dout): 25 | x, weight, bias = ctx.saved_tensors 26 | if dout.stride(2) != 1 and dout.stride(1) != 1: 27 | dout = dout.contiguous() 28 | # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the 29 | # backward of conv1d with the backward of chunk). 30 | # Here we just pass in None and dx will be allocated in the C++ code. 31 | dx, dweight, dbias = causal_conv1d_cuda.causal_conv1d_bwd( 32 | x, weight, bias, dout, None, ctx.activation 33 | ) 34 | return dx, dweight, dbias if bias is not None else None, None 35 | 36 | 37 | def causal_conv1d_fn(x, weight, bias=None, activation=None): 38 | """ 39 | x: (batch, dim, seqlen) 40 | weight: (dim, width) 41 | bias: (dim,) 42 | activation: either None or "silu" or "swish" 43 | 44 | out: (batch, dim, seqlen) 45 | """ 46 | return CausalConv1dFn.apply(x, weight, bias, activation) 47 | 48 | 49 | def causal_conv1d_ref(x, weight, bias=None, activation=None): 50 | """ 51 | x: (batch, dim, seqlen) 52 | weight: (dim, width) 53 | bias: (dim,) 54 | 55 | out: (batch, dim, seqlen) 56 | """ 57 | if activation not in [None, "silu", "swish"]: 58 | raise NotImplementedError("activation must be None, silu, or swish") 59 | dtype_in = x.dtype 60 | x = x.to(weight.dtype) 61 | seqlen = x.shape[-1] 62 | dim, width = weight.shape 63 | out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) 64 | out = out[..., :seqlen] 65 | return (out if activation is None else F.silu(out)).to(dtype=dtype_in) 66 | 67 | 68 | def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None): 69 | """ 70 | x: (batch, dim) 71 | conv_state: (batch, dim, width) 72 | weight: (dim, width) 73 | bias: (dim,) 74 | 75 | out: (batch, dim) 76 | """ 77 | if activation not in [None, "silu", "swish"]: 78 | raise NotImplementedError("activation must be None, silu, or swish") 79 | activation = activation in ["silu", "swish"] 80 | return causal_conv1d_cuda.causal_conv1d_update(x, conv_state, weight, bias, activation) 81 | 82 | 83 | def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None): 84 | """ 85 | x: (batch, dim) 86 | conv_state: (batch, dim, width) 87 | weight: (dim, width) 88 | bias: (dim,) 89 | 90 | out: (batch, dim) 91 | """ 92 | if activation not in [None, "silu", "swish"]: 93 | raise NotImplementedError("activation must be None, silu, or swish") 94 | dtype_in = x.dtype 95 | batch, dim = x.shape 96 | width = weight.shape[1] 97 | assert conv_state.shape == (batch, dim, width) 98 | assert weight.shape == (dim, width) 99 | conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) 100 | conv_state[:, :, -1] = x 101 | out = torch.sum(conv_state * weight, dim=-1) # (B D) 102 | if bias is not None: 103 | out += bias 104 | return (out if activation is None else F.silu(out)).to(dtype=dtype_in) 105 | -------------------------------------------------------------------------------- /augmentation/__init__.py: -------------------------------------------------------------------------------- 1 | from .composition import Compose 2 | from .augmentor import DataAugment 3 | from .test_augmentor import TestAugmentor 4 | 5 | # augmentation methods 6 | from .warp import Elastic 7 | from .grayscale import Grayscale 8 | from .flip import Flip 9 | from .rotation import Rotate 10 | from .rescale import Rescale 11 | from .misalign import MisAlignment 12 | from .missing_section import MissingSection 13 | from .missing_parts import MissingParts 14 | from .motion_blur import MotionBlur 15 | from .cutblur import CutBlur 16 | from .cutnoise import CutNoise 17 | from .mixup import MixupAugmentor 18 | 19 | __all__ = ['Compose', 20 | 'DataAugment', 21 | 'Elastic', 22 | 'Grayscale', 23 | 'Rotate', 24 | 'Rescale', 25 | 'MisAlignment', 26 | 'MissingSection', 27 | 'MissingParts', 28 | 'Flip', 29 | 'MotionBlur', 30 | 'CutBlur', 31 | 'CutNoise', 32 | 'MixupAugmentor', 33 | 'TestAugmentor'] 34 | 35 | 36 | def build_train_augmentor(cfg, keep_uncropped=False, keep_non_smoothed=False): 37 | # The two arguments, keep_uncropped and keep_non_smoothed, are used only 38 | # for debugging, which are False by defaults and can not be adjusted 39 | # in the config files. 40 | aug_list = [] 41 | #1. rotate 42 | if cfg.AUGMENTOR.ROTATE.ENABLED: 43 | aug_list.append(Rotate(p=cfg.AUGMENTOR.ROTATE.P)) 44 | 45 | #2. rescale 46 | if cfg.AUGMENTOR.RESCALE.ENABLED: 47 | aug_list.append(Rescale(p=cfg.AUGMENTOR.RESCALE.P)) 48 | 49 | #3. flip 50 | if cfg.AUGMENTOR.FLIP.ENABLED: 51 | aug_list.append(Flip(p=cfg.AUGMENTOR.FLIP.P, 52 | do_ztrans=cfg.AUGMENTOR.FLIP.DO_ZTRANS)) 53 | 54 | #4. elastic 55 | if cfg.AUGMENTOR.ELASTIC.ENABLED: 56 | aug_list.append(Elastic(alpha=cfg.AUGMENTOR.ELASTIC.ALPHA, 57 | sigma = cfg.AUGMENTOR.ELASTIC.SIGMA, 58 | p=cfg.AUGMENTOR.ELASTIC.P)) 59 | 60 | #5. grayscale 61 | if cfg.AUGMENTOR.GRAYSCALE.ENABLED: 62 | aug_list.append(Grayscale(p=cfg.AUGMENTOR.GRAYSCALE.P)) 63 | 64 | #6. missingparts 65 | if cfg.AUGMENTOR.MISSINGPARTS.ENABLED: 66 | aug_list.append(MissingParts(p=cfg.AUGMENTOR.MISSINGPARTS.P)) 67 | 68 | #7. missingsection 69 | if cfg.AUGMENTOR.MISSINGSECTION.ENABLED and not cfg.DATASET.DO_2D: 70 | aug_list.append(MissingSection(p=cfg.AUGMENTOR.MISSINGSECTION.P, 71 | num_sections=cfg.AUGMENTOR.MISSINGSECTION.NUM_SECTION)) 72 | 73 | #8. misalignment 74 | if cfg.AUGMENTOR.MISALIGNMENT.ENABLED and not cfg.DATASET.DO_2D: 75 | aug_list.append(MisAlignment(p=cfg.AUGMENTOR.MISALIGNMENT.P, 76 | displacement=cfg.AUGMENTOR.MISALIGNMENT.DISPLACEMENT, 77 | rotate_ratio=cfg.AUGMENTOR.MISALIGNMENT.ROTATE_RATIO)) 78 | #9. motion-blur 79 | if cfg.AUGMENTOR.MOTIONBLUR.ENABLED: 80 | aug_list.append(MotionBlur(p=cfg.AUGMENTOR.MOTIONBLUR.P, 81 | sections=cfg.AUGMENTOR.MOTIONBLUR.SECTIONS, 82 | kernel_size=cfg.AUGMENTOR.MOTIONBLUR.KERNEL_SIZE)) 83 | 84 | #10. cut-blur 85 | if cfg.AUGMENTOR.CUTBLUR.ENABLED: 86 | aug_list.append(CutBlur(p=cfg.AUGMENTOR.CUTBLUR.P, 87 | length_ratio=cfg.AUGMENTOR.CUTBLUR.LENGTH_RATIO, 88 | down_ratio_min=cfg.AUGMENTOR.CUTBLUR.DOWN_RATIO_MIN, 89 | down_ratio_max=cfg.AUGMENTOR.CUTBLUR.DOWN_RATIO_MAX, 90 | downsample_z=cfg.AUGMENTOR.CUTBLUR.DOWNSAMPLE_Z)) 91 | 92 | #11. cut-noise 93 | if cfg.AUGMENTOR.CUTNOISE.ENABLED: 94 | aug_list.append(CutNoise(p=cfg.AUGMENTOR.CUTNOISE.P, 95 | length_ratio=cfg.AUGMENTOR.CUTNOISE.LENGTH_RATIO, 96 | scale=cfg.AUGMENTOR.CUTNOISE.SCALE)) 97 | 98 | augmentor = Compose(aug_list, input_size=cfg.MODEL.INPUT_SIZE, smooth=cfg.AUGMENTOR.SMOOTH, 99 | keep_uncropped=keep_uncropped, keep_non_smoothed=keep_non_smoothed) 100 | 101 | return augmentor 102 | -------------------------------------------------------------------------------- /data/total_data_provider.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.data as data 5 | from glob import glob 6 | import os 7 | import h5py 8 | 9 | class TotalDataProvider(data.Dataset): 10 | def __init__(self, data_path, partition_number=0, total_partitions=6): 11 | self.data_path = data_path 12 | self.partition_number = partition_number 13 | self.total_partitions = total_partitions 14 | self.data_list = sorted(glob(os.path.join(data_path, '*h5'))) 15 | 16 | # Calculate partition size 17 | total_size = len(self.data_list) 18 | self.partition_size = total_size // self.total_partitions 19 | self.start_index = self.partition_size * partition_number 20 | self.end_index = (self.start_index + self.partition_size 21 | if partition_number < self.total_partitions - 1 22 | else total_size) 23 | 24 | def __len__(self): 25 | return self.end_index - self.start_index 26 | 27 | def scaler01(self, x): 28 | return (x - x.min()) / (x.max() - x.min() + 1e-8) 29 | 30 | def __getitem__(self, index): 31 | # Adjust index for the partition 32 | adjusted_index = self.start_index + index 33 | data_path = self.data_list[adjusted_index] 34 | with h5py.File(data_path, 'r') as f: 35 | data = f['main'][:] 36 | data = self.scaler01(data) 37 | data = torch.from_numpy(data).float() 38 | data = data.unsqueeze(0) 39 | return data, data_path 40 | 41 | if __name__ == '__main__': 42 | import sys 43 | sys.path.append('/data/ydchen/VLP/wafer4') 44 | from model_superhuman2 import UNet_PNI 45 | from monai.inferers import sliding_window_inference 46 | from omegaconf import OmegaConf 47 | import random 48 | from matplotlib import pyplot as plt 49 | from collections import OrderedDict 50 | 51 | cfg = OmegaConf.load('/data/ydchen/VLP/wafer4/config/seg_3d_ac4_data80.yaml') 52 | device = torch.device('cuda:0') 53 | model = UNet_PNI(in_planes=cfg.MODEL.input_nc, 54 | out_planes=cfg.MODEL.output_nc, 55 | filters=cfg.MODEL.filters, 56 | upsample_mode=cfg.MODEL.upsample_mode, 57 | decode_ratio=cfg.MODEL.decode_ratio, 58 | merge_mode=cfg.MODEL.merge_mode, 59 | pad_mode=cfg.MODEL.pad_mode, 60 | bn_mode=cfg.MODEL.bn_mode, 61 | relu_mode=cfg.MODEL.relu_mode, 62 | init_mode=cfg.MODEL.init_mode) 63 | ckpt_path = os.path.join('/LSEM/wafer_seg/model_trained_superhuman/model/model-132000.ckpt') 64 | checkpoint = torch.load(ckpt_path) 65 | 66 | new_state_dict = OrderedDict() 67 | state_dict = checkpoint['model_weights'] 68 | for k, v in state_dict.items(): 69 | if 'module' in k: 70 | name = k[7:] 71 | else: 72 | name = k 73 | new_state_dict[name] = v 74 | 75 | model.load_state_dict(new_state_dict) 76 | model = model.to(device) 77 | data_path = '/LSEM/user/chenyinda/total_mec_seg_final/affinity' 78 | save_dir = '/data/ydchen/VLP/wafer4/data/temp' 79 | os.makedirs(save_dir, exist_ok=True) 80 | dataset = TotalDataProvider(data_path) 81 | provider = data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0) 82 | model.eval() 83 | with torch.no_grad(): 84 | for i, batch in enumerate(provider): 85 | data, path = batch 86 | data = data.to(device) 87 | data = data.squeeze() 88 | if not 'affinity' in data_path: 89 | output = sliding_window_inference(data, (18, 160, 160), 8, model, overlap=0.25) 90 | else: 91 | output = data 92 | print(output.shape) 93 | total_len = output.shape[1] 94 | for slice_num in range(0, total_len, 10): 95 | 96 | # slice_num = random.randint(0, total_len-1) 97 | plt.subplot(1, 2, 1) 98 | plt.imshow(data[:, slice_num, :, :].cpu().numpy().transpose(1,2,0)) 99 | plt.subplot(1, 2, 2) 100 | plt.imshow(output[:, slice_num, :, :].cpu().numpy().transpose(1, 2, 0)) 101 | plt.savefig(os.path.join(save_dir, f'{i}_{slice_num}.png')) 102 | 103 | 104 | break -------------------------------------------------------------------------------- /src/eval_signle.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import waterz 3 | import os 4 | from utils.fragment import watershed, randomlabel, relabel 5 | from skimage.metrics import adapted_rand_error as adapted_rand_ref 6 | from skimage.metrics import variation_of_information as voi_ref 7 | from utils.lmc import mc_baseline 8 | from PIL import Image 9 | 10 | def draw_fragments_3d(pred): 11 | d,m,n = pred.shape 12 | ids = np.unique(pred) 13 | size = len(ids) 14 | print("the neurons number of pred is %d" % size) 15 | color_pred = np.zeros([d, m, n, 3]) 16 | idx = np.searchsorted(ids, pred) 17 | for i in range(3): 18 | color_val = np.random.randint(0, 255, ids.shape) 19 | if ids[0] == 0: 20 | color_val[0] = 0 21 | color_pred[:,:,:,i] = color_val[idx] 22 | color_pred = color_pred 23 | return color_pred 24 | 25 | print(f'load data...') 26 | data_zip = np.load('/h3cstore_ns/hyshi/wafer4_errorbar/unetr_MAE/unetr_MAE_720.npz') 27 | pred_affs = data_zip['pred_affs'] 28 | gt_seg = data_zip['gt_seg'] 29 | gt_affs = data_zip['gt_affs'] 30 | 31 | # valid_data = np.load('/h3cstore_ns/hyshi/valid_data.npy') 32 | # valid_label = np.load('/h3cstore_ns/hyshi/valid_label.npy') 33 | 34 | # valid_data = np.load('/h3cstore_ns/hyshi/valid_data_wafer4.npy') 35 | # valid_label = np.load('/h3cstore_ns/hyshi/valid_label_wafer4.npy') 36 | 37 | # valid_data = np.load('/h3cstore_ns/hyshi/valid_data_ac3.npy') 38 | # valid_label = np.load('/h3cstore_ns/hyshi/valid_label_ac3.npy') 39 | 40 | # valid_data = np.load('/h3cstore_ns/hyshi/valid_data_wafer36_2.npy') 41 | # valid_label = np.load('/h3cstore_ns/hyshi/valid_label_wafer36_2.npy') 42 | 43 | print('Waterz Segmentation...') 44 | fragments = watershed(pred_affs, 'maxima_distance') 45 | sf = 'OneMinus>' 46 | # sf = 'OneMinus>>' 47 | seg_waterz = list(waterz.agglomerate(pred_affs, [0.50], 48 | fragments=fragments, 49 | scoring_function=sf, 50 | discretize_queue=256))[0] 51 | seg_waterz = relabel(seg_waterz).astype(np.uint64) 52 | arand_waterz = adapted_rand_ref(gt_seg, seg_waterz, ignore_labels=(0))[0] 53 | voi_split_waterz, voi_merge_waterz = voi_ref(gt_seg, seg_waterz, ignore_labels=(0)) 54 | voi_sum_waterz = voi_split_waterz + voi_merge_waterz 55 | 56 | print('LMC Segmentation...') 57 | seg_lmc = mc_baseline(pred_affs) 58 | arand_lmc = adapted_rand_ref(gt_seg, seg_lmc, ignore_labels=(0))[0] 59 | voi_split_lmc, voi_merge_lmc = voi_ref(gt_seg, seg_lmc, ignore_labels=(0)) 60 | voi_sum_lmc = voi_split_lmc + voi_merge_lmc 61 | 62 | # print('Write the results...') 63 | # outfile.write(f'{file_name}: \n') 64 | # outfile.write('VOIm-waterz=%.6f, VOIs-waterz=%.6f, VOI-waterz=%.6f, ARAND-waterz=%.6f, VOIm-lmc=%.6f, VOIs-lmc=%.6f, VOI-lmc=%.6f, ARAND-lmc=%.6f\n' % \ 65 | # (voi_merge_waterz, voi_split_waterz, voi_sum_waterz, arand_waterz, voi_merge_lmc, voi_split_lmc, voi_sum_lmc, arand_lmc)) 66 | 67 | # print('Visualize...') 68 | # waterz_seg_color = draw_fragments_3d(seg_waterz) 69 | # label_color = draw_fragments_3d(valid_label) 70 | # zero_positions = (label_color == 0) 71 | # waterz_seg_color[zero_positions] = 0 72 | # label_img = Image.fromarray(label_color[0].astype(np.uint8)) 73 | # waterz_img = Image.fromarray(waterz_seg_color[0].astype(np.uint8)) 74 | # h, w = label_img.size 75 | # white_line = Image.new('RGB', (w, 16), (255, 255, 255)) 76 | # pred_aff_img = Image.fromarray((pred_affs[:,0]*255).astype(np.uint8).transpose(1,2,0)) 77 | # data_img = Image.fromarray((valid_data[0]*255).astype(np.uint8)) 78 | # affs_img = Image.fromarray((gt_affs[:,0]*255).astype(np.uint8).transpose(1,2,0)) 79 | # visual_img = Image.new('RGB', (w*5 + 16*4, h), (255, 255, 255)) 80 | # visual_img.paste(label_img, (0, 0)) 81 | # visual_img.paste(waterz_img, (w+16, 0)) 82 | # visual_img.paste(white_line, (w*2+16, 0)) 83 | # visual_img.paste(pred_aff_img, (w*2+16*2, 0)) 84 | # visual_img.paste(affs_img, (w*3+16*3, 0)) 85 | # visual_img.paste(white_line, (0, h)) 86 | # visual_img.paste(data_img, (w*4+16*4, 0)) 87 | # # visual_img.save('/h3cstore_ns/hyshi/InferenceWafer36_2_monai_visual/0520' + '/PEA_random_25000.png') 88 | # visual_img.save('/h3cstore_ns/hyshi/Visual_wafer4_result' + '/mamba3_ar11_1150test.png') 89 | 90 | print('VOIm-waterz=%.6f, VOIs-waterz=%.6f, VOI-waterz=%.6f, ARAND-waterz=%.6f, VOIm-lmc=%.6f, VOIs-lmc=%.6f, VOI-lmc=%.6f, ARAND-lmc=%.6f' % \ 91 | (voi_merge_waterz, voi_split_waterz, voi_sum_waterz, arand_waterz, voi_merge_lmc, voi_split_lmc, voi_sum_lmc, arand_lmc), flush=True) -------------------------------------------------------------------------------- /src/eval_single_25.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import waterz 3 | import os 4 | from utils.fragment import watershed, randomlabel, relabel 5 | from skimage.metrics import adapted_rand_error as adapted_rand_ref 6 | from skimage.metrics import variation_of_information as voi_ref 7 | from utils.lmc import mc_baseline 8 | from PIL import Image 9 | 10 | def draw_fragments_3d(pred): 11 | d,m,n = pred.shape 12 | ids = np.unique(pred) 13 | size = len(ids) 14 | print("the neurons number of pred is %d" % size) 15 | color_pred = np.zeros([d, m, n, 3]) 16 | idx = np.searchsorted(ids, pred) 17 | for i in range(3): 18 | color_val = np.random.randint(0, 255, ids.shape) 19 | if ids[0] == 0: 20 | color_val[0] = 0 21 | color_pred[:,:,:,i] = color_val[idx] 22 | color_pred = color_pred 23 | return color_pred 24 | 25 | print(f'load data...') 26 | data_zip = np.load('/h3cstore_ns/hyshi/wafer4_errorbar/unetr_MAE/unetr_MAE_740.npz') 27 | pred_affs = data_zip['pred_affs'] 28 | gt_seg = data_zip['gt_seg'] 29 | gt_affs = data_zip['gt_affs'] 30 | 31 | # valid_data = np.load('/h3cstore_ns/hyshi/valid_data.npy') 32 | # valid_label = np.load('/h3cstore_ns/hyshi/valid_label.npy') 33 | 34 | # valid_data = np.load('/h3cstore_ns/hyshi/valid_data_wafer4.npy') 35 | # valid_label = np.load('/h3cstore_ns/hyshi/valid_label_wafer4.npy') 36 | 37 | # valid_data = np.load('/h3cstore_ns/hyshi/valid_data_ac3.npy') 38 | # valid_label = np.load('/h3cstore_ns/hyshi/valid_label_ac3.npy') 39 | 40 | # valid_data = np.load('/h3cstore_ns/hyshi/valid_data_wafer36_2.npy') 41 | # valid_label = np.load('/h3cstore_ns/hyshi/valid_label_wafer36_2.npy') 42 | 43 | print('Waterz Segmentation...') 44 | fragments = watershed(pred_affs, 'maxima_distance') 45 | sf = 'OneMinus>' 46 | # sf = 'OneMinus>>' 47 | seg_waterz = list(waterz.agglomerate(pred_affs, [0.50], 48 | fragments=fragments, 49 | scoring_function=sf, 50 | discretize_queue=256))[0] 51 | seg_waterz = relabel(seg_waterz).astype(np.uint64) 52 | arand_waterz = adapted_rand_ref(gt_seg, seg_waterz, ignore_labels=(0))[0] 53 | voi_split_waterz, voi_merge_waterz = voi_ref(gt_seg, seg_waterz, ignore_labels=(0)) 54 | voi_sum_waterz = voi_split_waterz + voi_merge_waterz 55 | 56 | print('LMC Segmentation...') 57 | seg_lmc = mc_baseline(pred_affs) 58 | arand_lmc = adapted_rand_ref(gt_seg, seg_lmc, ignore_labels=(0))[0] 59 | voi_split_lmc, voi_merge_lmc = voi_ref(gt_seg, seg_lmc, ignore_labels=(0)) 60 | voi_sum_lmc = voi_split_lmc + voi_merge_lmc 61 | 62 | # print('Write the results...') 63 | # outfile.write(f'{file_name}: \n') 64 | # outfile.write('VOIm-waterz=%.6f, VOIs-waterz=%.6f, VOI-waterz=%.6f, ARAND-waterz=%.6f, VOIm-lmc=%.6f, VOIs-lmc=%.6f, VOI-lmc=%.6f, ARAND-lmc=%.6f\n' % \ 65 | # (voi_merge_waterz, voi_split_waterz, voi_sum_waterz, arand_waterz, voi_merge_lmc, voi_split_lmc, voi_sum_lmc, arand_lmc)) 66 | 67 | # print('Visualize...') 68 | # waterz_seg_color = draw_fragments_3d(seg_waterz) 69 | # label_color = draw_fragments_3d(valid_label) 70 | # zero_positions = (label_color == 0) 71 | # waterz_seg_color[zero_positions] = 0 72 | # label_img = Image.fromarray(label_color[0].astype(np.uint8)) 73 | # waterz_img = Image.fromarray(waterz_seg_color[0].astype(np.uint8)) 74 | # h, w = label_img.size 75 | # white_line = Image.new('RGB', (w, 16), (255, 255, 255)) 76 | # pred_aff_img = Image.fromarray((pred_affs[:,0]*255).astype(np.uint8).transpose(1,2,0)) 77 | # data_img = Image.fromarray((valid_data[0]*255).astype(np.uint8)) 78 | # affs_img = Image.fromarray((gt_affs[:,0]*255).astype(np.uint8).transpose(1,2,0)) 79 | # visual_img = Image.new('RGB', (w*5 + 16*4, h), (255, 255, 255)) 80 | # visual_img.paste(label_img, (0, 0)) 81 | # visual_img.paste(waterz_img, (w+16, 0)) 82 | # visual_img.paste(white_line, (w*2+16, 0)) 83 | # visual_img.paste(pred_aff_img, (w*2+16*2, 0)) 84 | # visual_img.paste(affs_img, (w*3+16*3, 0)) 85 | # visual_img.paste(white_line, (0, h)) 86 | # visual_img.paste(data_img, (w*4+16*4, 0)) 87 | # # visual_img.save('/h3cstore_ns/hyshi/InferenceWafer36_2_monai_visual/0520' + '/PEA_random_25000.png') 88 | # visual_img.save('/h3cstore_ns/hyshi/Visual_wafer4_result' + '/mamba3_ar11_1150test.png') 89 | 90 | print('VOIm-waterz=%.6f, VOIs-waterz=%.6f, VOI-waterz=%.6f, ARAND-waterz=%.6f, VOIm-lmc=%.6f, VOIs-lmc=%.6f, VOI-lmc=%.6f, ARAND-lmc=%.6f' % \ 91 | (voi_merge_waterz, voi_split_waterz, voi_sum_waterz, arand_waterz, voi_merge_lmc, voi_split_lmc, voi_sum_lmc, arand_lmc), flush=True) -------------------------------------------------------------------------------- /augmentation/misalign.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | from .augmentor import DataAugment 5 | 6 | class MisAlignment(DataAugment): 7 | """Mis-alignment data augmentation of image stacks. 8 | 9 | Args: 10 | displacement (int): maximum pixel displacement in `xy`-plane. Default: 16 11 | p (float): probability of applying the augmentation. Default: 0.5 12 | """ 13 | def __init__(self, 14 | displacement=16, 15 | rotate_ratio=0.0, 16 | p=0.5): 17 | super(MisAlignment, self).__init__(p=p) 18 | self.displacement = displacement 19 | self.rotate_ratio = rotate_ratio 20 | self.set_params() 21 | 22 | def set_params(self): 23 | self.sample_params['add'] = [0, 24 | int(math.ceil(self.displacement / 2.0)), 25 | int(math.ceil(self.displacement / 2.0))] 26 | 27 | def misalignment(self, data, random_state): 28 | images = data['image'].copy() 29 | labels = data['label'].copy() 30 | 31 | out_shape = (images.shape[0], 32 | images.shape[1]-self.displacement, 33 | images.shape[2]-self.displacement) 34 | new_images = np.zeros(out_shape, images.dtype) 35 | new_labels = np.zeros(out_shape, labels.dtype) 36 | 37 | x0 = random_state.randint(self.displacement) 38 | y0 = random_state.randint(self.displacement) 39 | x1 = random_state.randint(self.displacement) 40 | y1 = random_state.randint(self.displacement) 41 | idx = random_state.choice(np.array(range(1, out_shape[0]-1)), 1)[0] 42 | 43 | if random_state.rand() < 0.5: 44 | # slip misalignment 45 | new_images = images[:, y0:y0+out_shape[1], x0:x0+out_shape[2]] 46 | new_labels = labels[:, y0:y0+out_shape[1], x0:x0+out_shape[2]] 47 | new_images[idx] = images[idx, y1:y1+out_shape[1], x1:x1+out_shape[2]] 48 | new_labels[idx] = labels[idx, y1:y1+out_shape[1], x1:x1+out_shape[2]] 49 | else: 50 | # translation misalignment 51 | new_images[:idx] = images[:idx, y0:y0+out_shape[1], x0:x0+out_shape[2]] 52 | new_labels[:idx] = labels[:idx, y0:y0+out_shape[1], x0:x0+out_shape[2]] 53 | new_images[idx:] = images[idx:, y1:y1+out_shape[1], x1:x1+out_shape[2]] 54 | new_labels[idx:] = labels[idx:, y1:y1+out_shape[1], x1:x1+out_shape[2]] 55 | 56 | return new_images, new_labels 57 | 58 | def misalignment_rotate(self, data, random_state): 59 | images = data['image'].copy() 60 | labels = data['label'].copy() 61 | 62 | height, width = images.shape[-2:] 63 | assert height == width 64 | M = self.random_rotate_matrix(height, random_state) 65 | idx = random_state.choice(np.array(range(1, images.shape[0]-1)), 1)[0] 66 | 67 | if random_state.rand() < 0.5: 68 | # slip misalignment 69 | images[idx] = cv2.warpAffine(images[idx], M, (height,width), 1.0, 70 | flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT) 71 | labels[idx] = cv2.warpAffine(labels[idx], M, (height,width), 1.0, 72 | flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT) 73 | else: 74 | # translation misalignment 75 | for i in range(idx, images.shape[0]): 76 | images[i] = cv2.warpAffine(images[i], M, (height,width), 1.0, 77 | flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT) 78 | labels[i] = cv2.warpAffine(labels[i], M, (height,width), 1.0, 79 | flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT) 80 | 81 | new_images = images.copy() 82 | new_labels = labels.copy() 83 | 84 | return new_images, new_labels 85 | 86 | def random_rotate_matrix(self, height, random_state): 87 | x = (self.displacement / 2.0) 88 | y = ((height - self.displacement) / 2.0) * 1.42 89 | angle = math.asin(x/y) * 2.0 * 57.2958 # convert radians to degrees 90 | rand_angle = (random_state.rand() - 0.5) * 2.0 * angle 91 | M = cv2.getRotationMatrix2D((height/2, height/2), rand_angle, 1) 92 | return M 93 | 94 | def __call__(self, data, random_state=np.random): 95 | if random_state.rand() < self.rotate_ratio: 96 | new_images, new_labels = self.misalignment_rotate(data, random_state) 97 | else: 98 | new_images, new_labels = self.misalignment(data, random_state) 99 | return {'image': new_images, 'label': new_labels} 100 | -------------------------------------------------------------------------------- /util_mamba/MultiScaleAttention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class AdaptiveScale(nn.Module): 6 | """Adaptive scale module.""" 7 | def __init__(self, embed_dim): 8 | super().__init__() 9 | self.fc = nn.Linear(embed_dim, 1) 10 | 11 | def forward(self, x): 12 | scale = torch.sigmoid(self.fc(x.mean(dim=1))) # Assuming BxNxC -> BxC 13 | return scale 14 | 15 | class MultiScaleAttention(nn.Module): 16 | """ Multi-scale Attention Module with Frequency Emphasis and Adaptive Scaling """ 17 | def __init__(self, embed_dim, num_heads=12, max_scale=3, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 18 | super().__init__() 19 | self.max_scale = max_scale 20 | self.num_heads = num_heads 21 | self.embed_dim = embed_dim 22 | self.head_dim = embed_dim // num_heads 23 | self.scale = qk_scale or self.head_dim ** -0.5 24 | 25 | self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=qkv_bias) 26 | self.attn_drop = nn.Dropout(attn_drop) 27 | self.proj = nn.Linear(embed_dim, embed_dim) 28 | self.proj_drop = nn.Dropout(proj_drop) 29 | 30 | self.norm = nn.LayerNorm(embed_dim) 31 | self.adaptive_scale = AdaptiveScale(embed_dim) 32 | 33 | def forward(self, x): 34 | B, N, C = x.shape 35 | adaptive_scale = self.adaptive_scale(x) * self.max_scale 36 | 37 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 38 | q, k, v = qkv[0], qkv[1], qkv[2] 39 | 40 | # multi-scale attention with frequency emphasis 41 | attns = [] 42 | for i in range(int(adaptive_scale)): 43 | scale = 2 ** i 44 | q_s = q[:, :, :, ::scale] 45 | k_s = k[:, :, :, ::scale] 46 | v_s = v[:, :, :, ::scale] 47 | 48 | attn = (q_s @ k_s.transpose(-2, -1)) * self.scale 49 | attn = attn.softmax(dim=-1) 50 | attn = self.attn_drop(attn) 51 | attns.append(attn) 52 | 53 | attn = attn @ v_s 54 | attn = attn.transpose(1, 2).reshape(B, -1, C) 55 | attns.append(attn) 56 | 57 | attn = torch.cat(attns, dim=1) 58 | 59 | # projection 60 | x = self.proj(attn) 61 | x = self.proj_drop(x) 62 | x = x + x 63 | x = self.norm(x) 64 | return x 65 | 66 | class MultiScaleAttentionHighFre(nn.Module): 67 | """ Multi-scale Attention Module with Enhanced Focus on High-Frequency Information """ 68 | def __init__(self, embed_dim, num_heads=12, max_scale=3, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 69 | super().__init__() 70 | 71 | self.max_scale = max_scale 72 | self.num_heads = num_heads 73 | self.embed_dim = embed_dim 74 | self.head_dim = embed_dim // num_heads 75 | self.scale = qk_scale or self.head_dim ** -0.5 76 | 77 | self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=qkv_bias) 78 | self.attn_drop = nn.Dropout(attn_drop) 79 | self.proj = nn.Linear(embed_dim, embed_dim) 80 | self.proj_drop = nn.Dropout(proj_drop) 81 | 82 | self.norm = nn.LayerNorm(embed_dim) 83 | self.adaptive_scale = AdaptiveScale(embed_dim) 84 | 85 | self.freq_amplification_factor = 2 # 高频放大系数 86 | self.freq_threshold = 0.5 # 频率阈值 87 | 88 | def forward(self, x): 89 | B, N, C = x.shape 90 | adaptive_scale = self.adaptive_scale(x) * self.max_scale 91 | 92 | 93 | freq = torch.fft.fft(x, dim=1) 94 | freq_amp = torch.abs(freq) 95 | 96 | freq_weights = torch.sigmoid((freq_amp - self.freq_threshold) * self.freq_amplification_factor) 97 | freq_weights = freq_weights.reshape(B, N, 1, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 98 | 99 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 100 | q, k, v = qkv[0], qkv[1], qkv[2] 101 | 102 | attns = [] 103 | for i in range(int(adaptive_scale)): 104 | scale = 2 ** i 105 | q_s = q[:, :, :, ::scale] 106 | k_s = k[:, :, :, ::scale] 107 | v_s = v[:, :, :, ::scale] 108 | 109 | attn = (q_s @ k_s.transpose(-2, -1)) * self.scale 110 | attn = attn.softmax(dim=-1) 111 | 112 | 113 | attn = attn * freq_weights.unsqueeze(1).unsqueeze(2) 114 | 115 | attn = self.attn_drop(attn) 116 | attns.append(attn) 117 | 118 | attn = attn @ v_s 119 | attn = attn.transpose(1, 2).reshape(B, -1, C) 120 | attns.append(attn) 121 | 122 | attn = torch.cat(attns, dim=1) 123 | 124 | # projection 125 | x = self.proj(attn) 126 | x = self.proj_drop(x) 127 | x = x + x 128 | x = self.norm(x) 129 | return x 130 | 131 | -------------------------------------------------------------------------------- /utils/coordinate.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | 3 | class Coordinate(tuple): 4 | '''A ``tuple`` of integers. 5 | Allows the following element-wise operators: addition, subtraction, 6 | multiplication, division, absolute value, and negation. This allows to 7 | perform simple arithmetics with coordinates, e.g.:: 8 | shape = Coordinate((2, 3, 4)) 9 | voxel_size = Coordinate((10, 5, 1)) 10 | size = shape*voxel_size # == Coordinate((20, 15, 4)) 11 | ''' 12 | def __new__(cls, array_like): 13 | return super(Coordinate, cls).__new__( 14 | cls, 15 | [ 16 | int(x) 17 | if x is not None 18 | else None 19 | for x in array_like]) 20 | 21 | def dims(self): 22 | return len(self) 23 | 24 | def __neg__(self): 25 | return Coordinate( 26 | -a 27 | if a is not None 28 | else None 29 | for a in self) 30 | 31 | def __abs__(self): 32 | return Coordinate( 33 | abs(a) 34 | if a is not None 35 | else None 36 | for a in self) 37 | 38 | def __add__(self, other): 39 | assert isinstance( 40 | other, tuple), "can only add Coordinate or tuples to Coordinate" 41 | assert self.dims() == len(other), "can only add Coordinate of equal dimensions" 42 | return Coordinate( 43 | a+b 44 | if a is not None and b is not None 45 | else None 46 | for a, b in zip(self, other)) 47 | 48 | def __sub__(self, other): 49 | assert isinstance( 50 | other, tuple), "can only subtract Coordinate or tuples to Coordinate" 51 | assert self.dims() == len(other), "can only subtract Coordinate of equal dimensions" 52 | return Coordinate( 53 | a-b 54 | if a is not None and b is not None 55 | else None 56 | for a, b in zip(self, other)) 57 | 58 | def __mul__(self, other): 59 | if isinstance(other, tuple): 60 | assert self.dims() == len(other), "can only multiply Coordinate of equal dimensions" 61 | return Coordinate( 62 | a*b 63 | if a is not None and b is not None 64 | else None 65 | for a, b in zip(self, other)) 66 | 67 | elif isinstance(other, numbers.Number): 68 | return Coordinate( 69 | a*other 70 | if a is not None 71 | else None 72 | for a in self) 73 | else: 74 | raise TypeError( 75 | "multiplication of Coordinate with type %s not supported" % type(other)) 76 | 77 | def __div__(self, other): 78 | if isinstance(other, tuple): 79 | assert self.dims() == len(other), "can only divide Coordinate of equal dimensions" 80 | return Coordinate( 81 | a/b 82 | if a is not None and b is not None 83 | else None 84 | for a, b in zip(self, other)) 85 | elif isinstance(other, numbers.Number): 86 | return Coordinate( 87 | a/other 88 | if a is not None 89 | else None 90 | for a in self) 91 | else: 92 | raise TypeError( 93 | "division of Coordinate with type %s not supported" % type(other)) 94 | 95 | def __truediv__(self, other): 96 | if isinstance(other, tuple): 97 | assert self.dims() == len(other), "can only divide Coordinate of equal dimensions" 98 | return Coordinate( 99 | a/b 100 | if a is not None and b is not None 101 | else None 102 | for a, b in zip(self, other)) 103 | elif isinstance(other, numbers.Number): 104 | return Coordinate( 105 | a/other 106 | if a is not None 107 | else None 108 | for a in self) 109 | else: 110 | raise TypeError( 111 | "division of Coordinate with type %s not supported" % type(other)) 112 | 113 | def __floordiv__(self, other): 114 | if isinstance(other, tuple): 115 | assert self.dims() == len(other), "can only divide Coordinate of equal dimensions" 116 | return Coordinate( 117 | a//b 118 | if a is not None and b is not None 119 | else None 120 | for a, b in zip(self, other)) 121 | elif isinstance(other, numbers.Number): 122 | return Coordinate( 123 | a//other 124 | if a is not None 125 | else None 126 | for a in self) 127 | else: 128 | raise TypeError( 129 | "division of Coordinate with type %s not supported" % type(other)) 130 | 131 | 132 | -------------------------------------------------------------------------------- /mamba_local/mamba_ssm_local/causal_conv1d_local/csrc/causal_conv1d_update.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #include 6 | #include 7 | #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK 8 | 9 | #include 10 | #include 11 | 12 | #include "causal_conv1d.h" 13 | #include "causal_conv1d_common.h" 14 | #include "static_switch.h" 15 | 16 | template 17 | struct Causal_conv1d_update_kernel_traits { 18 | using input_t = input_t_; 19 | using weight_t = weight_t_; 20 | static constexpr int kNThreads = kNThreads_; 21 | static constexpr int kWidth = kWidth_; 22 | static constexpr int kNBytes = sizeof(input_t); 23 | static_assert(kNBytes == 2 || kNBytes == 4); 24 | }; 25 | 26 | template 27 | __global__ __launch_bounds__(Ktraits::kNThreads) 28 | void causal_conv1d_update_kernel(ConvParamsBase params) { 29 | constexpr int kWidth = Ktraits::kWidth; 30 | constexpr int kNThreads = Ktraits::kNThreads; 31 | using input_t = typename Ktraits::input_t; 32 | using weight_t = typename Ktraits::weight_t; 33 | 34 | const int tidx = threadIdx.x; 35 | const int batch_id = blockIdx.x; 36 | const int channel_id = blockIdx.y * kNThreads + tidx; 37 | input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride 38 | + channel_id * params.x_c_stride; 39 | input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride 40 | + channel_id * params.conv_state_c_stride; 41 | weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; 42 | input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride 43 | + channel_id * params.out_c_stride; 44 | float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); 45 | 46 | float weight_vals[kWidth] = {0}; 47 | if (channel_id < params.dim) { 48 | #pragma unroll 49 | for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } 50 | } 51 | 52 | float x_vals[kWidth] = {0}; 53 | if (channel_id < params.dim) { 54 | #pragma unroll 55 | for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); } 56 | x_vals[kWidth - 1] = float(x[0]); 57 | #pragma unroll 58 | for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); } 59 | } 60 | 61 | float out_val = bias_val; 62 | #pragma unroll 63 | for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; } 64 | if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); } 65 | if (channel_id < params.dim) { out[0] = input_t(out_val); } 66 | } 67 | 68 | template 69 | void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) { 70 | using Ktraits = Causal_conv1d_update_kernel_traits; 71 | dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads); 72 | auto kernel = &causal_conv1d_update_kernel; 73 | kernel<<>>(params); 74 | C10_CUDA_KERNEL_LAUNCH_CHECK(); 75 | } 76 | 77 | template 78 | void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { 79 | if (params.width == 2) { 80 | causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream); 81 | } else if (params.width == 3) { 82 | causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream); 83 | } else if (params.width == 4) { 84 | causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream); 85 | } 86 | } 87 | 88 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 89 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 90 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 91 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 92 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 93 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 94 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 95 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 96 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /augmentation/missing_parts.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .augmentor import DataAugment 3 | 4 | from scipy.ndimage.interpolation import map_coordinates, zoom 5 | import numbers 6 | from skimage.draw import line 7 | from scipy.ndimage.filters import gaussian_filter 8 | from scipy.ndimage.measurements import label 9 | from scipy.ndimage.morphology import binary_dilation 10 | 11 | class MissingParts(DataAugment): 12 | """Missing-parts augmentation of image stacks. 13 | 14 | Args: 15 | deformation_strength (int): Default: 0 16 | iterations (int): Default: 40 17 | deform_ratio (float): Default: 0.25 18 | p (float): probability of applying the augmentation. Default: 0.5 19 | """ 20 | def __init__(self, 21 | deformation_strength=0, 22 | iterations=40, 23 | deform_ratio=0.25, 24 | p=0.5): 25 | super(MissingParts, self).__init__(p=p) 26 | self.deformation_strength = deformation_strength 27 | self.iterations = iterations 28 | self.set_params() 29 | 30 | def set_params(self): 31 | # No change in sample size 32 | pass 33 | 34 | def prepare_deform_slice(self, slice_shape, random_state): 35 | # grow slice shape by 2 x deformation strength 36 | grow_by = 2 * self.deformation_strength 37 | #print ('sliceshape: '+str(slice_shape[0])+' growby: '+str(grow_by)+ ' strength: '+str(deformation_strength)) 38 | shape = (slice_shape[0] + grow_by, slice_shape[1] + grow_by) 39 | # randomly choose fixed x or fixed y with p = 1/2 40 | fixed_x = random_state.rand() < 0.5 41 | if fixed_x: 42 | x0, y0 = 0, np.random.randint(1, shape[1] - 2) 43 | x1, y1 = shape[0] - 1, np.random.randint(1, shape[1] - 2) 44 | else: 45 | x0, y0 = np.random.randint(1, shape[0] - 2), 0 46 | x1, y1 = np.random.randint(1, shape[0] - 2), shape[1] - 1 47 | 48 | ## generate the mask of the line that should be blacked out 49 | #print (shape) 50 | line_mask = np.zeros(shape, dtype='bool') 51 | rr, cc = line(x0, y0, x1, y1) 52 | line_mask[rr, cc] = 1 53 | 54 | # generate vectorfield pointing towards the line to compress the image 55 | # first we get the unit vector representing the line 56 | line_vector = np.array([x1 - x0, y1 - y0], dtype='float32') 57 | line_vector /= np.linalg.norm(line_vector) 58 | # next, we generate the normal to the line 59 | normal_vector = np.zeros_like(line_vector) 60 | normal_vector[0] = - line_vector[1] 61 | normal_vector[1] = line_vector[0] 62 | 63 | # make meshgrid 64 | x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0])) 65 | # generate the vector field 66 | flow_x, flow_y = np.zeros(shape), np.zeros(shape) 67 | 68 | # find the 2 components where coordinates are bigger / smaller than the line 69 | # to apply normal vector in the correct direction 70 | components, n_components = label(np.logical_not(line_mask).view('uint8')) 71 | assert n_components == 2, "%i" % n_components 72 | neg_val = components[0, 0] if fixed_x else components[-1, -1] 73 | pos_val = components[-1, -1] if fixed_x else components[0, 0] 74 | 75 | flow_x[components == pos_val] = self.deformation_strength * normal_vector[1] 76 | flow_y[components == pos_val] = self.deformation_strength * normal_vector[0] 77 | flow_x[components == neg_val] = - self.deformation_strength * normal_vector[1] 78 | flow_y[components == neg_val] = - self.deformation_strength * normal_vector[0] 79 | 80 | # generate the flow fields 81 | flow_x, flow_y = (x + flow_x).reshape(-1, 1), (y + flow_y).reshape(-1, 1) 82 | 83 | # dilate the line mask 84 | line_mask = binary_dilation(line_mask, iterations=self.iterations) #default=10 85 | 86 | return flow_x, flow_y, line_mask 87 | 88 | def deform_2d(self, image2d, random_state): 89 | flow_x, flow_y, line_mask = self.prepare_deform_slice(image2d.shape, random_state) 90 | section = image2d.squeeze() 91 | mean = section.mean() 92 | shape = section.shape 93 | #interpolation=3 94 | section = map_coordinates(section, (flow_y, flow_x), mode='constant', 95 | order=3).reshape(int(flow_x.shape[0]**0.5),int(flow_x.shape[0]**0.5)) 96 | section = np.clip(section, 0., 1.) 97 | section[line_mask] = mean 98 | return section 99 | 100 | def apply_deform(self, imgs, random_state): 101 | transformedimgs = np.copy(imgs) 102 | sectionsnum = imgs.shape[0] 103 | i=0 104 | while i < sectionsnum: 105 | if random_state.rand() < self.p: 106 | transformedimgs[i] = self.deform_2d(imgs[i], random_state) 107 | i += 2 # only one deformed image in any consecutive 3 images 108 | i += 1 109 | return transformedimgs 110 | 111 | def __call__(self, data, random_state=np.random): 112 | augmented = self.apply_deform(data['image'], random_state) 113 | data['image'] = augmented 114 | return data 115 | -------------------------------------------------------------------------------- /utils/affinity_ours.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def gen_affs(map1, map2=None, dir=0, shift=1, padding=True, background=False): 5 | if dir == 0 and map2 is None: 6 | raise AttributeError('map2 is none') 7 | map1 = map1.astype(np.float32) 8 | h, w = map1.shape 9 | if dir == 0: 10 | map2 = map2.astype(np.float32) 11 | elif dir == 1: 12 | map2 = np.zeros_like(map1, dtype=np.float32) 13 | map2[shift:, :] = map1[:h-shift, :] 14 | elif dir == 2: 15 | map2 = np.zeros_like(map1, dtype=np.float32) 16 | map2[:, shift:] = map1[:, :w-shift] 17 | else: 18 | raise AttributeError('dir must be 0, 1 or 2') 19 | dif = map2 - map1 20 | out = dif.copy() 21 | out[dif == 0] = 1 22 | out[dif != 0] = 0 23 | if background: 24 | out[map1 == 0] = 0 25 | out[map2 == 0] = 0 26 | if padding: 27 | if dir == 1: 28 | # out[:shift, :] = (map1[:shift, :] > 0).astype(np.float32) 29 | out[:shift, :] = out[2*shift:shift:-1, :] 30 | if dir == 2: 31 | # out[:, :shift] = (map1[:, :shift] > 0).astype(np.float32) 32 | out[:, :shift] = out[:, 2*shift:shift:-1] 33 | else: 34 | if dir == 1: 35 | out[:shift, :] = 0 36 | if dir == 2: 37 | out[:, :shift] = 0 38 | return out 39 | 40 | def gen_affs_mutex(map1, map2, shift=0, padding=True, background=False): 41 | assert len(shift) == 3, 'the len(shift) must be 3' 42 | h, w = map1.shape 43 | map1 = map1.astype(np.float32) 44 | map2 = map2.astype(np.float32) 45 | 46 | if shift[1] <= 0 and shift[2] <= 0: 47 | map1[-shift[1]:, -shift[2]:] = map1[:h+shift[1], :w+shift[2]] 48 | elif shift[1] <= 0 and shift[2] > 0: 49 | map1[-shift[1]:, :w-shift[2]] = map1[:h+shift[1], shift[2]:] 50 | elif shift[1] > 0 and shift[2] <= 0: 51 | map1[:h-shift[1], -shift[2]:] = map1[shift[1]:, :w+shift[2]] 52 | elif shift[1] > 0 and shift[2] > 0: 53 | map1[:h-shift[1], :w-shift[2]] = map1[shift[1]:, shift[2]:] 54 | else: 55 | pass 56 | 57 | dif = map1 - map2 58 | out = dif.copy() 59 | out[dif == 0] = 1 60 | out[dif != 0] = 0 61 | if background: 62 | out[map1 == 0] = 0 63 | out[map2 == 0] = 0 64 | if padding: 65 | if shift[1] < 0: 66 | out[:-shift[1], :] = out[-2*shift[1]:-shift[1]:-1, :] 67 | elif shift[1] > 0: 68 | out[h-shift[1]:, :] = out[h-shift[1]-2:h-2*shift[1]-2:-1, :] 69 | else: 70 | pass 71 | if shift[2] < 0: 72 | out[:, :-shift[2]] = out[:, -2*shift[2]:-shift[2]:-1] 73 | elif shift[2] > 0: 74 | out[:, w-shift[2]:] = out[:, w-shift[2]-2:w-2*shift[2]-2:-1] 75 | else: 76 | pass 77 | else: 78 | if shift[1] < 0: 79 | out[:-shift[1], :] = 0 80 | elif shift[1] > 0: 81 | out[h-shift[1]:, :] = 0 82 | else: 83 | pass 84 | if shift[2] < 0: 85 | out[:, :-shift[2]] = 0 86 | elif shift[2] > 0: 87 | out[:, w-shift[2]:] = 0 88 | else: 89 | pass 90 | return out 91 | 92 | def gen_affs_3d(labels, shift=1, padding=True, background=False): 93 | assert len(labels.shape) == 3, '3D input' 94 | out = [] 95 | for i in range(labels.shape[0]): 96 | if i == 0: 97 | if padding: 98 | # affs0 = (labels[0] > 0).astype(np.float32) 99 | affs0 = gen_affs(labels[i], labels[i+1], dir=0, shift=shift, padding=padding, background=background) 100 | else: 101 | affs0 = np.zeros_like(labels[0], dtype=np.float32) 102 | else: 103 | affs0 = gen_affs(labels[i-1], labels[i], dir=0, shift=shift, padding=padding, background=background) 104 | affs1 = gen_affs(labels[i], None, dir=1, shift=shift, padding=padding, background=background) 105 | affs2 = gen_affs(labels[i], None, dir=2, shift=shift, padding=padding, background=background) 106 | affs = np.stack([affs0, affs1, affs2], axis=0) 107 | out.append(affs) 108 | out = np.asarray(out, dtype=np.float32) 109 | out = np.transpose(out, (1, 0, 2, 3)) 110 | return out 111 | 112 | def gen_affs_mutex_3d(labels, shift=[[-1, 0, 0], [0, -1, 0], [0, 0, -1]], padding=True, background=False): 113 | affs = [] 114 | for shift_k in shift: 115 | affs_k = [] 116 | for i in range(labels.shape[0]): 117 | if shift_k[0] != 0: 118 | if i == 0: 119 | if padding: 120 | temp = gen_affs_mutex(labels[0], labels[1], shift=shift_k, padding=padding, background=background) 121 | else: 122 | temp = np.zeros_like(labels[0], dtype=np.float32) 123 | else: 124 | temp = gen_affs_mutex(labels[i-1], labels[i], shift=shift_k, padding=padding, background=background) 125 | else: 126 | temp = gen_affs_mutex(labels[i], labels[i], shift=shift_k, padding=padding, background=background) 127 | affs_k.append(temp) 128 | affs.append(affs_k) 129 | affs = np.asarray(affs) 130 | return affs 131 | -------------------------------------------------------------------------------- /data/data_affinity.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # from Janelia pyGreentea 4 | # https://github.com/naibaf7/PyGreentea 5 | def mknhood2d(radius=1): 6 | # Makes nhood structures for some most used dense graphs. 7 | 8 | ceilrad = np.ceil(radius) 9 | x = np.arange(-ceilrad,ceilrad+1,1) 10 | y = np.arange(-ceilrad,ceilrad+1,1) 11 | [i,j] = np.meshgrid(y,x) 12 | 13 | idxkeep = (i**2+j**2)<=radius**2 14 | i=i[idxkeep].ravel(); j=j[idxkeep].ravel(); 15 | zeroIdx = np.ceil(len(i)/2).astype(np.int32); 16 | 17 | nhood = np.vstack((i[:zeroIdx],j[:zeroIdx])).T.astype(np.int32) 18 | nhood = np.ascontiguousarray(np.flipud(nhood)) 19 | nhood = nhood[1:] 20 | return nhood 21 | 22 | def mknhood3d(radius=1): 23 | # Makes nhood structures for some most used dense graphs. 24 | # The neighborhood reference for the dense graph representation we use 25 | # nhood(1,:) is a 3 vector that describe the node that conn(:,:,:,1) connects to 26 | # so to use it: conn(23,12,42,3) is the edge between node [23 12 42] and [23 12 42]+nhood(3,:) 27 | # See? It's simple! nhood is just the offset vector that the edge corresponds to. 28 | 29 | ceilrad = np.ceil(radius) 30 | x = np.arange(-ceilrad,ceilrad+1,1) 31 | y = np.arange(-ceilrad,ceilrad+1,1) 32 | z = np.arange(-ceilrad,ceilrad+1,1) 33 | [i,j,k] = np.meshgrid(z,y,x) 34 | 35 | idxkeep = (i**2+j**2+k**2)<=radius**2 36 | i=i[idxkeep].ravel(); j=j[idxkeep].ravel(); k=k[idxkeep].ravel(); 37 | zeroIdx = np.array(len(i) // 2).astype(np.int32); 38 | 39 | nhood = np.vstack((k[:zeroIdx],i[:zeroIdx],j[:zeroIdx])).T.astype(np.int32) 40 | return np.ascontiguousarray(np.flipud(nhood)) 41 | 42 | def mknhood3d_aniso(radiusxy=1,radiusxy_zminus1=1.8): 43 | # Makes nhood structures for some most used dense graphs. 44 | nhoodxyz = mknhood3d(radiusxy) 45 | nhoodxy_zminus1 = mknhood2d(radiusxy_zminus1) 46 | nhood = np.zeros((nhoodxyz.shape[0]+2*nhoodxy_zminus1.shape[0],3),dtype=np.int32) 47 | nhood[:3,:3] = nhoodxyz 48 | nhood[3:,0] = -1 49 | nhood[3:,1:] = np.vstack((nhoodxy_zminus1,-nhoodxy_zminus1)) 50 | 51 | return np.ascontiguousarray(nhood) 52 | 53 | def seg_to_aff(seg, nhood=mknhood3d(1), pad='replicate'): 54 | # constructs an affinity graph from a segmentation 55 | # assume affinity graph is represented as: 56 | # shape = (e, z, y, x) 57 | # nhood.shape = (edges, 3) 58 | shape = seg.shape 59 | nEdge = nhood.shape[0] 60 | aff = np.zeros((nEdge,)+shape,dtype=np.float32) 61 | 62 | if len(shape) == 3: # 3D affinity 63 | for e in range(nEdge): 64 | aff[e, \ 65 | max(0,-nhood[e,0]):min(shape[0],shape[0]-nhood[e,0]), \ 66 | max(0,-nhood[e,1]):min(shape[1],shape[1]-nhood[e,1]), \ 67 | max(0,-nhood[e,2]):min(shape[2],shape[2]-nhood[e,2])] = \ 68 | (seg[max(0,-nhood[e,0]):min(shape[0],shape[0]-nhood[e,0]), \ 69 | max(0,-nhood[e,1]):min(shape[1],shape[1]-nhood[e,1]), \ 70 | max(0,-nhood[e,2]):min(shape[2],shape[2]-nhood[e,2])] == \ 71 | seg[max(0,nhood[e,0]):min(shape[0],shape[0]+nhood[e,0]), \ 72 | max(0,nhood[e,1]):min(shape[1],shape[1]+nhood[e,1]), \ 73 | max(0,nhood[e,2]):min(shape[2],shape[2]+nhood[e,2])] ) \ 74 | * ( seg[max(0,-nhood[e,0]):min(shape[0],shape[0]-nhood[e,0]), \ 75 | max(0,-nhood[e,1]):min(shape[1],shape[1]-nhood[e,1]), \ 76 | max(0,-nhood[e,2]):min(shape[2],shape[2]-nhood[e,2])] > 0 ) \ 77 | * ( seg[max(0,nhood[e,0]):min(shape[0],shape[0]+nhood[e,0]), \ 78 | max(0,nhood[e,1]):min(shape[1],shape[1]+nhood[e,1]), \ 79 | max(0,nhood[e,2]):min(shape[2],shape[2]+nhood[e,2])] > 0 ) 80 | elif len(shape) == 2: # 2D affinity 81 | for e in range(nEdge): 82 | aff[e, \ 83 | max(0,-nhood[e,0]):min(shape[0],shape[0]-nhood[e,0]), \ 84 | max(0,-nhood[e,1]):min(shape[1],shape[1]-nhood[e,1])] = \ 85 | (seg[max(0,-nhood[e,0]):min(shape[0],shape[0]-nhood[e,0]), \ 86 | max(0,-nhood[e,1]):min(shape[1],shape[1]-nhood[e,1])] == \ 87 | seg[max(0,nhood[e,0]):min(shape[0],shape[0]+nhood[e,0]), \ 88 | max(0,nhood[e,1]):min(shape[1],shape[1]+nhood[e,1])] ) \ 89 | * ( seg[max(0,-nhood[e,0]):min(shape[0],shape[0]-nhood[e,0]), \ 90 | max(0,-nhood[e,1]):min(shape[1],shape[1]-nhood[e,1])] > 0 ) \ 91 | * ( seg[max(0,nhood[e,0]):min(shape[0],shape[0]+nhood[e,0]), \ 92 | max(0,nhood[e,1]):min(shape[1],shape[1]+nhood[e,1])] > 0 ) 93 | 94 | if nEdge==3 and pad == 'replicate': # pad the boundary affinity 95 | aff[0,0] = (seg[0]>0).astype(aff.dtype) 96 | aff[1,:,0] = (seg[:,0]>0).astype(aff.dtype) 97 | aff[2,:,:,0] = (seg[:,:,0]>0).astype(aff.dtype) 98 | elif nEdge==2 and pad == 'replicate': # pad the boundary affinity 99 | aff[0,0] = (seg[0]>0).astype(aff.dtype) 100 | aff[1,:,0] = (seg[:,0]>0).astype(aff.dtype) 101 | 102 | return aff 103 | --------------------------------------------------------------------------------