├── LICENSE ├── README.md ├── causal-conv1d ├── AUTHORS ├── LICENSE ├── README.md ├── causal_conv1d │ ├── __init__.py │ └── causal_conv1d_interface.py ├── csrc │ ├── causal_conv1d.cpp │ ├── causal_conv1d.h │ ├── causal_conv1d_bwd.cu │ ├── causal_conv1d_common.h │ ├── causal_conv1d_fwd.cu │ ├── causal_conv1d_update.cu │ └── static_switch.h ├── setup.py └── tests │ └── test_causal_conv1d.py ├── code ├── augmentations │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ └── ctaugment.cpython-38.pyc │ └── ctaugment.py ├── config.py ├── configs │ ├── swin_tiny_patch4_window7_224_lite.yaml │ └── vmamba_tiny.yaml ├── dataloaders │ ├── __pycache__ │ │ ├── dataset.cpython-310.pyc │ │ ├── dataset.cpython-311.pyc │ │ ├── dataset.cpython-38.pyc │ │ ├── dataset_s2l.cpython-38.pyc │ │ ├── dataset_semi.cpython-38.pyc │ │ ├── utils.cpython-310.pyc │ │ ├── utils.cpython-311.pyc │ │ └── utils.cpython-38.pyc │ ├── acdc_data_processing.py │ ├── acdc_pseudo_label_random_walker.py │ ├── dataset.py │ ├── dataset_s2l.py │ ├── dataset_semi.py │ └── utils.py ├── networks │ ├── VoxResNet.py │ ├── __pycache__ │ │ ├── attention.cpython-310.pyc │ │ ├── attention.cpython-38.pyc │ │ ├── config.cpython-310.pyc │ │ ├── config.cpython-38.pyc │ │ ├── discriminator.cpython-38.pyc │ │ ├── efficient_encoder.cpython-310.pyc │ │ ├── efficient_encoder.cpython-38.pyc │ │ ├── efficientunet.cpython-310.pyc │ │ ├── efficientunet.cpython-38.pyc │ │ ├── enet.cpython-310.pyc │ │ ├── enet.cpython-38.pyc │ │ ├── mamba_sys.cpython-310.pyc │ │ ├── net_factory.cpython-310.pyc │ │ ├── net_factory.cpython-38.pyc │ │ ├── neural_network.cpython-310.pyc │ │ ├── neural_network.cpython-38.pyc │ │ ├── nnunet.cpython-310.pyc │ │ ├── nnunet.cpython-38.pyc │ │ ├── pnet.cpython-310.pyc │ │ ├── pnet.cpython-38.pyc │ │ ├── swin_transformer_unet_skip_expand_decoder_sys.cpython-310.pyc │ │ ├── swin_transformer_unet_skip_expand_decoder_sys.cpython-38.pyc │ │ ├── unet.cpython-310.pyc │ │ ├── unet.cpython-38.pyc │ │ ├── vision_mamba.cpython-310.pyc │ │ ├── vision_transformer.cpython-310.pyc │ │ └── vision_transformer.cpython-38.pyc │ ├── attention.py │ ├── attention_unet.py │ ├── config.py │ ├── discriminator.py │ ├── efficient_encoder.py │ ├── efficientunet.py │ ├── encoder_tool.py │ ├── enet.py │ ├── grid_attention_layer.py │ ├── mamba_sys.py │ ├── net_factory.py │ ├── net_factory_3d.py │ ├── networks_other.py │ ├── neural_network.py │ ├── nnunet.py │ ├── pnet.py │ ├── segmamba.py │ ├── swin_transformer_unet_skip_expand_decoder_sys.py │ ├── unet.py │ ├── unet_3D.py │ ├── utils.py │ ├── vision_mamba.py │ ├── vision_transformer.py │ └── vnet.py ├── pretrained_ckpt │ └── readme.txt ├── scribbles_generator.py ├── test_2D.py ├── test_2D_ViT.py ├── train_weak_mamba_unet.py ├── train_weakly_supervised_pCE_2D.py ├── train_weakly_supervised_pCE_2D_ViT.py ├── train_weakly_supervised_ustm_2D_ViT.py ├── utils │ ├── __pycache__ │ │ ├── gate_crf_loss.cpython-310.pyc │ │ ├── gate_crf_loss.cpython-38.pyc │ │ ├── losses.cpython-310.pyc │ │ ├── losses.cpython-38.pyc │ │ ├── metrics.cpython-310.pyc │ │ ├── metrics.cpython-38.pyc │ │ ├── ramps.cpython-310.pyc │ │ └── ramps.cpython-38.pyc │ ├── gate_crf_loss.py │ ├── losses.py │ ├── metrics.py │ ├── ramps.py │ └── util.py └── val_2D.py ├── data └── readme.txt ├── img ├── results.png ├── wslframework.png └── wslintro.png └── mamba ├── .gitmodules ├── AUTHORS ├── LICENSE ├── README.md ├── assets └── selection.png ├── benchmarks └── benchmark_generation_mamba_simple.py ├── csrc └── selective_scan │ ├── reverse_scan.cuh │ ├── selective_scan.cpp │ ├── selective_scan.h │ ├── selective_scan_bwd_bf16_complex.cu │ ├── selective_scan_bwd_bf16_real.cu │ ├── selective_scan_bwd_fp16_complex.cu │ ├── selective_scan_bwd_fp16_real.cu │ ├── selective_scan_bwd_fp32_complex.cu │ ├── selective_scan_bwd_fp32_real.cu │ ├── selective_scan_bwd_kernel.cuh │ ├── selective_scan_common.h │ ├── selective_scan_fwd_bf16.cu │ ├── selective_scan_fwd_fp16.cu │ ├── selective_scan_fwd_fp32.cu │ ├── selective_scan_fwd_kernel.cuh │ ├── static_switch.h │ └── uninitialized_copy.cuh ├── evals └── lm_harness_eval.py ├── mamba_ssm ├── __init__.py ├── models │ ├── __init__.py │ └── mixer_seq_simple.py ├── modules │ ├── __init__.py │ └── mamba_simple.py ├── ops │ ├── __init__.py │ ├── selective_scan_interface.py │ └── triton │ │ ├── __init__.py │ │ ├── layernorm.py │ │ └── selective_state_update.py └── utils │ ├── __init__.py │ ├── generation.py │ └── hf.py ├── setup.py ├── test_mamba_module.py └── tests └── ops ├── test_selective_scan.py └── triton └── test_selective_state_update.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 ziyangwang007 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

Weak-Mamba-UNet:
Visual Mamba Makes CNN and ViT Work Better for Scribble-based Medical Image Segmentation

3 | 4 | [![arXiv](https://img.shields.io/badge/arXiv-2402.10887-b31b1b.svg)](https://arxiv.org/abs/2402.10887) 5 | 6 |
7 | 8 | > This repo provides an implementation of the training and inference pipeline for [Weak-Mamba-UNet](https://arxiv.org/abs/2402.10887). 9 | 10 | 11 | ## Contents ### 12 | - [Graphical Abstract](#Graphical-Abstract) 13 | - [Results](#Results) 14 | - [Requirements](#Requirements) 15 | - [Usage](#Usage) 16 | - [Reference](#Reference) 17 | - [Contact](#Contact) 18 | 19 | 20 | 21 | 22 | ## Graphical Abstract 23 | 24 | The introduction of Scribble Annotation 25 | 26 | 27 | 28 | The proposed Framework 29 | 30 | 31 | 32 | 33 | ## Results 34 | 35 | 36 | 37 | 38 | 39 | 40 | ## Requirements 41 | * Pytorch, MONAI 42 | * Some basic python packages: Torchio, Numpy, Scikit-image, SimpleITK, Scipy, Medpy, nibabel, tqdm ...... 43 | 44 | ```shell 45 | cd casual-conv1d 46 | 47 | python setup.py install 48 | ``` 49 | 50 | ```shell 51 | cd mamba 52 | 53 | python setup.py install 54 | ``` 55 | 56 | 57 | 58 | ## Usage 59 | 60 | 1. Clone the repo: 61 | ```shell 62 | git clone https://github.com/ziyangwang007/Weak-Mamba-UNet.git 63 | cd Weak-Mamba-UNet 64 | ``` 65 | 66 | 2. Download Pretrained Model 67 | 68 | Download through [Google Drive](https://drive.google.com/file/d/14RzbbBDjbKbgr0ordKlWbb69EFkHuplr/view?usp=sharing) for SwinUNet, and [[Google Drive]](https://drive.google.com/file/d/1uUPsr7XeqayCxlspqBHbg5zIWx0JYtSX/view?usp=sharing) for Mamba-UNet, and save in `../code/pretrained_ckpt`. 69 | 70 | 3. Download Dataset 71 | 72 | Download ACDC for Weak-Supervised learning through [[Google Drive]](https://drive.google.com/file/d/1XR_Id0wdvXY9QeKtdOdgJHKVJ-nVr2j1/view?usp=sharing), or [[Baidu Netdisk]](https://pan.baidu.com/s/1dHkp9daqE3kLEbAP6zl7Jw) with passcode: 'rwv2', and save in `../data/ACDC` folder. 73 | 74 | 75 | 4. Train 76 | 77 | ```shell 78 | cd code 79 | ``` 80 | 81 | 5. Train 2D UNet with pCE 82 | 83 | ```shell 84 | python train_weakly_supervised_pCE_2D.py 85 | ``` 86 | 87 | 6. Train 2D SwinUNet with pCE 88 | ```shell 89 | python train_weakly_supervised_pCE_2D_ViT.py 90 | ``` 91 | 92 | 7. Train 2D SwinUNet with MT and pCE 93 | ```shell 94 | python train_weakly_supervised_ustm_2D_ViT.py 95 | ``` 96 | 97 | 8. Train 2D Semi-Mamba-UNet with pCE 98 | ```shell 99 | python train_weak_mamba_unet.py 100 | ``` 101 | 102 | 9. Test 103 | 104 | Test CNN-based model 105 | ```shell 106 | python test_2D.py -root_path ../data/XXX --exp ACDC/XXX 107 | ``` 108 | Test ViT/Mamba-based model 109 | ```shell 110 | python test_2D_fully.py -root_path ../data/XXX --exp ACDC/XXX 111 | ``` 112 | 113 | 114 | ## Reference 115 | Wang, Ziyang, et al. "Mamba-unet: Unet-like pure visual mamba for medical image segmentation." arXiv preprint arXiv:2402.05079 (2024). 116 | 117 | Wang, Ziyang, and Chao Ma. "Weak-Mamba-UNet: Visual Mamba Makes CNN and ViT Work Better for Scribble-based Medical Image Segmentation." arXiv preprint arXiv:2402.10887 (2024). 118 | 119 | 120 | ```bibtex 121 | @article{wang2024mamba, 122 | title={Mamba-unet: Unet-like pure visual mamba for medical image segmentation}, 123 | author={Wang, Ziyang and Zheng, Jian-Qing and Zhang, Yichi and Cui, Ge and Li, Lei}, 124 | journal={arXiv preprint arXiv:2402.05079}, 125 | year={2024} 126 | } 127 | 128 | @article{wang2024weakmamba, 129 | title={Weak-Mamba-UNet: Visual Mamba Makes CNN and ViT Work Better for Scribble-based Medical Image Segmentation}, 130 | author={Wang, Ziyang and Ma, Chao}, 131 | journal={arXiv preprint arXiv:2402.10887}, 132 | year={2024} 133 | } 134 | ``` 135 | ## Contact 136 | 137 | ziyang [dot] wang17 [at] gmail [dot] com 138 | -------------------------------------------------------------------------------- /causal-conv1d/AUTHORS: -------------------------------------------------------------------------------- 1 | Tri Dao, tri@tridao.me 2 | -------------------------------------------------------------------------------- /causal-conv1d/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /causal-conv1d/README.md: -------------------------------------------------------------------------------- 1 | # Causal depthwise conv1d in CUDA with a PyTorch interface 2 | -------------------------------------------------------------------------------- /causal-conv1d/causal_conv1d/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.0" 2 | 3 | from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update 4 | -------------------------------------------------------------------------------- /causal-conv1d/causal_conv1d/causal_conv1d_interface.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao. 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | import causal_conv1d_cuda 8 | 9 | 10 | class CausalConv1dFn(torch.autograd.Function): 11 | @staticmethod 12 | def forward(ctx, x, weight, bias=None, activation=None): 13 | if activation not in [None, "silu", "swish"]: 14 | raise NotImplementedError("activation must be None, silu, or swish") 15 | if x.stride(2) != 1 and x.stride(1) != 1: 16 | x = x.contiguous() 17 | bias = bias.contiguous() if bias is not None else None 18 | ctx.save_for_backward(x, weight, bias) 19 | ctx.activation = activation in ["silu", "swish"] 20 | out = causal_conv1d_cuda.causal_conv1d_fwd(x, weight, bias, ctx.activation) 21 | return out 22 | 23 | @staticmethod 24 | def backward(ctx, dout): 25 | x, weight, bias = ctx.saved_tensors 26 | if dout.stride(2) != 1 and dout.stride(1) != 1: 27 | dout = dout.contiguous() 28 | # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the 29 | # backward of conv1d with the backward of chunk). 30 | # Here we just pass in None and dx will be allocated in the C++ code. 31 | dx, dweight, dbias = causal_conv1d_cuda.causal_conv1d_bwd( 32 | x, weight, bias, dout, None, ctx.activation 33 | ) 34 | return dx, dweight, dbias if bias is not None else None, None 35 | 36 | 37 | def causal_conv1d_fn(x, weight, bias=None, activation=None): 38 | """ 39 | x: (batch, dim, seqlen) 40 | weight: (dim, width) 41 | bias: (dim,) 42 | activation: either None or "silu" or "swish" 43 | 44 | out: (batch, dim, seqlen) 45 | """ 46 | return CausalConv1dFn.apply(x, weight, bias, activation) 47 | 48 | 49 | def causal_conv1d_ref(x, weight, bias=None, activation=None): 50 | """ 51 | x: (batch, dim, seqlen) 52 | weight: (dim, width) 53 | bias: (dim,) 54 | 55 | out: (batch, dim, seqlen) 56 | """ 57 | if activation not in [None, "silu", "swish"]: 58 | raise NotImplementedError("activation must be None, silu, or swish") 59 | dtype_in = x.dtype 60 | x = x.to(weight.dtype) 61 | seqlen = x.shape[-1] 62 | dim, width = weight.shape 63 | out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) 64 | out = out[..., :seqlen] 65 | return (out if activation is None else F.silu(out)).to(dtype=dtype_in) 66 | 67 | 68 | def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None): 69 | """ 70 | x: (batch, dim) 71 | conv_state: (batch, dim, width) 72 | weight: (dim, width) 73 | bias: (dim,) 74 | 75 | out: (batch, dim) 76 | """ 77 | if activation not in [None, "silu", "swish"]: 78 | raise NotImplementedError("activation must be None, silu, or swish") 79 | activation = activation in ["silu", "swish"] 80 | return causal_conv1d_cuda.causal_conv1d_update(x, conv_state, weight, bias, activation) 81 | 82 | 83 | def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None): 84 | """ 85 | x: (batch, dim) 86 | conv_state: (batch, dim, width) 87 | weight: (dim, width) 88 | bias: (dim,) 89 | 90 | out: (batch, dim) 91 | """ 92 | if activation not in [None, "silu", "swish"]: 93 | raise NotImplementedError("activation must be None, silu, or swish") 94 | dtype_in = x.dtype 95 | batch, dim = x.shape 96 | width = weight.shape[1] 97 | assert conv_state.shape == (batch, dim, width) 98 | assert weight.shape == (dim, width) 99 | conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) 100 | conv_state[:, :, -1] = x 101 | out = torch.sum(conv_state * weight, dim=-1) # (B D) 102 | if bias is not None: 103 | out += bias 104 | return (out if activation is None else F.silu(out)).to(dtype=dtype_in) 105 | -------------------------------------------------------------------------------- /causal-conv1d/csrc/causal_conv1d.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | //////////////////////////////////////////////////////////////////////////////////////////////////// 8 | 9 | struct ConvParamsBase { 10 | using index_t = uint32_t; 11 | 12 | int batch, dim, seqlen, width; 13 | bool silu_activation; 14 | 15 | index_t x_batch_stride; 16 | index_t x_c_stride; 17 | index_t x_l_stride; 18 | index_t weight_c_stride; 19 | index_t weight_width_stride; 20 | index_t out_batch_stride; 21 | index_t out_c_stride; 22 | index_t out_l_stride; 23 | 24 | index_t conv_state_batch_stride; 25 | index_t conv_state_c_stride; 26 | index_t conv_state_l_stride; 27 | 28 | // Common data pointers. 29 | void *__restrict__ x_ptr; 30 | void *__restrict__ weight_ptr; 31 | void *__restrict__ bias_ptr; 32 | void *__restrict__ out_ptr; 33 | 34 | void *__restrict__ conv_state_ptr; 35 | }; 36 | 37 | struct ConvParamsBwd: public ConvParamsBase { 38 | index_t dx_batch_stride; 39 | index_t dx_c_stride; 40 | index_t dx_l_stride; 41 | index_t dweight_c_stride; 42 | index_t dweight_width_stride; 43 | index_t dout_batch_stride; 44 | index_t dout_c_stride; 45 | index_t dout_l_stride; 46 | 47 | // Common data pointers. 48 | void *__restrict__ dx_ptr; 49 | void *__restrict__ dweight_ptr; 50 | void *__restrict__ dbias_ptr; 51 | void *__restrict__ dout_ptr; 52 | }; 53 | 54 | -------------------------------------------------------------------------------- /causal-conv1d/csrc/causal_conv1d_common.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | 10 | //////////////////////////////////////////////////////////////////////////////////////////////////// 11 | 12 | template struct BytesToType {}; 13 | 14 | template<> struct BytesToType<16> { 15 | using Type = uint4; 16 | static_assert(sizeof(Type) == 16); 17 | }; 18 | 19 | template<> struct BytesToType<8> { 20 | using Type = uint64_t; 21 | static_assert(sizeof(Type) == 8); 22 | }; 23 | 24 | template<> struct BytesToType<4> { 25 | using Type = uint32_t; 26 | static_assert(sizeof(Type) == 4); 27 | }; 28 | 29 | template<> struct BytesToType<2> { 30 | using Type = uint16_t; 31 | static_assert(sizeof(Type) == 2); 32 | }; 33 | 34 | template<> struct BytesToType<1> { 35 | using Type = uint8_t; 36 | static_assert(sizeof(Type) == 1); 37 | }; 38 | 39 | //////////////////////////////////////////////////////////////////////////////////////////////////// 40 | 41 | template 42 | struct SumOp { 43 | __device__ inline T operator()(T const & x, T const & y) { return x + y; } 44 | }; 45 | 46 | template 47 | struct Allreduce { 48 | static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); 49 | template 50 | static __device__ inline T run(T x, Operator &op) { 51 | constexpr int OFFSET = THREADS / 2; 52 | x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); 53 | return Allreduce::run(x, op); 54 | } 55 | }; 56 | 57 | template<> 58 | struct Allreduce<2> { 59 | template 60 | static __device__ inline T run(T x, Operator &op) { 61 | x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); 62 | return x; 63 | } 64 | }; 65 | -------------------------------------------------------------------------------- /causal-conv1d/csrc/causal_conv1d_update.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #include 6 | #include 7 | #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK 8 | 9 | #include 10 | #include 11 | 12 | #include "causal_conv1d.h" 13 | #include "causal_conv1d_common.h" 14 | #include "static_switch.h" 15 | 16 | template 17 | struct Causal_conv1d_update_kernel_traits { 18 | using input_t = input_t_; 19 | using weight_t = weight_t_; 20 | static constexpr int kNThreads = kNThreads_; 21 | static constexpr int kWidth = kWidth_; 22 | static constexpr int kNBytes = sizeof(input_t); 23 | static_assert(kNBytes == 2 || kNBytes == 4); 24 | }; 25 | 26 | template 27 | __global__ __launch_bounds__(Ktraits::kNThreads) 28 | void causal_conv1d_update_kernel(ConvParamsBase params) { 29 | constexpr int kWidth = Ktraits::kWidth; 30 | constexpr int kNThreads = Ktraits::kNThreads; 31 | using input_t = typename Ktraits::input_t; 32 | using weight_t = typename Ktraits::weight_t; 33 | 34 | const int tidx = threadIdx.x; 35 | const int batch_id = blockIdx.x; 36 | const int channel_id = blockIdx.y * kNThreads + tidx; 37 | input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride 38 | + channel_id * params.x_c_stride; 39 | input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride 40 | + channel_id * params.conv_state_c_stride; 41 | weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; 42 | input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride 43 | + channel_id * params.out_c_stride; 44 | float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); 45 | 46 | float weight_vals[kWidth] = {0}; 47 | if (channel_id < params.dim) { 48 | #pragma unroll 49 | for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } 50 | } 51 | 52 | float x_vals[kWidth] = {0}; 53 | if (channel_id < params.dim) { 54 | #pragma unroll 55 | for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); } 56 | x_vals[kWidth - 1] = float(x[0]); 57 | #pragma unroll 58 | for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); } 59 | } 60 | 61 | float out_val = bias_val; 62 | #pragma unroll 63 | for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; } 64 | if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); } 65 | if (channel_id < params.dim) { out[0] = input_t(out_val); } 66 | } 67 | 68 | template 69 | void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) { 70 | using Ktraits = Causal_conv1d_update_kernel_traits; 71 | dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads); 72 | auto kernel = &causal_conv1d_update_kernel; 73 | kernel<<>>(params); 74 | C10_CUDA_KERNEL_LAUNCH_CHECK(); 75 | } 76 | 77 | template 78 | void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { 79 | if (params.width == 2) { 80 | causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream); 81 | } else if (params.width == 3) { 82 | causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream); 83 | } else if (params.width == 4) { 84 | causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream); 85 | } 86 | } 87 | 88 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 89 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 90 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 91 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 92 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 93 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 94 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 95 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 96 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /causal-conv1d/csrc/static_switch.h: -------------------------------------------------------------------------------- 1 | // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h 2 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h 3 | 4 | #pragma once 5 | 6 | /// @param COND - a boolean expression to switch by 7 | /// @param CONST_NAME - a name given for the constexpr bool variable. 8 | /// @param ... - code to execute for true and false 9 | /// 10 | /// Usage: 11 | /// ``` 12 | /// BOOL_SWITCH(flag, BoolConst, [&] { 13 | /// some_function(...); 14 | /// }); 15 | /// ``` 16 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \ 17 | [&] { \ 18 | if (COND) { \ 19 | static constexpr bool CONST_NAME = true; \ 20 | return __VA_ARGS__(); \ 21 | } else { \ 22 | static constexpr bool CONST_NAME = false; \ 23 | return __VA_ARGS__(); \ 24 | } \ 25 | }() 26 | -------------------------------------------------------------------------------- /causal-conv1d/tests/test_causal_conv1d.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023, Tri Dao. 2 | 3 | import math 4 | 5 | import torch 6 | import pytest 7 | 8 | from einops import rearrange 9 | 10 | from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_ref 11 | from causal_conv1d.causal_conv1d_interface import causal_conv1d_update, causal_conv1d_update_ref 12 | 13 | 14 | @pytest.mark.parametrize("channel_last", [False, True]) 15 | # @pytest.mark.parametrize('channel_last', [True]) 16 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) 17 | # @pytest.mark.parametrize('itype', [torch.float16]) 18 | @pytest.mark.parametrize("silu_activation", [False, True]) 19 | # @pytest.mark.parametrize('silu_activation', [True]) 20 | @pytest.mark.parametrize("has_bias", [False, True]) 21 | # @pytest.mark.parametrize('has_bias', [True]) 22 | @pytest.mark.parametrize("width", [2, 3, 4]) 23 | # @pytest.mark.parametrize('width', [2]) 24 | @pytest.mark.parametrize( 25 | "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096] 26 | ) 27 | # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) 28 | # @pytest.mark.parametrize('seqlen', [128]) 29 | def test_causal_conv1d(seqlen, width, has_bias, silu_activation, itype, channel_last): 30 | device = "cuda" 31 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) 32 | if itype == torch.bfloat16: 33 | rtol, atol = 1e-2, 5e-2 34 | rtolw, atolw = (1e-3, 1e-3) 35 | # set seed 36 | torch.random.manual_seed(0) 37 | batch_size = 2 38 | # batch_size = 1 39 | dim = 4096 + 32 # Try dim not divisible by 64 40 | # dim = 64 41 | if not channel_last: 42 | x = torch.randn(batch_size, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_() 43 | else: 44 | x = rearrange( 45 | torch.randn(batch_size, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s" 46 | ).requires_grad_() 47 | weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True) 48 | if has_bias: 49 | bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) 50 | else: 51 | bias = None 52 | x_ref = x.detach().clone().requires_grad_() 53 | weight_ref = weight.detach().clone().requires_grad_() 54 | bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None 55 | activation = None if not silu_activation else "silu" 56 | out = causal_conv1d_fn(x, weight, bias, activation=activation) 57 | out_ref = causal_conv1d_ref(x_ref, weight_ref, bias_ref, activation=activation) 58 | 59 | print(f"Output max diff: {(out - out_ref).abs().max().item()}") 60 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") 61 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) 62 | 63 | g = torch.randn_like(out) 64 | out_ref.backward(g) 65 | out.backward(g) 66 | 67 | print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}") 68 | print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}") 69 | if has_bias: 70 | print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}") 71 | 72 | assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol) 73 | assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw) 74 | if has_bias: 75 | assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw) 76 | 77 | 78 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) 79 | # @pytest.mark.parametrize('itype', [torch.float16]) 80 | @pytest.mark.parametrize("silu_activation", [False, True]) 81 | # @pytest.mark.parametrize('silu_activation', [False]) 82 | @pytest.mark.parametrize("has_bias", [False, True]) 83 | # @pytest.mark.parametrize('has_bias', [True]) 84 | @pytest.mark.parametrize("width", [2, 3, 4]) 85 | # @pytest.mark.parametrize('width', [2]) 86 | @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) 87 | # @pytest.mark.parametrize("dim", [2048]) 88 | def test_causal_conv1d_update(dim, width, has_bias, silu_activation, itype): 89 | device = "cuda" 90 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) 91 | if itype == torch.bfloat16: 92 | rtol, atol = 1e-2, 5e-2 93 | rtolw, atolw = (1e-3, 1e-3) 94 | # set seed 95 | torch.random.manual_seed(0) 96 | batch_size = 2 97 | # batch_size = 1 98 | # dim = 64 99 | x = torch.randn(batch_size, dim, device=device, dtype=itype) 100 | conv_state = torch.randn(batch_size, dim, width, device=device, dtype=itype) 101 | weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True) 102 | if has_bias: 103 | bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) 104 | else: 105 | bias = None 106 | conv_state_ref = conv_state.detach().clone() 107 | activation = None if not silu_activation else "silu" 108 | out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation) 109 | out_ref = causal_conv1d_update_ref(x, conv_state_ref, weight, bias, activation=activation) 110 | 111 | print(f"Output max diff: {(out - out_ref).abs().max().item()}") 112 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") 113 | assert torch.equal(conv_state, conv_state_ref) 114 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) 115 | 116 | 117 | # @pytest.mark.parametrize("channel_last", [False, True]) 118 | @pytest.mark.parametrize('channel_last', [True]) 119 | # @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) 120 | @pytest.mark.parametrize('itype', [torch.bfloat16]) 121 | # @pytest.mark.parametrize("silu_activation", [False, True]) 122 | @pytest.mark.parametrize('silu_activation', [True]) 123 | # @pytest.mark.parametrize("has_bias", [False, True]) 124 | @pytest.mark.parametrize('has_bias', [True]) 125 | # @pytest.mark.parametrize("width", [2, 3, 4]) 126 | @pytest.mark.parametrize('width', [4]) 127 | @pytest.mark.parametrize( 128 | # "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096] 129 | "seqlen", [2048] 130 | ) 131 | # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) 132 | # @pytest.mark.parametrize('seqlen', [128]) 133 | def test_causal_conv1d_race_condition(seqlen, width, has_bias, silu_activation, itype, channel_last): 134 | device = "cuda" 135 | # set seed 136 | torch.random.manual_seed(0) 137 | batch_size = 2 138 | # batch_size = 1 139 | dim = 4096 + 32 # Try dim not divisible by 64 140 | # dim = 64 141 | if not channel_last: 142 | x = torch.randn(batch_size, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_() 143 | else: 144 | x = rearrange( 145 | torch.randn(batch_size, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s" 146 | ).requires_grad_() 147 | weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True) 148 | if has_bias: 149 | bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) 150 | else: 151 | bias = None 152 | activation = None if not silu_activation else "silu" 153 | out0 = causal_conv1d_fn(x, weight, bias, activation=activation) 154 | g = torch.randn_like(out0) 155 | dx0, dw0, db0 = torch.autograd.grad(out0, (x, weight, bias), g) 156 | dw_atol = 1e-4 157 | db_atol = 1e-4 158 | 159 | for i in range(10000): 160 | out = causal_conv1d_fn(x, weight, bias, activation=activation) 161 | dx, dw, db = torch.autograd.grad(out, (x, weight, bias), g) 162 | dw_equal = torch.allclose(dw, dw0, atol=dw_atol) 163 | # if not dw_equal: 164 | # breakpoint() 165 | if has_bias: 166 | db_equal = torch.allclose(db, db0, atol=db_atol) 167 | # if not db_equal: 168 | # breakpoint() 169 | assert torch.equal(out, out0) 170 | assert torch.equal(dx, dx0) 171 | assert dw_equal 172 | if has_bias: 173 | assert dw_equal 174 | -------------------------------------------------------------------------------- /code/augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import OrderedDict 3 | 4 | from augmentations.ctaugment import * 5 | 6 | 7 | class StorableCTAugment(CTAugment): 8 | def load_state_dict(self, state): 9 | for k in ["decay", "depth", "th", "rates"]: 10 | assert k in state, "{} not in {}".format(k, state.keys()) 11 | setattr(self, k, state[k]) 12 | 13 | def state_dict(self): 14 | return OrderedDict( 15 | [(k, getattr(self, k)) for k in ["decay", "depth", "th", "rates"]] 16 | ) 17 | 18 | 19 | def get_default_cta(): 20 | return StorableCTAugment() 21 | 22 | 23 | def cta_apply(pil_img, ops): 24 | if ops is None: 25 | return pil_img 26 | for op, args in ops: 27 | pil_img = OPS[op].f(pil_img, *args) 28 | return pil_img 29 | 30 | 31 | def deserialize(policy_str): 32 | return [OP(f=x[0], bins=x[1]) for x in json.loads(policy_str)] 33 | 34 | 35 | def stats(cta): 36 | return "\n".join( 37 | "%-16s %s" 38 | % ( 39 | k, 40 | " / ".join( 41 | " ".join("%.2f" % x for x in cta.rate_to_p(rate)) 42 | for rate in cta.rates[k] 43 | ), 44 | ) 45 | for k in sorted(OPS.keys()) 46 | ) 47 | 48 | 49 | def interleave(x, batch, inverse=False): 50 | """ 51 | TF code 52 | def interleave(x, batch): 53 | s = x.get_shape().as_list() 54 | return tf.reshape(tf.transpose(tf.reshape(x, [-1, batch] + s[1:]), [1, 0] + list(range(2, 1+len(s)))), [-1] + s[1:]) 55 | """ 56 | shape = x.shape 57 | axes = [batch, -1] if inverse else [-1, batch] 58 | return x.reshape(*axes, *shape[1:]).transpose(0, 1).reshape(-1, *shape[1:]) 59 | 60 | 61 | def deinterleave(x, batch): 62 | return interleave(x, batch, inverse=True) -------------------------------------------------------------------------------- /code/augmentations/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/augmentations/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /code/augmentations/__pycache__/ctaugment.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/augmentations/__pycache__/ctaugment.cpython-38.pyc -------------------------------------------------------------------------------- /code/augmentations/ctaugment.py: -------------------------------------------------------------------------------- 1 | # https://raw.githubusercontent.com/google-research/fixmatch/master/libml/ctaugment.py 2 | # 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Control Theory based self-augmentation, modified from https://github.com/vfdev-5/FixMatch-pytorch""" 17 | import random 18 | import torch 19 | from collections import namedtuple 20 | 21 | import numpy as np 22 | from scipy.ndimage.interpolation import zoom 23 | from PIL import Image, ImageOps, ImageEnhance, ImageFilter 24 | 25 | 26 | OPS = {} 27 | OP = namedtuple("OP", ("f", "bins")) 28 | Sample = namedtuple("Sample", ("train", "probe")) 29 | 30 | 31 | def register(*bins): 32 | def wrap(f): 33 | OPS[f.__name__] = OP(f, bins) 34 | return f 35 | 36 | return wrap 37 | 38 | 39 | class CTAugment(object): 40 | def __init__(self, depth=2, th=0.85, decay=0.99): 41 | self.decay = decay 42 | self.depth = depth 43 | self.th = th 44 | self.rates = {} 45 | for k, op in OPS.items(): 46 | self.rates[k] = tuple([np.ones(x, "f") for x in op.bins]) 47 | 48 | def rate_to_p(self, rate): 49 | p = rate + (1 - self.decay) # Avoid to have all zero. 50 | p = p / p.max() 51 | p[p < self.th] = 0 52 | return p 53 | 54 | def policy(self, probe, weak): 55 | num_strong_ops = 11 56 | kl_weak = list(OPS.keys())[num_strong_ops:] 57 | kl_strong = list(OPS.keys())[:num_strong_ops] 58 | 59 | if weak: 60 | kl = kl_weak 61 | else: 62 | kl = kl_strong 63 | 64 | v = [] 65 | if probe: 66 | for _ in range(self.depth): 67 | k = random.choice(kl) 68 | bins = self.rates[k] 69 | rnd = np.random.uniform(0, 1, len(bins)) 70 | v.append(OP(k, rnd.tolist())) 71 | return v 72 | for _ in range(self.depth): 73 | vt = [] 74 | k = random.choice(kl) 75 | bins = self.rates[k] 76 | rnd = np.random.uniform(0, 1, len(bins)) 77 | for r, bin in zip(rnd, bins): 78 | p = self.rate_to_p(bin) 79 | value = np.random.choice(p.shape[0], p=p / p.sum()) 80 | vt.append((value + r) / p.shape[0]) 81 | v.append(OP(k, vt)) 82 | return v 83 | 84 | def update_rates(self, policy, proximity): 85 | for k, bins in policy: 86 | for p, rate in zip(bins, self.rates[k]): 87 | p = int(p * len(rate) * 0.999) 88 | rate[p] = rate[p] * self.decay + proximity * (1 - self.decay) 89 | print(f"\t {k} weights updated") 90 | 91 | def stats(self): 92 | return "\n".join( 93 | "%-16s %s" 94 | % ( 95 | k, 96 | " / ".join( 97 | " ".join("%.2f" % x for x in self.rate_to_p(rate)) 98 | for rate in self.rates[k] 99 | ), 100 | ) 101 | for k in sorted(OPS.keys()) 102 | ) 103 | 104 | 105 | def _enhance(x, op, level): 106 | return op(x).enhance(0.1 + 1.9 * level) 107 | 108 | 109 | def _imageop(x, op, level): 110 | return Image.blend(x, op(x), level) 111 | 112 | 113 | def _filter(x, op, level): 114 | return Image.blend(x, x.filter(op), level) 115 | 116 | 117 | @register(17) 118 | def autocontrast(x, level): 119 | return _imageop(x, ImageOps.autocontrast, level) 120 | 121 | 122 | @register(17) 123 | def brightness(x, brightness): 124 | return _enhance(x, ImageEnhance.Brightness, brightness) 125 | 126 | 127 | @register(17) 128 | def color(x, color): 129 | return _enhance(x, ImageEnhance.Color, color) 130 | 131 | 132 | @register(17) 133 | def contrast(x, contrast): 134 | return _enhance(x, ImageEnhance.Contrast, contrast) 135 | 136 | 137 | @register(17) 138 | def equalize(x, level): 139 | return _imageop(x, ImageOps.equalize, level) 140 | 141 | 142 | @register(17) 143 | def invert(x, level): 144 | return _imageop(x, ImageOps.invert, level) 145 | 146 | 147 | @register(8) 148 | def posterize(x, level): 149 | level = 1 + int(level * 7.999) 150 | return ImageOps.posterize(x, level) 151 | 152 | 153 | @register(17) 154 | def solarize(x, th): 155 | th = int(th * 255.999) 156 | return ImageOps.solarize(x, th) 157 | 158 | 159 | @register(17) 160 | def smooth(x, level): 161 | return _filter(x, ImageFilter.SMOOTH, level) 162 | 163 | 164 | @register(17) 165 | def blur(x, level): 166 | return _filter(x, ImageFilter.BLUR, level) 167 | 168 | 169 | @register(17) 170 | def sharpness(x, sharpness): 171 | return _enhance(x, ImageEnhance.Sharpness, sharpness) 172 | 173 | 174 | # weak after here 175 | 176 | 177 | @register(17) 178 | def cutout(x, level): 179 | """Apply cutout to pil_img at the specified level.""" 180 | size = 1 + int(level * min(x.size) * 0.499) 181 | img_height, img_width = x.size 182 | height_loc = np.random.randint(low=img_height // 2, high=img_height) 183 | width_loc = np.random.randint(low=img_height // 2, high=img_width) 184 | upper_coord = (max(0, height_loc - size // 2), max(0, width_loc - size // 2)) 185 | lower_coord = ( 186 | min(img_height, height_loc + size // 2), 187 | min(img_width, width_loc + size // 2), 188 | ) 189 | pixels = x.load() # create the pixel map 190 | for i in range(upper_coord[0], lower_coord[0]): # for every col: 191 | for j in range(upper_coord[1], lower_coord[1]): # For every row 192 | x.putpixel((i, j), 0) # set the color accordingly 193 | return x 194 | 195 | 196 | @register() 197 | def identity(x): 198 | return x 199 | 200 | 201 | @register(17, 6) 202 | def rescale(x, scale, method): 203 | s = x.size 204 | scale *= 0.25 205 | crop = (scale * s[0], scale * s[1], s[0] * (1 - scale), s[1] * (1 - scale)) 206 | methods = ( 207 | Image.ANTIALIAS, 208 | Image.BICUBIC, 209 | Image.BILINEAR, 210 | Image.BOX, 211 | Image.HAMMING, 212 | Image.NEAREST, 213 | ) 214 | method = methods[int(method * 5.99)] 215 | return x.crop(crop).resize(x.size, method) 216 | 217 | 218 | @register(17) 219 | def rotate(x, angle): 220 | angle = int(np.round((2 * angle - 1) * 45)) 221 | return x.rotate(angle) 222 | 223 | 224 | @register(17) 225 | def shear_x(x, shear): 226 | shear = (2 * shear - 1) * 0.3 227 | return x.transform(x.size, Image.AFFINE, (1, shear, 0, 0, 1, 0)) 228 | 229 | 230 | @register(17) 231 | def shear_y(x, shear): 232 | shear = (2 * shear - 1) * 0.3 233 | return x.transform(x.size, Image.AFFINE, (1, 0, 0, shear, 1, 0)) 234 | 235 | 236 | @register(17) 237 | def translate_x(x, delta): 238 | delta = (2 * delta - 1) * 0.3 239 | return x.transform(x.size, Image.AFFINE, (1, 0, delta, 0, 1, 0)) 240 | 241 | 242 | @register(17) 243 | def translate_y(x, delta): 244 | delta = (2 * delta - 1) * 0.3 245 | return x.transform(x.size, Image.AFFINE, (1, 0, 0, 0, 1, delta)) -------------------------------------------------------------------------------- /code/config.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # --------------------------------------------------------' 7 | 8 | import os 9 | import yaml 10 | from yacs.config import CfgNode as CN 11 | 12 | _C = CN() 13 | 14 | # Base config files 15 | _C.BASE = [''] 16 | 17 | # ----------------------------------------------------------------------------- 18 | # Data settings 19 | # ----------------------------------------------------------------------------- 20 | _C.DATA = CN() 21 | # Batch size for a single GPU, could be overwritten by command line argument 22 | _C.DATA.BATCH_SIZE = 128 23 | # Path to dataset, could be overwritten by command line argument 24 | _C.DATA.DATA_PATH = '' 25 | # Dataset name 26 | _C.DATA.DATASET = 'imagenet' 27 | # Input image size 28 | _C.DATA.IMG_SIZE = 224 29 | # Interpolation to resize image (random, bilinear, bicubic) 30 | _C.DATA.INTERPOLATION = 'bicubic' 31 | # Use zipped dataset instead of folder dataset 32 | # could be overwritten by command line argument 33 | _C.DATA.ZIP_MODE = False 34 | # Cache Data in Memory, could be overwritten by command line argument 35 | _C.DATA.CACHE_MODE = 'part' 36 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. 37 | _C.DATA.PIN_MEMORY = True 38 | # Number of data loading threads 39 | _C.DATA.NUM_WORKERS = 8 40 | 41 | # ----------------------------------------------------------------------------- 42 | # Model settings 43 | # ----------------------------------------------------------------------------- 44 | _C.MODEL = CN() 45 | # Model type 46 | _C.MODEL.TYPE = 'swin' 47 | # Model name 48 | _C.MODEL.NAME = 'swin_tiny_patch4_window7_224' 49 | # Checkpoint to resume, could be overwritten by command line argument 50 | # _C.MODEL.PRETRAIN_CKPT = './pretrained_ckpt/swin_tiny_patch4_window7_224.pth' 51 | _C.MODEL.PRETRAIN_CKPT = './pretrained_ckpt/vmamba_tiny_e292.pth' 52 | _C.MODEL.RESUME = '' 53 | # Number of classes, overwritten in data preparation 54 | _C.MODEL.NUM_CLASSES = 1000 55 | # Dropout rate 56 | _C.MODEL.DROP_RATE = 0.0 57 | # Drop path rate 58 | _C.MODEL.DROP_PATH_RATE = 0.1 59 | # Label Smoothing 60 | _C.MODEL.LABEL_SMOOTHING = 0.1 61 | 62 | # VSSM parameters 63 | _C.MODEL.VSSM = CN() 64 | _C.MODEL.VSSM.PATCH_SIZE = 4 65 | _C.MODEL.VSSM.IN_CHANS = 3 66 | _C.MODEL.VSSM.EMBED_DIM = 96 67 | _C.MODEL.VSSM.DEPTHS = [2, 2, 9, 2] 68 | _C.MODEL.VSSM.MLP_RATIO = 4. 69 | _C.MODEL.VSSM.PATCH_NORM = True 70 | 71 | # Swin Transformer parameters 72 | _C.MODEL.SWIN = CN() 73 | _C.MODEL.SWIN.PATCH_SIZE = 4 74 | _C.MODEL.SWIN.IN_CHANS = 3 75 | _C.MODEL.SWIN.EMBED_DIM = 96 76 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] 77 | _C.MODEL.SWIN.DECODER_DEPTHS = [2, 2, 6, 2] 78 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] 79 | _C.MODEL.SWIN.WINDOW_SIZE = 7 80 | _C.MODEL.SWIN.MLP_RATIO = 4. 81 | _C.MODEL.SWIN.QKV_BIAS = True 82 | _C.MODEL.SWIN.QK_SCALE = False 83 | _C.MODEL.SWIN.APE = False 84 | _C.MODEL.SWIN.PATCH_NORM = True 85 | _C.MODEL.SWIN.FINAL_UPSAMPLE= "expand_first" 86 | 87 | # ----------------------------------------------------------------------------- 88 | # Training settings 89 | # ----------------------------------------------------------------------------- 90 | _C.TRAIN = CN() 91 | _C.TRAIN.START_EPOCH = 0 92 | _C.TRAIN.EPOCHS = 300 93 | _C.TRAIN.WARMUP_EPOCHS = 20 94 | _C.TRAIN.WEIGHT_DECAY = 0.05 95 | _C.TRAIN.BASE_LR = 5e-4 96 | _C.TRAIN.WARMUP_LR = 5e-7 97 | _C.TRAIN.MIN_LR = 5e-6 98 | # Clip gradient norm 99 | _C.TRAIN.CLIP_GRAD = 5.0 100 | # Auto resume from latest checkpoint 101 | _C.TRAIN.AUTO_RESUME = True 102 | # Gradient accumulation steps 103 | # could be overwritten by command line argument 104 | _C.TRAIN.ACCUMULATION_STEPS = 0 105 | # Whether to use gradient checkpointing to save memory 106 | # could be overwritten by command line argument 107 | _C.TRAIN.USE_CHECKPOINT = False 108 | 109 | # LR scheduler 110 | _C.TRAIN.LR_SCHEDULER = CN() 111 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine' 112 | # Epoch interval to decay LR, used in StepLRScheduler 113 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 114 | # LR decay rate, used in StepLRScheduler 115 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 116 | 117 | # Optimizer 118 | _C.TRAIN.OPTIMIZER = CN() 119 | _C.TRAIN.OPTIMIZER.NAME = 'adamw' 120 | # Optimizer Epsilon 121 | _C.TRAIN.OPTIMIZER.EPS = 1e-8 122 | # Optimizer Betas 123 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) 124 | # SGD momentum 125 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 126 | 127 | # ----------------------------------------------------------------------------- 128 | # Augmentation settings 129 | # ----------------------------------------------------------------------------- 130 | _C.AUG = CN() 131 | # Color jitter factor 132 | _C.AUG.COLOR_JITTER = 0.4 133 | # Use AutoAugment policy. "v0" or "original" 134 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' 135 | # Random erase prob 136 | _C.AUG.REPROB = 0.25 137 | # Random erase mode 138 | _C.AUG.REMODE = 'pixel' 139 | # Random erase count 140 | _C.AUG.RECOUNT = 1 141 | # Mixup alpha, mixup enabled if > 0 142 | _C.AUG.MIXUP = 0.8 143 | # Cutmix alpha, cutmix enabled if > 0 144 | _C.AUG.CUTMIX = 1.0 145 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set 146 | _C.AUG.CUTMIX_MINMAX = False 147 | # Probability of performing mixup or cutmix when either/both is enabled 148 | _C.AUG.MIXUP_PROB = 1.0 149 | # Probability of switching to cutmix when both mixup and cutmix enabled 150 | _C.AUG.MIXUP_SWITCH_PROB = 0.5 151 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" 152 | _C.AUG.MIXUP_MODE = 'batch' 153 | 154 | # ----------------------------------------------------------------------------- 155 | # Testing settings 156 | # ----------------------------------------------------------------------------- 157 | _C.TEST = CN() 158 | # Whether to use center crop when testing 159 | _C.TEST.CROP = True 160 | 161 | # ----------------------------------------------------------------------------- 162 | # Misc 163 | # ----------------------------------------------------------------------------- 164 | # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2') 165 | # overwritten by command line argument 166 | _C.AMP_OPT_LEVEL = '' 167 | # Path to output folder, overwritten by command line argument 168 | _C.OUTPUT = '' 169 | # Tag of experiment, overwritten by command line argument 170 | _C.TAG = 'default' 171 | # Frequency to save checkpoint 172 | _C.SAVE_FREQ = 1 173 | # Frequency to logging info 174 | _C.PRINT_FREQ = 10 175 | # Fixed random seed 176 | _C.SEED = 0 177 | # Perform evaluation only, overwritten by command line argument 178 | _C.EVAL_MODE = False 179 | # Test throughput only, overwritten by command line argument 180 | _C.THROUGHPUT_MODE = False 181 | # local rank for DistributedDataParallel, given by command line argument 182 | _C.LOCAL_RANK = 0 183 | 184 | 185 | def _update_config_from_file(config, cfg_file): 186 | config.defrost() 187 | with open(cfg_file, 'r') as f: 188 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) 189 | 190 | for cfg in yaml_cfg.setdefault('BASE', ['']): 191 | if cfg: 192 | _update_config_from_file( 193 | config, os.path.join(os.path.dirname(cfg_file), cfg) 194 | ) 195 | print('=> merge config from {}'.format(cfg_file)) 196 | config.merge_from_file(cfg_file) 197 | config.freeze() 198 | 199 | 200 | def update_config(config, args): 201 | _update_config_from_file(config, args.cfg) 202 | 203 | config.defrost() 204 | if args.opts: 205 | config.merge_from_list(args.opts) 206 | 207 | # merge from specific arguments 208 | if args.batch_size: 209 | config.DATA.BATCH_SIZE = args.batch_size 210 | if args.zip: 211 | config.DATA.ZIP_MODE = True 212 | if args.cache_mode: 213 | config.DATA.CACHE_MODE = args.cache_mode 214 | if args.resume: 215 | config.MODEL.RESUME = args.resume 216 | if args.accumulation_steps: 217 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps 218 | if args.use_checkpoint: 219 | config.TRAIN.USE_CHECKPOINT = True 220 | if args.amp_opt_level: 221 | config.AMP_OPT_LEVEL = args.amp_opt_level 222 | if args.tag: 223 | config.TAG = args.tag 224 | if args.eval: 225 | config.EVAL_MODE = True 226 | if args.throughput: 227 | config.THROUGHPUT_MODE = True 228 | 229 | config.freeze() 230 | 231 | 232 | def get_config(args): 233 | """Get a yacs CfgNode object with default values.""" 234 | # Return a clone so that the defaults will not be altered 235 | # This is for the "local variable" use pattern 236 | config = _C.clone() 237 | update_config(config, args) 238 | 239 | return config 240 | -------------------------------------------------------------------------------- /code/configs/swin_tiny_patch4_window7_224_lite.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_tiny_patch4_window7_224 4 | DROP_PATH_RATE: 0.2 5 | PRETRAIN_CKPT: "../code/pretrained_ckpt/swin_tiny_patch4_window7_224.pth" 6 | SWIN: 7 | FINAL_UPSAMPLE: "expand_first" 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 2, 2 ] 10 | DECODER_DEPTHS: [ 2, 2, 2, 1] 11 | NUM_HEADS: [ 3, 6, 12, 24 ] 12 | WINDOW_SIZE: 7 -------------------------------------------------------------------------------- /code/configs/vmamba_tiny.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny 4 | DROP_PATH_RATE: 0.2 5 | PRETRAIN_CKPT: "../code/pretrained_ckpt/vmamba_tiny_e292.pth" 6 | VSSM: 7 | EMBED_DIM: 96 8 | DEPTHS: [ 2, 2, 2, 2 ] -------------------------------------------------------------------------------- /code/dataloaders/__pycache__/dataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/dataloaders/__pycache__/dataset.cpython-310.pyc -------------------------------------------------------------------------------- /code/dataloaders/__pycache__/dataset.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/dataloaders/__pycache__/dataset.cpython-311.pyc -------------------------------------------------------------------------------- /code/dataloaders/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/dataloaders/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /code/dataloaders/__pycache__/dataset_s2l.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/dataloaders/__pycache__/dataset_s2l.cpython-38.pyc -------------------------------------------------------------------------------- /code/dataloaders/__pycache__/dataset_semi.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/dataloaders/__pycache__/dataset_semi.cpython-38.pyc -------------------------------------------------------------------------------- /code/dataloaders/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/dataloaders/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /code/dataloaders/__pycache__/utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/dataloaders/__pycache__/utils.cpython-311.pyc -------------------------------------------------------------------------------- /code/dataloaders/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/dataloaders/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /code/dataloaders/acdc_data_processing.py: -------------------------------------------------------------------------------- 1 | # save images in slice level 2 | import glob 3 | import os 4 | 5 | import h5py 6 | import numpy as np 7 | import SimpleITK as sitk 8 | 9 | 10 | class MedicalImageDeal(object): 11 | def __init__(self, img, percent=1): 12 | self.img = img 13 | self.percent = percent 14 | 15 | @property 16 | def valid_img(self): 17 | from skimage import exposure 18 | cdf = exposure.cumulative_distribution(self.img) 19 | watershed = cdf[1][cdf[0] >= self.percent][0] 20 | return np.clip(self.img, self.img.min(), watershed) 21 | 22 | @property 23 | def norm_img(self): 24 | return (self.img - self.img.min()) / (self.img.max() - self.img.min()) 25 | 26 | # saving images in slice level 27 | 28 | 29 | slice_num = 0 30 | mask_path = sorted( 31 | glob.glob("../data/ACDC_training/*_gt.nii.gz")) 32 | for case in mask_path: 33 | label_itk = sitk.ReadImage(case) 34 | label = sitk.GetArrayFromImage(label_itk) 35 | 36 | image_path = case.replace("_gt", "") 37 | image_itk = sitk.ReadImage(image_path) 38 | image = sitk.GetArrayFromImage(image_itk) 39 | 40 | scribble_path = case.replace("_gt", "_scribble") 41 | scribble_itk = sitk.ReadImage(scribble_path) 42 | scribble = sitk.GetArrayFromImage(scribble_itk) 43 | 44 | image = MedicalImageDeal(image, percent=0.99).valid_img 45 | image = (image - image.min()) / (image.max() - image.min()) 46 | print(image.shape) 47 | image = image.astype(np.float32) 48 | item = case.split("/")[-1].split(".")[0].replace("_gt", "") 49 | if image.shape != label.shape: 50 | print("Error") 51 | print(item) 52 | for slice_ind in range(image.shape[0]): 53 | f = h5py.File( 54 | '../data/ACDC_training_slices/{}_slice_{}.h5'.format(item, slice_ind), 'w') 55 | f.create_dataset( 56 | 'image', data=image[slice_ind], compression="gzip") 57 | f.create_dataset('label', data=label[slice_ind], compression="gzip") 58 | f.create_dataset( 59 | 'scribble', data=scribble[slice_ind], compression="gzip") 60 | f.close() 61 | slice_num += 1 62 | print("Converted all ACDC volumes to 2D slices") 63 | print("Total {} slices".format(slice_num)) 64 | 65 | # saving images in volume level 66 | 67 | 68 | class MedicalImageDeal(object): 69 | def __init__(self, img, percent=1): 70 | self.img = img 71 | self.percent = percent 72 | 73 | @property 74 | def valid_img(self): 75 | from skimage import exposure 76 | cdf = exposure.cumulative_distribution(self.img) 77 | watershed = cdf[1][cdf[0] >= self.percent][0] 78 | return np.clip(self.img, self.img.min(), watershed) 79 | 80 | @property 81 | def norm_img(self): 82 | return (self.img - self.img.min()) / (self.img.max() - self.img.min()) 83 | 84 | 85 | slice_num = 0 86 | mask_path = sorted( 87 | glob.glob("../data/ACDC_training/*_gt.nii.gz")) 88 | print(mask_path) 89 | 90 | for case in mask_path: 91 | label_itk = sitk.ReadImage(case) 92 | label = sitk.GetArrayFromImage(label_itk) 93 | 94 | image_path = case.replace("_gt", "") 95 | image_itk = sitk.ReadImage(image_path) 96 | image = sitk.GetArrayFromImage(image_itk) 97 | 98 | scribble_path = case.replace("_gt", "_scribble") 99 | scribble_itk = sitk.ReadImage(scribble_path) 100 | scribble = sitk.GetArrayFromImage(scribble_itk) 101 | 102 | image = MedicalImageDeal(image, percent=0.99).valid_img 103 | image = (image - image.min()) / (image.max() - image.min()) 104 | print(image.shape) 105 | image = image.astype(np.float32) 106 | item = case.split("/")[-1].split(".")[0].replace("_gt", "") 107 | if image.shape != label.shape: 108 | print("Error") 109 | print(item) 110 | f = h5py.File( 111 | '../data/ACDC_training_volumes/{}.h5'.format(item), 'w') 112 | f.create_dataset( 113 | 'image', data=image, compression="gzip") 114 | f.create_dataset('label', data=label, compression="gzip") 115 | f.create_dataset('scribble', data=scribble, compression="gzip") 116 | f.close() 117 | slice_num += 1 118 | print("Converted all ACDC volumes to 2D slices") 119 | print("Total {} slices".format(slice_num)) 120 | -------------------------------------------------------------------------------- /code/dataloaders/acdc_pseudo_label_random_walker.py: -------------------------------------------------------------------------------- 1 | import SimpleITK as sitk 2 | import glob 3 | import os 4 | import numpy as np 5 | from skimage.exposure import rescale_intensity 6 | from skimage.segmentation import random_walker 7 | 8 | 9 | def pseudo_label_generator_acdc(data, seed): 10 | from skimage.exposure import rescale_intensity 11 | from skimage.segmentation import random_walker 12 | if 1 not in np.unique(seed) or 2 not in np.unique(seed) or 3 not in np.unique(seed): 13 | pseudo_label = np.zeros_like(seed) 14 | else: 15 | markers = np.ones_like(seed) 16 | markers[seed == 4] = 0 17 | markers[seed == 0] = 1 18 | markers[seed == 1] = 2 19 | markers[seed == 2] = 3 20 | markers[seed == 3] = 4 21 | sigma = 0.35 22 | data = rescale_intensity(data, in_range=(-sigma, 1 + sigma), 23 | out_range=(-1, 1)) 24 | segmentation = random_walker(data, markers, beta=100, mode='bf') 25 | pseudo_label = segmentation - 1 26 | return pseudo_label 27 | 28 | 29 | def pseudo_label_generator(data, seed): 30 | # in the seed array: 0 means background, 1 to 3 mean class 1 to 3, 4 means: unknown region 31 | markers = np.ones_like(seed) 32 | markers[seed == 4] = 0 33 | markers[seed == 0] = 1 34 | markers[seed == 1] = 2 35 | markers[seed == 2] = 3 36 | markers[seed == 3] = 4 37 | sigma = 0.35 38 | data = rescale_intensity(data, in_range=(-sigma, 1 + sigma), 39 | out_range=(-1, 1)) 40 | pseudo_label = random_walker(data, markers, beta=100, mode='bf') 41 | return pseudo_label-1 42 | 43 | 44 | for i in sorted(glob.glob("../data/ACDC_training/*_scribble.nii.gz"))[2:]: 45 | print(i.replace("_scribble.nii.gz", ".nii.gz")) 46 | img_itk = sitk.ReadImage(i.replace("_scribble.nii.gz", ".nii.gz")) 47 | image = sitk.GetArrayFromImage(img_itk) 48 | scribble = sitk.GetArrayFromImage(sitk.ReadImage(i)) 49 | pseudo_volumes = np.zeros_like(image) 50 | for ind, slice_ind in enumerate(range(image.shape[0])): 51 | if 1 not in np.unique(scribble[ind, ...]) or 2 not in np.unique(scribble[ind, ...]) or 3 not in np.unique(scribble[ind, ...]): 52 | pass 53 | else: 54 | pseudo_volumes[ind, ...] = pseudo_label_generator( 55 | image[ind, ...], scribble[ind, ...]) 56 | pseudo_volumes_itk = sitk.GetImageFromArray(pseudo_volumes) 57 | pseudo_volumes_itk.CopyInformation(img_itk) 58 | sitk.WriteImage(pseudo_volumes_itk, i.replace( 59 | "_scribble.nii.gz", "_random_walker.nii.gz")) -------------------------------------------------------------------------------- /code/dataloaders/dataset.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import os 3 | import random 4 | import re 5 | from glob import glob 6 | 7 | import cv2 8 | import h5py 9 | import numpy as np 10 | import torch 11 | from scipy import ndimage 12 | from scipy.ndimage.interpolation import zoom 13 | from torch.utils.data import Dataset 14 | from torch.utils.data.sampler import Sampler 15 | 16 | 17 | def pseudo_label_generator_acdc(data, seed, beta=100, mode='bf'): 18 | from skimage.exposure import rescale_intensity 19 | from skimage.segmentation import random_walker 20 | if 1 not in np.unique(seed) or 2 not in np.unique(seed) or 3 not in np.unique(seed): 21 | pseudo_label = np.zeros_like(seed) 22 | else: 23 | markers = np.ones_like(seed) 24 | markers[seed == 4] = 0 25 | markers[seed == 0] = 1 26 | markers[seed == 1] = 2 27 | markers[seed == 2] = 3 28 | markers[seed == 3] = 4 29 | sigma = 0.35 30 | data = rescale_intensity(data, in_range=(-sigma, 1 + sigma), 31 | out_range=(-1, 1)) 32 | segmentation = random_walker(data, markers, beta, mode) 33 | pseudo_label = segmentation - 1 34 | return pseudo_label 35 | 36 | 37 | class BaseDataSets(Dataset): 38 | def __init__(self, base_dir=None, split='train', transform=None, fold="fold1", sup_type="label"): 39 | self._base_dir = base_dir 40 | self.sample_list = [] 41 | self.split = split 42 | self.sup_type = sup_type 43 | self.transform = transform 44 | train_ids, test_ids = self._get_fold_ids(fold) 45 | if self.split == 'train': 46 | self.all_slices = os.listdir( 47 | self._base_dir + "/ACDC_training_slices") 48 | self.sample_list = [] 49 | for ids in train_ids: 50 | new_data_list = list(filter(lambda x: re.match( 51 | '{}.*'.format(ids), x) != None, self.all_slices)) 52 | self.sample_list.extend(new_data_list) 53 | 54 | elif self.split == 'val': 55 | self.all_volumes = os.listdir( 56 | self._base_dir + "/ACDC_training_volumes") 57 | self.sample_list = [] 58 | for ids in test_ids: 59 | new_data_list = list(filter(lambda x: re.match( 60 | '{}.*'.format(ids), x) != None, self.all_volumes)) 61 | self.sample_list.extend(new_data_list) 62 | 63 | # if num is not None and self.split == "train": 64 | # self.sample_list = self.sample_list[:num] 65 | print("total {} samples".format(len(self.sample_list))) 66 | 67 | def _get_fold_ids(self, fold): 68 | all_cases_set = ["patient{:0>3}".format(i) for i in range(1, 101)] 69 | fold1_testing_set = [ 70 | "patient{:0>3}".format(i) for i in range(1, 21)] 71 | fold1_training_set = [ 72 | i for i in all_cases_set if i not in fold1_testing_set] 73 | 74 | fold2_testing_set = [ 75 | "patient{:0>3}".format(i) for i in range(21, 41)] 76 | fold2_training_set = [ 77 | i for i in all_cases_set if i not in fold2_testing_set] 78 | 79 | fold3_testing_set = [ 80 | "patient{:0>3}".format(i) for i in range(41, 61)] 81 | fold3_training_set = [ 82 | i for i in all_cases_set if i not in fold3_testing_set] 83 | 84 | fold4_testing_set = [ 85 | "patient{:0>3}".format(i) for i in range(61, 81)] 86 | fold4_training_set = [ 87 | i for i in all_cases_set if i not in fold4_testing_set] 88 | 89 | fold5_testing_set = [ 90 | "patient{:0>3}".format(i) for i in range(81, 101)] 91 | fold5_training_set = [ 92 | i for i in all_cases_set if i not in fold5_testing_set] 93 | if fold == "fold1": 94 | return [fold1_training_set, fold1_testing_set] 95 | elif fold == "fold2": 96 | return [fold2_training_set, fold2_testing_set] 97 | elif fold == "fold3": 98 | return [fold3_training_set, fold3_testing_set] 99 | elif fold == "fold4": 100 | return [fold4_training_set, fold4_testing_set] 101 | elif fold == "fold5": 102 | return [fold5_training_set, fold5_testing_set] 103 | else: 104 | return "ERROR KEY" 105 | 106 | def __len__(self): 107 | return len(self.sample_list) 108 | 109 | def __getitem__(self, idx): 110 | case = self.sample_list[idx] 111 | if self.split == "train": 112 | h5f = h5py.File(self._base_dir + 113 | "/ACDC_training_slices/{}".format(case), 'r') 114 | else: 115 | h5f = h5py.File(self._base_dir + 116 | "/ACDC_training_volumes/{}".format(case), 'r') 117 | image = h5f['image'][:] 118 | label = h5f['label'][:] 119 | sample = {'image': image, 'label': label} 120 | if self.split == "train": 121 | image = h5f['image'][:] 122 | if self.sup_type == "random_walker": 123 | label = pseudo_label_generator_acdc(image, h5f["scribble"][:]) 124 | else: 125 | label = h5f[self.sup_type][:] 126 | sample = {'image': image, 'label': label} 127 | sample = self.transform(sample) 128 | else: 129 | image = h5f['image'][:] 130 | label = h5f['label'][:] 131 | sample = {'image': image, 'label': label} 132 | sample["idx"] = idx 133 | return sample 134 | 135 | 136 | def random_rot_flip(image, label): 137 | k = np.random.randint(0, 4) 138 | image = np.rot90(image, k) 139 | label = np.rot90(label, k) 140 | axis = np.random.randint(0, 2) 141 | image = np.flip(image, axis=axis).copy() 142 | label = np.flip(label, axis=axis).copy() 143 | return image, label 144 | 145 | 146 | def random_rotate(image, label, cval): 147 | angle = np.random.randint(-20, 20) 148 | image = ndimage.rotate(image, angle, order=0, reshape=False) 149 | label = ndimage.rotate(label, angle, order=0, 150 | reshape=False, mode="constant", cval=cval) 151 | return image, label 152 | 153 | 154 | class RandomGenerator(object): 155 | def __init__(self, output_size): 156 | self.output_size = output_size 157 | 158 | def __call__(self, sample): 159 | image, label = sample['image'], sample['label'] 160 | # ind = random.randrange(0, img.shape[0]) 161 | # image = img[ind, ...] 162 | # label = lab[ind, ...] 163 | if random.random() > 0.5: 164 | image, label = random_rot_flip(image, label) 165 | elif random.random() > 0.5: 166 | if 4 in np.unique(label): 167 | image, label = random_rotate(image, label, cval=4) 168 | else: 169 | image, label = random_rotate(image, label, cval=0) 170 | x, y = image.shape 171 | image = zoom( 172 | image, (self.output_size[0] / x, self.output_size[1] / y), order=0) 173 | label = zoom( 174 | label, (self.output_size[0] / x, self.output_size[1] / y), order=0) 175 | image = torch.from_numpy( 176 | image.astype(np.float32)).unsqueeze(0) 177 | label = torch.from_numpy(label.astype(np.uint8)) 178 | sample = {'image': image, 'label': label} 179 | return sample 180 | 181 | 182 | class TwoStreamBatchSampler(Sampler): 183 | """Iterate two sets of indices 184 | 185 | An 'epoch' is one iteration through the primary indices. 186 | During the epoch, the secondary indices are iterated through 187 | as many times as needed. 188 | """ 189 | 190 | def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size): 191 | self.primary_indices = primary_indices 192 | self.secondary_indices = secondary_indices 193 | self.secondary_batch_size = secondary_batch_size 194 | self.primary_batch_size = batch_size - secondary_batch_size 195 | 196 | assert len(self.primary_indices) >= self.primary_batch_size > 0 197 | assert len(self.secondary_indices) >= self.secondary_batch_size > 0 198 | 199 | def __iter__(self): 200 | primary_iter = iterate_once(self.primary_indices) 201 | secondary_iter = iterate_eternally(self.secondary_indices) 202 | return ( 203 | primary_batch + secondary_batch 204 | for (primary_batch, secondary_batch) 205 | in zip(grouper(primary_iter, self.primary_batch_size), 206 | grouper(secondary_iter, self.secondary_batch_size)) 207 | ) 208 | 209 | def __len__(self): 210 | return len(self.primary_indices) // self.primary_batch_size 211 | 212 | 213 | def iterate_once(iterable): 214 | return np.random.permutation(iterable) 215 | 216 | 217 | def iterate_eternally(indices): 218 | def infinite_shuffles(): 219 | while True: 220 | yield np.random.permutation(indices) 221 | return itertools.chain.from_iterable(infinite_shuffles()) 222 | 223 | 224 | def grouper(iterable, n): 225 | "Collect data into fixed-length chunks or blocks" 226 | # grouper('ABCDEFG', 3) --> ABC DEF" 227 | args = [iter(iterable)] * n 228 | return zip(*args) 229 | -------------------------------------------------------------------------------- /code/dataloaders/dataset_semi.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import os 3 | import random 4 | import re 5 | from glob import glob 6 | 7 | import cv2 8 | import h5py 9 | import numpy as np 10 | import torch 11 | from scipy import ndimage 12 | from scipy.ndimage.interpolation import zoom 13 | from torch.utils.data import Dataset 14 | from torch.utils.data.sampler import Sampler 15 | 16 | 17 | class BaseDataSets(Dataset): 18 | def __init__(self, base_dir=None, num=4, labeled_type="labeled", split='train', transform=None, fold="fold1", sup_type="label"): 19 | self._base_dir = base_dir 20 | self.sample_list = [] 21 | self.split = split 22 | self.sup_type = sup_type 23 | self.transform = transform 24 | self.num = num 25 | self.labeled_type = labeled_type 26 | train_ids, test_ids = self._get_fold_ids(fold) 27 | all_labeled_ids = ["patient{:0>3}".format( 28 | 10 * i) for i in range(1, 11)] 29 | if self.split == 'train': 30 | self.all_slices = os.listdir( 31 | self._base_dir + "/ACDC_training_slices") 32 | self.sample_list = [] 33 | labeled_ids = [i for i in all_labeled_ids if i in train_ids] 34 | unlabeled_ids = [i for i in train_ids if i not in labeled_ids] 35 | if self.labeled_type == "labeled": 36 | print("Labeled patients IDs", labeled_ids) 37 | for ids in labeled_ids: 38 | new_data_list = list(filter(lambda x: re.match( 39 | '{}.*'.format(ids), x) != None, self.all_slices)) 40 | self.sample_list.extend(new_data_list) 41 | print("total labeled {} samples".format(len(self.sample_list))) 42 | else: 43 | print("Unlabeled patients IDs", unlabeled_ids) 44 | for ids in unlabeled_ids: 45 | new_data_list = list(filter(lambda x: re.match( 46 | '{}.*'.format(ids), x) != None, self.all_slices)) 47 | self.sample_list.extend(new_data_list) 48 | print("total unlabeled {} samples".format(len(self.sample_list))) 49 | 50 | elif self.split == 'val': 51 | self.all_volumes = os.listdir( 52 | self._base_dir + "/ACDC_training_volumes") 53 | self.sample_list = [] 54 | for ids in test_ids: 55 | new_data_list = list(filter(lambda x: re.match( 56 | '{}.*'.format(ids), x) != None, self.all_volumes)) 57 | self.sample_list.extend(new_data_list) 58 | 59 | # if num is not None and self.split == "train": 60 | # self.sample_list = self.sample_list[:num] 61 | 62 | def _get_fold_ids(self, fold): 63 | all_cases_set = ["patient{:0>3}".format(i) for i in range(1, 101)] 64 | fold1_testing_set = [ 65 | "patient{:0>3}".format(i) for i in range(1, 21)] 66 | fold1_training_set = [ 67 | i for i in all_cases_set if i not in fold1_testing_set] 68 | 69 | fold2_testing_set = [ 70 | "patient{:0>3}".format(i) for i in range(21, 41)] 71 | fold2_training_set = [ 72 | i for i in all_cases_set if i not in fold2_testing_set] 73 | 74 | fold3_testing_set = [ 75 | "patient{:0>3}".format(i) for i in range(41, 61)] 76 | fold3_training_set = [ 77 | i for i in all_cases_set if i not in fold3_testing_set] 78 | 79 | fold4_testing_set = [ 80 | "patient{:0>3}".format(i) for i in range(61, 81)] 81 | fold4_training_set = [ 82 | i for i in all_cases_set if i not in fold4_testing_set] 83 | 84 | fold5_testing_set = [ 85 | "patient{:0>3}".format(i) for i in range(81, 101)] 86 | fold5_training_set = [ 87 | i for i in all_cases_set if i not in fold5_testing_set] 88 | if fold == "fold1": 89 | return [fold1_training_set, fold1_testing_set] 90 | elif fold == "fold2": 91 | return [fold2_training_set, fold2_testing_set] 92 | elif fold == "fold3": 93 | return [fold3_training_set, fold3_testing_set] 94 | elif fold == "fold4": 95 | return [fold4_training_set, fold4_testing_set] 96 | elif fold == "fold5": 97 | return [fold5_training_set, fold5_testing_set] 98 | else: 99 | return "ERROR KEY" 100 | 101 | def __len__(self): 102 | return len(self.sample_list) 103 | 104 | def __getitem__(self, idx): 105 | case = self.sample_list[idx] 106 | if self.split == "train": 107 | h5f = h5py.File(self._base_dir + 108 | "/ACDC_training_slices/{}".format(case), 'r') 109 | else: 110 | h5f = h5py.File(self._base_dir + 111 | "/ACDC_training_volumes/{}".format(case), 'r') 112 | image = h5f['image'][:] 113 | label = h5f['label'][:] 114 | sample = {'image': image, 'label': label} 115 | if self.split == "train": 116 | image = h5f['image'][:] 117 | label = h5f[self.sup_type][:] 118 | sample = {'image': image, 'label': label} 119 | sample = self.transform(sample) 120 | else: 121 | image = h5f['image'][:] 122 | label = h5f['label'][:] 123 | sample = {'image': image, 'label': label} 124 | sample["idx"] = case.split("_")[0] 125 | return sample 126 | 127 | 128 | def random_rot_flip(image, label): 129 | k = np.random.randint(0, 4) 130 | image = np.rot90(image, k) 131 | label = np.rot90(label, k) 132 | axis = np.random.randint(0, 2) 133 | image = np.flip(image, axis=axis).copy() 134 | label = np.flip(label, axis=axis).copy() 135 | return image, label 136 | 137 | 138 | def random_rotate(image, label, cval): 139 | angle = np.random.randint(-20, 20) 140 | image = ndimage.rotate(image, angle, order=0, reshape=False) 141 | label = ndimage.rotate(label, angle, order=0, 142 | reshape=False, mode="constant", cval=cval) 143 | return image, label 144 | 145 | 146 | class RandomGenerator(object): 147 | def __init__(self, output_size): 148 | self.output_size = output_size 149 | 150 | def __call__(self, sample): 151 | image, label = sample['image'], sample['label'] 152 | # ind = random.randrange(0, img.shape[0]) 153 | # image = img[ind, ...] 154 | # label = lab[ind, ...] 155 | if random.random() > 0.5: 156 | image, label = random_rot_flip(image, label) 157 | elif random.random() > 0.5: 158 | if 4 in np.unique(label): 159 | image, label = random_rotate(image, label, cval=4) 160 | else: 161 | image, label = random_rotate(image, label, cval=0) 162 | x, y = image.shape 163 | image = zoom( 164 | image, (self.output_size[0] / x, self.output_size[1] / y), order=0) 165 | label = zoom( 166 | label, (self.output_size[0] / x, self.output_size[1] / y), order=0) 167 | image = torch.from_numpy( 168 | image.astype(np.float32)).unsqueeze(0) 169 | label = torch.from_numpy(label.astype(np.uint8)) 170 | sample = {'image': image, 'label': label} 171 | return sample 172 | 173 | 174 | class TwoStreamBatchSampler(Sampler): 175 | """Iterate two sets of indices 176 | 177 | An 'epoch' is one iteration through the primary indices. 178 | During the epoch, the secondary indices are iterated through 179 | as many times as needed. 180 | """ 181 | 182 | def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size): 183 | self.primary_indices = primary_indices 184 | self.secondary_indices = secondary_indices 185 | self.secondary_batch_size = secondary_batch_size 186 | self.primary_batch_size = batch_size - secondary_batch_size 187 | 188 | assert len(self.primary_indices) >= self.primary_batch_size > 0 189 | assert len(self.secondary_indices) >= self.secondary_batch_size > 0 190 | 191 | def __iter__(self): 192 | primary_iter = iterate_once(self.primary_indices) 193 | secondary_iter = iterate_eternally(self.secondary_indices) 194 | return ( 195 | primary_batch + secondary_batch 196 | for (primary_batch, secondary_batch) 197 | in zip(grouper(primary_iter, self.primary_batch_size), 198 | grouper(secondary_iter, self.secondary_batch_size)) 199 | ) 200 | 201 | def __len__(self): 202 | return len(self.primary_indices) // self.primary_batch_size 203 | 204 | 205 | def iterate_once(iterable): 206 | return np.random.permutation(iterable) 207 | 208 | 209 | def iterate_eternally(indices): 210 | def infinite_shuffles(): 211 | while True: 212 | yield np.random.permutation(indices) 213 | return itertools.chain.from_iterable(infinite_shuffles()) 214 | 215 | 216 | def grouper(iterable, n): 217 | "Collect data into fixed-length chunks or blocks" 218 | # grouper('ABCDEFG', 3) --> ABC DEF" 219 | args = [iter(iterable)] * n 220 | return zip(*args) 221 | -------------------------------------------------------------------------------- /code/dataloaders/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import scipy.ndimage as nd 5 | import torch 6 | import torch.nn as nn 7 | # import matplotlib.pyplot as plt 8 | from skimage import measure 9 | 10 | 11 | def recursive_glob(rootdir='.', suffix=''): 12 | """Performs recursive glob with given suffix and rootdir 13 | :param rootdir is the root directory 14 | :param suffix is the suffix to be searched 15 | """ 16 | return [os.path.join(looproot, filename) 17 | for looproot, _, filenames in os.walk(rootdir) 18 | for filename in filenames if filename.endswith(suffix)] 19 | 20 | 21 | def get_cityscapes_labels(): 22 | return np.array([ 23 | # [ 0, 0, 0], 24 | [128, 64, 128], 25 | [244, 35, 232], 26 | [70, 70, 70], 27 | [102, 102, 156], 28 | [190, 153, 153], 29 | [153, 153, 153], 30 | [250, 170, 30], 31 | [220, 220, 0], 32 | [107, 142, 35], 33 | [152, 251, 152], 34 | [0, 130, 180], 35 | [220, 20, 60], 36 | [255, 0, 0], 37 | [0, 0, 142], 38 | [0, 0, 70], 39 | [0, 60, 100], 40 | [0, 80, 100], 41 | [0, 0, 230], 42 | [119, 11, 32]]) 43 | 44 | 45 | def get_pascal_labels(): 46 | """Load the mapping that associates pascal classes with label colors 47 | Returns: 48 | np.ndarray with dimensions (21, 3) 49 | """ 50 | return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], 51 | [0, 0, 128], [128, 0, 128], [ 52 | 0, 128, 128], [128, 128, 128], 53 | [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], 54 | [64, 0, 128], [192, 0, 128], [ 55 | 64, 128, 128], [192, 128, 128], 56 | [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], 57 | [0, 64, 128]]) 58 | 59 | 60 | def encode_segmap(mask): 61 | """Encode segmentation label images as pascal classes 62 | Args: 63 | mask (np.ndarray): raw segmentation label image of dimension 64 | (M, N, 3), in which the Pascal classes are encoded as colours. 65 | Returns: 66 | (np.ndarray): class map with dimensions (M,N), where the value at 67 | a given location is the integer denoting the class index. 68 | """ 69 | mask = mask.astype(int) 70 | label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16) 71 | for ii, label in enumerate(get_pascal_labels()): 72 | label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii 73 | label_mask = label_mask.astype(int) 74 | return label_mask 75 | 76 | 77 | def decode_seg_map_sequence(label_masks, dataset='pascal'): 78 | rgb_masks = [] 79 | for label_mask in label_masks: 80 | rgb_mask = decode_segmap(label_mask, dataset) 81 | rgb_masks.append(rgb_mask) 82 | rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2])) 83 | return rgb_masks 84 | 85 | 86 | def decode_segmap(label_mask, dataset, plot=False): 87 | """Decode segmentation class labels into a color image 88 | Args: 89 | label_mask (np.ndarray): an (M,N) array of integer values denoting 90 | the class label at each spatial location. 91 | plot (bool, optional): whether to show the resulting color image 92 | in a figure. 93 | Returns: 94 | (np.ndarray, optional): the resulting decoded color image. 95 | """ 96 | if dataset == 'pascal': 97 | n_classes = 21 98 | label_colours = get_pascal_labels() 99 | elif dataset == 'cityscapes': 100 | n_classes = 19 101 | label_colours = get_cityscapes_labels() 102 | else: 103 | raise NotImplementedError 104 | 105 | r = label_mask.copy() 106 | g = label_mask.copy() 107 | b = label_mask.copy() 108 | for ll in range(0, n_classes): 109 | r[label_mask == ll] = label_colours[ll, 0] 110 | g[label_mask == ll] = label_colours[ll, 1] 111 | b[label_mask == ll] = label_colours[ll, 2] 112 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) 113 | rgb[:, :, 0] = r / 255.0 114 | rgb[:, :, 1] = g / 255.0 115 | rgb[:, :, 2] = b / 255.0 116 | if plot: 117 | plt.imshow(rgb) 118 | plt.show() 119 | else: 120 | return rgb 121 | 122 | 123 | def generate_param_report(logfile, param): 124 | log_file = open(logfile, 'w') 125 | # for key, val in param.items(): 126 | # log_file.write(key + ':' + str(val) + '\n') 127 | log_file.write(str(param)) 128 | log_file.close() 129 | 130 | 131 | def cross_entropy2d(logit, target, ignore_index=255, weight=None, size_average=True, batch_average=True): 132 | n, c, h, w = logit.size() 133 | # logit = logit.permute(0, 2, 3, 1) 134 | target = target.squeeze(1) 135 | if weight is None: 136 | criterion = nn.CrossEntropyLoss( 137 | weight=weight, ignore_index=ignore_index, size_average=False) 138 | else: 139 | criterion = nn.CrossEntropyLoss(weight=torch.from_numpy(np.array( 140 | weight)).float().cuda(), ignore_index=ignore_index, size_average=False) 141 | loss = criterion(logit, target.long()) 142 | 143 | if size_average: 144 | loss /= (h * w) 145 | 146 | if batch_average: 147 | loss /= n 148 | 149 | return loss 150 | 151 | 152 | def lr_poly(base_lr, iter_, max_iter=100, power=0.9): 153 | return base_lr * ((1 - float(iter_) / max_iter) ** power) 154 | 155 | 156 | def get_iou(pred, gt, n_classes=21): 157 | total_iou = 0.0 158 | for i in range(len(pred)): 159 | pred_tmp = pred[i] 160 | gt_tmp = gt[i] 161 | 162 | intersect = [0] * n_classes 163 | union = [0] * n_classes 164 | for j in range(n_classes): 165 | match = (pred_tmp == j) + (gt_tmp == j) 166 | 167 | it = torch.sum(match == 2).item() 168 | un = torch.sum(match > 0).item() 169 | 170 | intersect[j] += it 171 | union[j] += un 172 | 173 | iou = [] 174 | for k in range(n_classes): 175 | if union[k] == 0: 176 | continue 177 | iou.append(intersect[k] / union[k]) 178 | 179 | img_iou = (sum(iou) / len(iou)) 180 | total_iou += img_iou 181 | 182 | return total_iou 183 | 184 | 185 | def get_dice(pred, gt): 186 | total_dice = 0.0 187 | pred = pred.long() 188 | gt = gt.long() 189 | for i in range(len(pred)): 190 | pred_tmp = pred[i] 191 | gt_tmp = gt[i] 192 | dice = 2.0*torch.sum(pred_tmp*gt_tmp).item()/(1.0 + 193 | torch.sum(pred_tmp**2)+torch.sum(gt_tmp**2)).item() 194 | print(dice) 195 | total_dice += dice 196 | 197 | return total_dice 198 | 199 | 200 | def get_mc_dice(pred, gt, num=2): 201 | # num is the total number of classes, include the background 202 | total_dice = np.zeros(num-1) 203 | pred = pred.long() 204 | gt = gt.long() 205 | for i in range(len(pred)): 206 | for j in range(1, num): 207 | pred_tmp = (pred[i] == j) 208 | gt_tmp = (gt[i] == j) 209 | dice = 2.0*torch.sum(pred_tmp*gt_tmp).item()/(1.0 + 210 | torch.sum(pred_tmp**2)+torch.sum(gt_tmp**2)).item() 211 | total_dice[j-1] += dice 212 | return total_dice 213 | 214 | 215 | def post_processing(prediction): 216 | prediction = nd.binary_fill_holes(prediction) 217 | label_cc, num_cc = measure.label(prediction, return_num=True) 218 | total_cc = np.sum(prediction) 219 | measure.regionprops(label_cc) 220 | for cc in range(1, num_cc+1): 221 | single_cc = (label_cc == cc) 222 | single_vol = np.sum(single_cc) 223 | if single_vol/total_cc < 0.2: 224 | prediction[single_cc] = 0 225 | 226 | return prediction 227 | -------------------------------------------------------------------------------- /code/networks/VoxResNet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import print_function, division 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class SEBlock(nn.Module): 10 | def __init__(self, in_channels, r): 11 | super(SEBlock, self).__init__() 12 | 13 | redu_chns = int(in_channels / r) 14 | self.se_layers = nn.Sequential( 15 | nn.AdaptiveAvgPool3d(1), 16 | nn.Conv3d(in_channels, redu_chns, kernel_size=1, padding=0), 17 | nn.ReLU(), 18 | nn.Conv3d(redu_chns, in_channels, kernel_size=1, padding=0), 19 | nn.ReLU()) 20 | 21 | def forward(self, x): 22 | f = self.se_layers(x) 23 | return f * x + x 24 | 25 | 26 | class VoxRex(nn.Module): 27 | def __init__(self, in_channels): 28 | super(VoxRex, self).__init__() 29 | self.block = nn.Sequential( 30 | nn.InstanceNorm3d(in_channels), 31 | nn.ReLU(inplace=True), 32 | nn.Conv3d(in_channels, in_channels, 33 | kernel_size=3, padding=1, bias=False), 34 | nn.InstanceNorm3d(in_channels), 35 | nn.ReLU(inplace=True), 36 | nn.Conv3d(in_channels, in_channels, 37 | kernel_size=3, padding=1, bias=False) 38 | ) 39 | 40 | def forward(self, x): 41 | return self.block(x)+x 42 | 43 | 44 | class ConvBlock(nn.Module): 45 | """two convolution layers with batch norm and leaky relu""" 46 | 47 | def __init__(self, in_channels, out_channels): 48 | super(ConvBlock, self).__init__() 49 | self.conv_conv = nn.Sequential( 50 | nn.InstanceNorm3d(in_channels), 51 | nn.ReLU(inplace=True), 52 | nn.Conv3d(in_channels, out_channels, 53 | kernel_size=3, padding=1, bias=False), 54 | nn.InstanceNorm3d(out_channels), 55 | nn.ReLU(inplace=True), 56 | nn.Conv3d(out_channels, out_channels, 57 | kernel_size=3, padding=1, bias=False) 58 | ) 59 | 60 | def forward(self, x): 61 | return self.conv_conv(x) 62 | 63 | 64 | class UpBlock(nn.Module): 65 | """Upssampling followed by ConvBlock""" 66 | 67 | def __init__(self, in_channels, out_channels): 68 | super(UpBlock, self).__init__() 69 | self.up = nn.Upsample( 70 | scale_factor=2, mode='trilinear', align_corners=True) 71 | self.conv = ConvBlock(in_channels, out_channels) 72 | 73 | def forward(self, x1, x2): 74 | x1 = self.up(x1) 75 | x = torch.cat([x2, x1], dim=1) 76 | return self.conv(x) 77 | 78 | 79 | class VoxResNet(nn.Module): 80 | def __init__(self, in_chns=1, feature_chns=64, class_num=2): 81 | super(VoxResNet, self).__init__() 82 | self.in_chns = in_chns 83 | self.ft_chns = feature_chns 84 | self.n_class = class_num 85 | 86 | self.conv1 = nn.Conv3d(in_chns, feature_chns, kernel_size=3, padding=1) 87 | self.res1 = VoxRex(feature_chns) 88 | self.res2 = VoxRex(feature_chns) 89 | self.res3 = VoxRex(feature_chns) 90 | self.res4 = VoxRex(feature_chns) 91 | self.res5 = VoxRex(feature_chns) 92 | self.res6 = VoxRex(feature_chns) 93 | 94 | self.up1 = UpBlock(feature_chns * 2, feature_chns) 95 | self.up2 = UpBlock(feature_chns * 2, feature_chns) 96 | 97 | self.out = nn.Conv3d(feature_chns, self.n_class, kernel_size=1) 98 | 99 | self.maxpool = nn.MaxPool3d(2) 100 | self.upsample = nn.Upsample( 101 | scale_factor=2, mode='trilinear', align_corners=True) 102 | 103 | def forward(self, x): 104 | x = self.maxpool(self.conv1(x)) 105 | x1 = self.res1(x) 106 | x2 = self.res2(x1) 107 | x2_pool = self.maxpool(x2) 108 | x3 = self.res3(x2_pool) 109 | x4 = self.maxpool(self.res4(x3)) 110 | x5 = self.res5(x4) 111 | x6 = self.res6(x5) 112 | up1 = self.up1(x6, x2_pool) 113 | up2 = self.up2(up1, x) 114 | up = self.upsample(up2) 115 | out = self.out(up) 116 | return out 117 | -------------------------------------------------------------------------------- /code/networks/__pycache__/attention.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/attention.cpython-310.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/attention.cpython-38.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/config.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/config.cpython-310.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/discriminator.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/discriminator.cpython-38.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/efficient_encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/efficient_encoder.cpython-310.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/efficient_encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/efficient_encoder.cpython-38.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/efficientunet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/efficientunet.cpython-310.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/efficientunet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/efficientunet.cpython-38.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/enet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/enet.cpython-310.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/enet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/enet.cpython-38.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/mamba_sys.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/mamba_sys.cpython-310.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/net_factory.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/net_factory.cpython-310.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/net_factory.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/net_factory.cpython-38.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/neural_network.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/neural_network.cpython-310.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/neural_network.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/neural_network.cpython-38.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/nnunet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/nnunet.cpython-310.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/nnunet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/nnunet.cpython-38.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/pnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/pnet.cpython-310.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/pnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/pnet.cpython-38.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-310.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-38.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/unet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/unet.cpython-310.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/unet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/unet.cpython-38.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/vision_mamba.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/vision_mamba.cpython-310.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/vision_transformer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/vision_transformer.cpython-310.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/vision_transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/vision_transformer.cpython-38.pyc -------------------------------------------------------------------------------- /code/networks/attention.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | try: 4 | from inplace_abn import InPlaceABN 5 | except ImportError: 6 | InPlaceABN = None 7 | 8 | 9 | class Conv2dReLU(nn.Sequential): 10 | def __init__( 11 | self, 12 | in_channels, 13 | out_channels, 14 | kernel_size, 15 | padding=0, 16 | stride=1, 17 | use_batchnorm=True, 18 | ): 19 | 20 | if use_batchnorm == "inplace" and InPlaceABN is None: 21 | raise RuntimeError( 22 | "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. " 23 | + "To install see: https://github.com/mapillary/inplace_abn" 24 | ) 25 | 26 | super().__init__() 27 | 28 | conv = nn.Conv2d( 29 | in_channels, 30 | out_channels, 31 | kernel_size, 32 | stride=stride, 33 | padding=padding, 34 | bias=not (use_batchnorm), 35 | ) 36 | relu = nn.ReLU(inplace=True) 37 | 38 | if use_batchnorm == "inplace": 39 | bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0) 40 | relu = nn.Identity() 41 | 42 | elif use_batchnorm and use_batchnorm != "inplace": 43 | bn = nn.BatchNorm2d(out_channels) 44 | 45 | else: 46 | bn = nn.Identity() 47 | 48 | super(Conv2dReLU, self).__init__(conv, bn, relu) 49 | 50 | 51 | class SCSEModule(nn.Module): 52 | def __init__(self, in_channels, reduction=16): 53 | super().__init__() 54 | self.cSE = nn.Sequential( 55 | nn.AdaptiveAvgPool2d(1), 56 | nn.Conv2d(in_channels, in_channels // reduction, 1), 57 | nn.ReLU(inplace=True), 58 | nn.Conv2d(in_channels // reduction, in_channels, 1), 59 | nn.Sigmoid(), 60 | ) 61 | self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid()) 62 | 63 | def forward(self, x): 64 | return x * self.cSE(x) + x * self.sSE(x) 65 | 66 | 67 | class Activation(nn.Module): 68 | 69 | def __init__(self, name, **params): 70 | 71 | super().__init__() 72 | 73 | if name is None or name == 'identity': 74 | self.activation = nn.Identity(**params) 75 | elif name == 'sigmoid': 76 | self.activation = nn.Sigmoid() 77 | elif name == 'softmax2d': 78 | self.activation = nn.Softmax(dim=1, **params) 79 | elif name == 'softmax': 80 | self.activation = nn.Softmax(**params) 81 | elif name == 'logsoftmax': 82 | self.activation = nn.LogSoftmax(**params) 83 | elif callable(name): 84 | self.activation = name(**params) 85 | else: 86 | raise ValueError('Activation should be callable/sigmoid/softmax/logsoftmax/None; got {}'.format(name)) 87 | 88 | def forward(self, x): 89 | return self.activation(x) 90 | 91 | 92 | class Attention(nn.Module): 93 | 94 | def __init__(self, name, **params): 95 | super().__init__() 96 | 97 | if name is None: 98 | self.attention = nn.Identity(**params) 99 | elif name == 'scse': 100 | self.attention = SCSEModule(**params) 101 | else: 102 | raise ValueError("Attention {} is not implemented".format(name)) 103 | 104 | def forward(self, x): 105 | return self.attention(x) 106 | 107 | 108 | class Flatten(nn.Module): 109 | def forward(self, x): 110 | return x.view(x.shape[0], -1) 111 | -------------------------------------------------------------------------------- /code/networks/attention_unet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from networks.utils import UnetConv3, UnetUp3_CT, UnetGridGatingSignal3, UnetDsv3 4 | import torch.nn.functional as F 5 | from networks.networks_other import init_weights 6 | from networks.grid_attention_layer import GridAttentionBlock3D 7 | 8 | 9 | class Attention_UNet(nn.Module): 10 | 11 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, 12 | nonlocal_mode='concatenation', attention_dsample=(2,2,2), is_batchnorm=True): 13 | super(Attention_UNet, self).__init__() 14 | self.is_deconv = is_deconv 15 | self.in_channels = in_channels 16 | self.is_batchnorm = is_batchnorm 17 | self.feature_scale = feature_scale 18 | 19 | filters = [64, 128, 256, 512, 1024] 20 | filters = [int(x / self.feature_scale) for x in filters] 21 | 22 | # downsampling 23 | self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) 24 | self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 25 | 26 | self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) 27 | self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 28 | 29 | self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) 30 | self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 31 | 32 | self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) 33 | self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 34 | 35 | self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) 36 | self.gating = UnetGridGatingSignal3(filters[4], filters[4], kernel_size=(1, 1, 1), is_batchnorm=self.is_batchnorm) 37 | 38 | # attention blocks 39 | self.attentionblock2 = MultiAttentionBlock(in_size=filters[1], gate_size=filters[2], inter_size=filters[1], 40 | nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample) 41 | self.attentionblock3 = MultiAttentionBlock(in_size=filters[2], gate_size=filters[3], inter_size=filters[2], 42 | nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample) 43 | self.attentionblock4 = MultiAttentionBlock(in_size=filters[3], gate_size=filters[4], inter_size=filters[3], 44 | nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample) 45 | 46 | # upsampling 47 | self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm) 48 | self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm) 49 | self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm) 50 | self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm) 51 | 52 | # deep supervision 53 | self.dsv4 = UnetDsv3(in_size=filters[3], out_size=n_classes, scale_factor=8) 54 | self.dsv3 = UnetDsv3(in_size=filters[2], out_size=n_classes, scale_factor=4) 55 | self.dsv2 = UnetDsv3(in_size=filters[1], out_size=n_classes, scale_factor=2) 56 | self.dsv1 = nn.Conv3d(in_channels=filters[0], out_channels=n_classes, kernel_size=1) 57 | 58 | # final conv (without any concat) 59 | self.final = nn.Conv3d(n_classes*4, n_classes, 1) 60 | 61 | # initialise weights 62 | for m in self.modules(): 63 | if isinstance(m, nn.Conv3d): 64 | init_weights(m, init_type='kaiming') 65 | elif isinstance(m, nn.BatchNorm3d): 66 | init_weights(m, init_type='kaiming') 67 | 68 | def forward(self, inputs): 69 | # Feature Extraction 70 | conv1 = self.conv1(inputs) 71 | maxpool1 = self.maxpool1(conv1) 72 | 73 | conv2 = self.conv2(maxpool1) 74 | maxpool2 = self.maxpool2(conv2) 75 | 76 | conv3 = self.conv3(maxpool2) 77 | maxpool3 = self.maxpool3(conv3) 78 | 79 | conv4 = self.conv4(maxpool3) 80 | maxpool4 = self.maxpool4(conv4) 81 | 82 | # Gating Signal Generation 83 | center = self.center(maxpool4) 84 | gating = self.gating(center) 85 | 86 | # Attention Mechanism 87 | # Upscaling Part (Decoder) 88 | g_conv4, att4 = self.attentionblock4(conv4, gating) 89 | up4 = self.up_concat4(g_conv4, center) 90 | g_conv3, att3 = self.attentionblock3(conv3, up4) 91 | up3 = self.up_concat3(g_conv3, up4) 92 | g_conv2, att2 = self.attentionblock2(conv2, up3) 93 | up2 = self.up_concat2(g_conv2, up3) 94 | up1 = self.up_concat1(conv1, up2) 95 | 96 | # Deep Supervision 97 | dsv4 = self.dsv4(up4) 98 | dsv3 = self.dsv3(up3) 99 | dsv2 = self.dsv2(up2) 100 | dsv1 = self.dsv1(up1) 101 | final = self.final(torch.cat([dsv1,dsv2,dsv3,dsv4], dim=1)) 102 | 103 | return final 104 | 105 | 106 | @staticmethod 107 | def apply_argmax_softmax(pred): 108 | log_p = F.softmax(pred, dim=1) 109 | 110 | return log_p 111 | 112 | 113 | class MultiAttentionBlock(nn.Module): 114 | def __init__(self, in_size, gate_size, inter_size, nonlocal_mode, sub_sample_factor): 115 | super(MultiAttentionBlock, self).__init__() 116 | self.gate_block_1 = GridAttentionBlock3D(in_channels=in_size, gating_channels=gate_size, 117 | inter_channels=inter_size, mode=nonlocal_mode, 118 | sub_sample_factor= sub_sample_factor) 119 | self.gate_block_2 = GridAttentionBlock3D(in_channels=in_size, gating_channels=gate_size, 120 | inter_channels=inter_size, mode=nonlocal_mode, 121 | sub_sample_factor=sub_sample_factor) 122 | self.combine_gates = nn.Sequential(nn.Conv3d(in_size*2, in_size, kernel_size=1, stride=1, padding=0), 123 | nn.BatchNorm3d(in_size), 124 | nn.ReLU(inplace=True) 125 | ) 126 | 127 | # initialise the blocks 128 | for m in self.children(): 129 | if m.__class__.__name__.find('GridAttentionBlock3D') != -1: continue 130 | init_weights(m, init_type='kaiming') 131 | 132 | def forward(self, input, gating_signal): 133 | gate_1, attention_1 = self.gate_block_1(input, gating_signal) 134 | gate_2, attention_2 = self.gate_block_2(input, gating_signal) 135 | 136 | return self.combine_gates(torch.cat([gate_1, gate_2], 1)), torch.cat([attention_1, attention_2], 1) -------------------------------------------------------------------------------- /code/networks/config.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # --------------------------------------------------------' 7 | 8 | import os 9 | import yaml 10 | from yacs.config import CfgNode as CN 11 | 12 | _C = CN() 13 | 14 | # Base config files 15 | _C.BASE = [''] 16 | 17 | # ----------------------------------------------------------------------------- 18 | # Data settings 19 | # ----------------------------------------------------------------------------- 20 | _C.DATA = CN() 21 | # Batch size for a single GPU, could be overwritten by command line argument 22 | _C.DATA.BATCH_SIZE = 128 23 | # Path to dataset, could be overwritten by command line argument 24 | _C.DATA.DATA_PATH = '' 25 | # Dataset name 26 | _C.DATA.DATASET = 'imagenet' 27 | # Input image size 28 | _C.DATA.IMG_SIZE = 224 29 | # Interpolation to resize image (random, bilinear, bicubic) 30 | _C.DATA.INTERPOLATION = 'bicubic' 31 | # Use zipped dataset instead of folder dataset 32 | # could be overwritten by command line argument 33 | _C.DATA.ZIP_MODE = False 34 | # Cache Data in Memory, could be overwritten by command line argument 35 | _C.DATA.CACHE_MODE = 'part' 36 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. 37 | _C.DATA.PIN_MEMORY = True 38 | # Number of data loading threads 39 | _C.DATA.NUM_WORKERS = 8 40 | 41 | # ----------------------------------------------------------------------------- 42 | # Model settings 43 | # ----------------------------------------------------------------------------- 44 | _C.MODEL = CN() 45 | # Model type 46 | _C.MODEL.TYPE = 'swin' 47 | # Model name 48 | _C.MODEL.NAME = 'swin_tiny_patch4_window7_224' 49 | # Checkpoint to resume, could be overwritten by command line argument 50 | _C.MODEL.PRETRAIN_CKPT = './pretrained_ckpt/swin_tiny_patch4_window7_224.pth' 51 | _C.MODEL.RESUME = '' 52 | # Number of classes, overwritten in data preparation 53 | _C.MODEL.NUM_CLASSES = 1000 54 | # Dropout rate 55 | _C.MODEL.DROP_RATE = 0.0 56 | # Drop path rate 57 | _C.MODEL.DROP_PATH_RATE = 0.1 58 | # Label Smoothing 59 | _C.MODEL.LABEL_SMOOTHING = 0.1 60 | 61 | # Swin Transformer parameters 62 | _C.MODEL.SWIN = CN() 63 | _C.MODEL.SWIN.PATCH_SIZE = 4 64 | _C.MODEL.SWIN.IN_CHANS = 3 65 | _C.MODEL.SWIN.EMBED_DIM = 96 66 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] 67 | _C.MODEL.SWIN.DECODER_DEPTHS = [2, 2, 6, 2] 68 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] 69 | _C.MODEL.SWIN.WINDOW_SIZE = 7 70 | _C.MODEL.SWIN.MLP_RATIO = 4. 71 | _C.MODEL.SWIN.QKV_BIAS = True 72 | _C.MODEL.SWIN.QK_SCALE = None 73 | _C.MODEL.SWIN.APE = False 74 | _C.MODEL.SWIN.PATCH_NORM = True 75 | _C.MODEL.SWIN.FINAL_UPSAMPLE= "expand_first" 76 | 77 | # ----------------------------------------------------------------------------- 78 | # Training settings 79 | # ----------------------------------------------------------------------------- 80 | _C.TRAIN = CN() 81 | _C.TRAIN.START_EPOCH = 0 82 | _C.TRAIN.EPOCHS = 300 83 | _C.TRAIN.WARMUP_EPOCHS = 20 84 | _C.TRAIN.WEIGHT_DECAY = 0.05 85 | _C.TRAIN.BASE_LR = 5e-4 86 | _C.TRAIN.WARMUP_LR = 5e-7 87 | _C.TRAIN.MIN_LR = 5e-6 88 | # Clip gradient norm 89 | _C.TRAIN.CLIP_GRAD = 5.0 90 | # Auto resume from latest checkpoint 91 | _C.TRAIN.AUTO_RESUME = True 92 | # Gradient accumulation steps 93 | # could be overwritten by command line argument 94 | _C.TRAIN.ACCUMULATION_STEPS = 0 95 | # Whether to use gradient checkpointing to save memory 96 | # could be overwritten by command line argument 97 | _C.TRAIN.USE_CHECKPOINT = False 98 | 99 | # LR scheduler 100 | _C.TRAIN.LR_SCHEDULER = CN() 101 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine' 102 | # Epoch interval to decay LR, used in StepLRScheduler 103 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 104 | # LR decay rate, used in StepLRScheduler 105 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 106 | 107 | # Optimizer 108 | _C.TRAIN.OPTIMIZER = CN() 109 | _C.TRAIN.OPTIMIZER.NAME = 'adamw' 110 | # Optimizer Epsilon 111 | _C.TRAIN.OPTIMIZER.EPS = 1e-8 112 | # Optimizer Betas 113 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) 114 | # SGD momentum 115 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 116 | 117 | # ----------------------------------------------------------------------------- 118 | # Augmentation settings 119 | # ----------------------------------------------------------------------------- 120 | _C.AUG = CN() 121 | # Color jitter factor 122 | _C.AUG.COLOR_JITTER = 0.4 123 | # Use AutoAugment policy. "v0" or "original" 124 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' 125 | # Random erase prob 126 | _C.AUG.REPROB = 0.25 127 | # Random erase mode 128 | _C.AUG.REMODE = 'pixel' 129 | # Random erase count 130 | _C.AUG.RECOUNT = 1 131 | # Mixup alpha, mixup enabled if > 0 132 | _C.AUG.MIXUP = 0.8 133 | # Cutmix alpha, cutmix enabled if > 0 134 | _C.AUG.CUTMIX = 1.0 135 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set 136 | _C.AUG.CUTMIX_MINMAX = None 137 | # Probability of performing mixup or cutmix when either/both is enabled 138 | _C.AUG.MIXUP_PROB = 1.0 139 | # Probability of switching to cutmix when both mixup and cutmix enabled 140 | _C.AUG.MIXUP_SWITCH_PROB = 0.5 141 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" 142 | _C.AUG.MIXUP_MODE = 'batch' 143 | 144 | # ----------------------------------------------------------------------------- 145 | # Testing settings 146 | # ----------------------------------------------------------------------------- 147 | _C.TEST = CN() 148 | # Whether to use center crop when testing 149 | _C.TEST.CROP = True 150 | 151 | # ----------------------------------------------------------------------------- 152 | # Misc 153 | # ----------------------------------------------------------------------------- 154 | # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2') 155 | # overwritten by command line argument 156 | _C.AMP_OPT_LEVEL = '' 157 | # Path to output folder, overwritten by command line argument 158 | _C.OUTPUT = '' 159 | # Tag of experiment, overwritten by command line argument 160 | _C.TAG = 'default' 161 | # Frequency to save checkpoint 162 | _C.SAVE_FREQ = 1 163 | # Frequency to logging info 164 | _C.PRINT_FREQ = 10 165 | # Fixed random seed 166 | _C.SEED = 0 167 | # Perform evaluation only, overwritten by command line argument 168 | _C.EVAL_MODE = False 169 | # Test throughput only, overwritten by command line argument 170 | _C.THROUGHPUT_MODE = False 171 | # local rank for DistributedDataParallel, given by command line argument 172 | _C.LOCAL_RANK = 0 173 | 174 | 175 | def _update_config_from_file(config, cfg_file): 176 | config.defrost() 177 | with open(cfg_file, 'r') as f: 178 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) 179 | 180 | for cfg in yaml_cfg.setdefault('BASE', ['']): 181 | if cfg: 182 | _update_config_from_file( 183 | config, os.path.join(os.path.dirname(cfg_file), cfg) 184 | ) 185 | print('=> merge config from {}'.format(cfg_file)) 186 | config.merge_from_file(cfg_file) 187 | config.freeze() 188 | 189 | 190 | def update_config(config, args): 191 | _update_config_from_file(config, args.cfg) 192 | 193 | config.defrost() 194 | if args.opts: 195 | config.merge_from_list(args.opts) 196 | 197 | # merge from specific arguments 198 | if args.batch_size: 199 | config.DATA.BATCH_SIZE = args.batch_size 200 | if args.zip: 201 | config.DATA.ZIP_MODE = True 202 | if args.cache_mode: 203 | config.DATA.CACHE_MODE = args.cache_mode 204 | if args.resume: 205 | config.MODEL.RESUME = args.resume 206 | if args.accumulation_steps: 207 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps 208 | if args.use_checkpoint: 209 | config.TRAIN.USE_CHECKPOINT = True 210 | if args.amp_opt_level: 211 | config.AMP_OPT_LEVEL = args.amp_opt_level 212 | if args.tag: 213 | config.TAG = args.tag 214 | if args.eval: 215 | config.EVAL_MODE = True 216 | if args.throughput: 217 | config.THROUGHPUT_MODE = True 218 | 219 | config.freeze() 220 | 221 | 222 | def get_config(args): 223 | """Get a yacs CfgNode object with default values.""" 224 | # Return a clone so that the defaults will not be altered 225 | # This is for the "local variable" use pattern 226 | config = _C.clone() 227 | update_config(config, args) 228 | 229 | return config 230 | -------------------------------------------------------------------------------- /code/networks/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FC3DDiscriminator(nn.Module): 7 | 8 | def __init__(self, num_classes, ndf=64, n_channel=1): 9 | super(FC3DDiscriminator, self).__init__() 10 | # downsample 16 11 | self.conv0 = nn.Conv3d( 12 | num_classes, ndf, kernel_size=4, stride=2, padding=1) 13 | self.conv1 = nn.Conv3d( 14 | n_channel, ndf, kernel_size=4, stride=2, padding=1) 15 | 16 | self.conv2 = nn.Conv3d(ndf, ndf*2, kernel_size=4, stride=2, padding=1) 17 | self.conv3 = nn.Conv3d( 18 | ndf*2, ndf*4, kernel_size=4, stride=2, padding=1) 19 | self.conv4 = nn.Conv3d( 20 | ndf*4, ndf*8, kernel_size=4, stride=2, padding=1) 21 | self.avgpool = nn.AvgPool3d((6, 6, 6)) # (D/16, W/16, H/16) 22 | self.classifier = nn.Linear(ndf*8, 2) 23 | 24 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 25 | self.dropout = nn.Dropout3d(0.5) 26 | self.Softmax = nn.Softmax() 27 | 28 | def forward(self, map, image): 29 | batch_size = map.shape[0] 30 | map_feature = self.conv0(map) 31 | image_feature = self.conv1(image) 32 | x = torch.add(map_feature, image_feature) 33 | x = self.leaky_relu(x) 34 | x = self.dropout(x) 35 | 36 | x = self.conv2(x) 37 | x = self.leaky_relu(x) 38 | x = self.dropout(x) 39 | 40 | x = self.conv3(x) 41 | x = self.leaky_relu(x) 42 | x = self.dropout(x) 43 | 44 | x = self.conv4(x) 45 | x = self.leaky_relu(x) 46 | 47 | x = self.avgpool(x) 48 | 49 | x = x.view(batch_size, -1) 50 | 51 | x = self.classifier(x) 52 | x = x.reshape((batch_size, 2)) 53 | # x = self.Softmax(x) 54 | 55 | return x 56 | 57 | 58 | class FCDiscriminator(nn.Module): 59 | 60 | def __init__(self, num_classes, ndf=64, n_channel=1): 61 | super(FCDiscriminator, self).__init__() 62 | self.conv0 = nn.Conv2d( 63 | num_classes, ndf, kernel_size=4, stride=2, padding=1) 64 | self.conv1 = nn.Conv2d( 65 | n_channel, ndf, kernel_size=4, stride=2, padding=1) 66 | self.conv2 = nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1) 67 | self.conv3 = nn.Conv2d( 68 | ndf*2, ndf*4, kernel_size=4, stride=2, padding=1) 69 | self.conv4 = nn.Conv2d( 70 | ndf*4, ndf*8, kernel_size=4, stride=2, padding=1) 71 | self.classifier = nn.Linear(ndf*32, 2) 72 | self.avgpool = nn.AvgPool2d((7, 7)) 73 | 74 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 75 | self.dropout = nn.Dropout2d(0.5) 76 | # self.up_sample = nn.Upsample(scale_factor=32, mode='bilinear') 77 | # self.sigmoid = nn.Sigmoid() 78 | 79 | def forward(self, map, feature): 80 | map_feature = self.conv0(map) 81 | image_feature = self.conv1(feature) 82 | 83 | x = torch.add(map_feature, image_feature) 84 | 85 | x = self.conv2(x) 86 | x = self.leaky_relu(x) 87 | x = self.dropout(x) 88 | 89 | x = self.conv3(x) 90 | x = self.leaky_relu(x) 91 | x = self.dropout(x) 92 | 93 | x = self.conv4(x) 94 | x = self.leaky_relu(x) 95 | x = self.avgpool(x) 96 | x = x.view(x.size(0), -1) 97 | x = self.classifier(x) 98 | # x = self.up_sample(x) 99 | # x = self.sigmoid(x) 100 | 101 | return x 102 | -------------------------------------------------------------------------------- /code/networks/efficientunet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from networks.attention import * 6 | from networks.efficient_encoder import get_encoder 7 | 8 | 9 | def initialize_decoder(module): 10 | for m in module.modules(): 11 | 12 | if isinstance(m, nn.Conv2d): 13 | nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu") 14 | if m.bias is not None: 15 | nn.init.constant_(m.bias, 0) 16 | 17 | elif isinstance(m, nn.BatchNorm2d): 18 | nn.init.constant_(m.weight, 1) 19 | nn.init.constant_(m.bias, 0) 20 | 21 | elif isinstance(m, nn.Linear): 22 | nn.init.xavier_uniform_(m.weight) 23 | if m.bias is not None: 24 | nn.init.constant_(m.bias, 0) 25 | 26 | 27 | class DecoderBlock(nn.Module): 28 | def __init__( 29 | self, 30 | in_channels, 31 | skip_channels, 32 | out_channels, 33 | use_batchnorm=True, 34 | attention_type=None, 35 | ): 36 | super().__init__() 37 | self.conv1 = Conv2dReLU( 38 | in_channels + skip_channels, 39 | out_channels, 40 | kernel_size=3, 41 | padding=1, 42 | use_batchnorm=use_batchnorm, 43 | ) 44 | self.attention1 = Attention(attention_type, in_channels=in_channels + skip_channels) 45 | self.conv2 = Conv2dReLU( 46 | out_channels, 47 | out_channels, 48 | kernel_size=3, 49 | padding=1, 50 | use_batchnorm=use_batchnorm, 51 | ) 52 | self.attention2 = Attention(attention_type, in_channels=out_channels) 53 | 54 | def forward(self, x, skip=None): 55 | x = F.interpolate(x, scale_factor=2, mode="nearest") 56 | if skip is not None: 57 | x = torch.cat([x, skip], dim=1) 58 | x = self.attention1(x) 59 | x = self.conv1(x) 60 | x = self.conv2(x) 61 | x = self.attention2(x) 62 | return x 63 | 64 | 65 | class CenterBlock(nn.Sequential): 66 | def __init__(self, in_channels, out_channels, use_batchnorm=True): 67 | conv1 = Conv2dReLU( 68 | in_channels, 69 | out_channels, 70 | kernel_size=3, 71 | padding=1, 72 | use_batchnorm=use_batchnorm, 73 | ) 74 | conv2 = Conv2dReLU( 75 | out_channels, 76 | out_channels, 77 | kernel_size=3, 78 | padding=1, 79 | use_batchnorm=use_batchnorm, 80 | ) 81 | super().__init__(conv1, conv2) 82 | 83 | 84 | class UnetDecoder(nn.Module): 85 | def __init__( 86 | self, 87 | encoder_channels, 88 | decoder_channels, 89 | n_blocks=5, 90 | use_batchnorm=True, 91 | attention_type=None, 92 | center=False, 93 | ): 94 | super().__init__() 95 | 96 | if n_blocks != len(decoder_channels): 97 | raise ValueError( 98 | "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( 99 | n_blocks, len(decoder_channels) 100 | ) 101 | ) 102 | 103 | encoder_channels = encoder_channels[1:] # remove first skip with same spatial resolution 104 | encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder 105 | 106 | # computing blocks input and output channels 107 | head_channels = encoder_channels[0] 108 | in_channels = [head_channels] + list(decoder_channels[:-1]) 109 | skip_channels = list(encoder_channels[1:]) + [0] 110 | out_channels = decoder_channels 111 | 112 | if center: 113 | self.center = CenterBlock( 114 | head_channels, head_channels, use_batchnorm=use_batchnorm 115 | ) 116 | else: 117 | self.center = nn.Identity() 118 | 119 | # combine decoder keyword arguments 120 | kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type) 121 | blocks = [ 122 | DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) 123 | for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels) 124 | ] 125 | self.blocks = nn.ModuleList(blocks) 126 | 127 | def forward(self, *features): 128 | 129 | features = features[1:] # remove first skip with same spatial resolution 130 | features = features[::-1] # reverse channels to start from head of encoder 131 | 132 | head = features[0] 133 | skips = features[1:] 134 | 135 | x = self.center(head) 136 | for i, decoder_block in enumerate(self.blocks): 137 | skip = skips[i] if i < len(skips) else None 138 | x = decoder_block(x, skip) 139 | 140 | return x 141 | 142 | 143 | class Effi_UNet(nn.Module): 144 | """Unet_ is a fully convolution neural network for image semantic segmentation 145 | 146 | Args: 147 | encoder_name: name of classification model (without last dense layers) used as feature 148 | extractor to build segmentation model. 149 | encoder_depth (int): number of stages used in decoder, larger depth - more features are generated. 150 | e.g. for depth=3 encoder will generate list of features with following spatial shapes 151 | [(H,W), (H/2, W/2), (H/4, W/4), (H/8, W/8)], so in general the deepest feature tensor will have 152 | spatial resolution (H/(2^depth), W/(2^depth)] 153 | encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet). 154 | decoder_channels: list of numbers of ``Conv2D`` layer filters in decoder blocks 155 | decoder_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers 156 | is used. If 'inplace' InplaceABN will be used, allows to decrease memory consumption. 157 | One of [True, False, 'inplace'] 158 | decoder_attention_type: attention module used in decoder of the model 159 | One of [``None``, ``scse``] 160 | in_channels: number of input channels for model, default is 3. 161 | classes: a number of classes for output (output shape - ``(batch, classes, h, w)``). 162 | activation: activation function to apply after final convolution; 163 | One of [``sigmoid``, ``softmax``, ``logsoftmax``, ``identity``, callable, None] 164 | aux_params: if specified model will have additional classification auxiliary output 165 | build on top of encoder, supported params: 166 | - classes (int): number of classes 167 | - pooling (str): one of 'max', 'avg'. Default is 'avg'. 168 | - dropout (float): dropout factor in [0, 1) 169 | - activation (str): activation function to apply "sigmoid"/"softmax" (could be None to return logits) 170 | 171 | Returns: 172 | ``torch.nn.Module``: **Unet** 173 | 174 | .. _Unet: 175 | https://arxiv.org/pdf/1505.04597 176 | 177 | """ 178 | 179 | def __init__( 180 | self, 181 | encoder_name: str = "resnet34", 182 | encoder_depth: int = 5, 183 | encoder_weights: str = "imagenet", 184 | decoder_use_batchnorm=True, 185 | decoder_channels=(256, 128, 64, 32, 16), 186 | decoder_attention_type=None, 187 | in_channels: int = 3, 188 | classes: int = 1): 189 | super().__init__() 190 | 191 | self.encoder = get_encoder( 192 | encoder_name, 193 | in_channels=in_channels, 194 | depth=encoder_depth, 195 | weights=encoder_weights, 196 | ) 197 | 198 | self.decoder = UnetDecoder( 199 | encoder_channels=self.encoder.out_channels, 200 | decoder_channels=decoder_channels, 201 | n_blocks=encoder_depth, 202 | use_batchnorm=decoder_use_batchnorm, 203 | center=True if encoder_name.startswith("vgg") else False, 204 | attention_type=decoder_attention_type, 205 | ) 206 | initialize_decoder(self.decoder) 207 | self.classifier = nn.Conv2d(decoder_channels[-1], classes, 1) 208 | 209 | def forward(self, x): 210 | """Sequentially pass `x` trough model`s encoder, decoder and heads""" 211 | features = self.encoder(x) 212 | decoder_output = self.decoder(*features) 213 | output = self.classifier(decoder_output) 214 | 215 | return output 216 | 217 | 218 | # unet = UNet('efficientnet-b3', encoder_weights='imagenet', in_channels=1, classes=1, decoder_attention_type="scse") 219 | # t = torch.rand(2, 1, 224, 224) 220 | # print(unet) 221 | # print(unet(t).shape) 222 | -------------------------------------------------------------------------------- /code/networks/encoder_tool.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.utils.model_zoo as model_zoo 6 | from efficientnet_pytorch import EfficientNet 7 | from efficientnet_pytorch.utils import get_model_params, url_map 8 | 9 | 10 | class EncoderMixin: 11 | """Add encoder functionality such as: 12 | - output channels specification of feature tensors (produced by encoder) 13 | - patching first convolution for arbitrary input channels 14 | """ 15 | 16 | @property 17 | def out_channels(self) -> List: 18 | """Return channels dimensions for each tensor of forward output of encoder""" 19 | return self._out_channels[: self._depth + 1] 20 | 21 | def set_in_channels(self, in_channels): 22 | """Change first convolution chennels""" 23 | if in_channels == 3: 24 | return 25 | 26 | self._in_channels = in_channels 27 | if self._out_channels[0] == 3: 28 | self._out_channels = tuple([in_channels] + list(self._out_channels)[1:]) 29 | 30 | patch_first_conv(model=self, in_channels=in_channels) 31 | 32 | 33 | def patch_first_conv(model, in_channels): 34 | """Change first convolution layer input channels. 35 | In case: 36 | in_channels == 1 or in_channels == 2 -> reuse original weights 37 | in_channels > 3 -> make random kaiming normal initialization 38 | """ 39 | 40 | # get first conv 41 | for module in model.modules(): 42 | if isinstance(module, nn.Conv2d): 43 | break 44 | 45 | # change input channels for first conv 46 | module.in_channels = in_channels 47 | weight = module.weight.detach() 48 | reset = False 49 | 50 | if in_channels == 1: 51 | weight = weight.sum(1, keepdim=True) 52 | elif in_channels == 2: 53 | weight = weight[:, :2] * (3.0 / 2.0) 54 | else: 55 | reset = True 56 | weight = torch.Tensor( 57 | module.out_channels, 58 | module.in_channels // module.groups, 59 | *module.kernel_size 60 | ) 61 | 62 | module.weight = nn.parameter.Parameter(weight) 63 | if reset: 64 | module.reset_parameters() 65 | 66 | 67 | class EfficientNetEncoder(EfficientNet, EncoderMixin): 68 | def __init__(self, stage_idxs, out_channels, model_name, depth=5): 69 | 70 | blocks_args, global_params = get_model_params(model_name, override_params=None) 71 | super().__init__(blocks_args, global_params) 72 | 73 | self._stage_idxs = list(stage_idxs) + [len(self._blocks)] 74 | self._out_channels = out_channels 75 | self._depth = depth 76 | self._in_channels = 3 77 | 78 | del self._fc 79 | 80 | def forward(self, x): 81 | 82 | features = [x] 83 | 84 | if self._depth > 0: 85 | x = self._swish(self._bn0(self._conv_stem(x))) 86 | features.append(x) 87 | 88 | if self._depth > 1: 89 | skip_connection_idx = 0 90 | for idx, block in enumerate(self._blocks): 91 | drop_connect_rate = self._global_params.drop_connect_rate 92 | if drop_connect_rate: 93 | drop_connect_rate *= float(idx) / len(self._blocks) 94 | x = block(x, drop_connect_rate=drop_connect_rate) 95 | if idx == self._stage_idxs[skip_connection_idx] - 1: 96 | skip_connection_idx += 1 97 | features.append(x) 98 | if skip_connection_idx + 1 == self._depth: 99 | break 100 | return features 101 | 102 | def load_state_dict(self, state_dict, **kwargs): 103 | state_dict.pop("_fc.bias") 104 | state_dict.pop("_fc.weight") 105 | super().load_state_dict(state_dict, **kwargs) 106 | 107 | 108 | def _get_pretrained_settings(encoder): 109 | pretrained_settings = { 110 | "imagenet": { 111 | "mean": [0.485, 0.456, 0.406], 112 | "std": [0.229, 0.224, 0.225], 113 | "url": url_map[encoder], 114 | "input_space": "RGB", 115 | "input_range": [0, 1], 116 | } 117 | } 118 | return pretrained_settings 119 | 120 | 121 | efficient_net_encoders = { 122 | "efficientnet-b0": { 123 | "encoder": EfficientNetEncoder, 124 | "pretrained_settings": _get_pretrained_settings("efficientnet-b0"), 125 | "params": { 126 | "out_channels": (3, 32, 24, 40, 112, 320), 127 | "stage_idxs": (3, 5, 9), 128 | "model_name": "efficientnet-b0", 129 | }, 130 | }, 131 | "efficientnet-b1": { 132 | "encoder": EfficientNetEncoder, 133 | "pretrained_settings": _get_pretrained_settings("efficientnet-b1"), 134 | "params": { 135 | "out_channels": (3, 32, 24, 40, 112, 320), 136 | "stage_idxs": (5, 8, 16), 137 | "model_name": "efficientnet-b1", 138 | }, 139 | }, 140 | "efficientnet-b2": { 141 | "encoder": EfficientNetEncoder, 142 | "pretrained_settings": _get_pretrained_settings("efficientnet-b2"), 143 | "params": { 144 | "out_channels": (3, 32, 24, 48, 120, 352), 145 | "stage_idxs": (5, 8, 16), 146 | "model_name": "efficientnet-b2", 147 | }, 148 | }, 149 | "efficientnet-b3": { 150 | "encoder": EfficientNetEncoder, 151 | "pretrained_settings": _get_pretrained_settings("efficientnet-b3"), 152 | "params": { 153 | "out_channels": (3, 40, 32, 48, 136, 384), 154 | "stage_idxs": (5, 8, 18), 155 | "model_name": "efficientnet-b3", 156 | }, 157 | }, 158 | "efficientnet-b4": { 159 | "encoder": EfficientNetEncoder, 160 | "pretrained_settings": _get_pretrained_settings("efficientnet-b4"), 161 | "params": { 162 | "out_channels": (3, 48, 32, 56, 160, 448), 163 | "stage_idxs": (6, 10, 22), 164 | "model_name": "efficientnet-b4", 165 | }, 166 | }, 167 | "efficientnet-b5": { 168 | "encoder": EfficientNetEncoder, 169 | "pretrained_settings": _get_pretrained_settings("efficientnet-b5"), 170 | "params": { 171 | "out_channels": (3, 48, 40, 64, 176, 512), 172 | "stage_idxs": (8, 13, 27), 173 | "model_name": "efficientnet-b5", 174 | }, 175 | }, 176 | "efficientnet-b6": { 177 | "encoder": EfficientNetEncoder, 178 | "pretrained_settings": _get_pretrained_settings("efficientnet-b6"), 179 | "params": { 180 | "out_channels": (3, 56, 40, 72, 200, 576), 181 | "stage_idxs": (9, 15, 31), 182 | "model_name": "efficientnet-b6", 183 | }, 184 | }, 185 | "efficientnet-b7": { 186 | "encoder": EfficientNetEncoder, 187 | "pretrained_settings": _get_pretrained_settings("efficientnet-b7"), 188 | "params": { 189 | "out_channels": (3, 64, 48, 80, 224, 640), 190 | "stage_idxs": (11, 18, 38), 191 | "model_name": "efficientnet-b7", 192 | }, 193 | }, 194 | } 195 | 196 | encoders = {} 197 | encoders.update(efficient_net_encoders) 198 | 199 | 200 | def get_encoder(name, in_channels=3, depth=5, weights=None): 201 | Encoder = encoders[name]["encoder"] 202 | params = encoders[name]["params"] 203 | params.update(depth=depth) 204 | encoder = Encoder(**params) 205 | 206 | if weights is not None: 207 | settings = encoders[name]["pretrained_settings"][weights] 208 | encoder.load_state_dict(model_zoo.load_url(settings["url"])) 209 | 210 | encoder.set_in_channels(in_channels) 211 | 212 | return encoder 213 | -------------------------------------------------------------------------------- /code/networks/net_factory.py: -------------------------------------------------------------------------------- 1 | from networks.efficientunet import Effi_UNet 2 | from networks.enet import ENet 3 | from networks.pnet import PNet2D 4 | from networks.unet import UNet, UNet_DS, UNet_URPC, UNet_CCT 5 | import argparse 6 | from networks.vision_transformer import SwinUnet as ViT_seg 7 | from networks.config import get_config 8 | from networks.nnunet import initialize_network 9 | 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--root_path', type=str, 13 | default='../data/ACDC', help='Name of Experiment') 14 | parser.add_argument('--exp', type=str, 15 | default='ACDC/Cross_Supervision_CNN_Trans2D', help='experiment_name') 16 | parser.add_argument('--fold', type=str, 17 | default='fold1', help='cross validation') 18 | parser.add_argument('--sup_type', type=str, 19 | default='scribble', help='supervision type') 20 | 21 | parser.add_argument('--model', type=str, 22 | default='unet', help='model_name') 23 | parser.add_argument('--max_iterations', type=int, 24 | default=30000, help='maximum epoch number to train') 25 | parser.add_argument('--batch_size', type=int, default=8, 26 | help='batch_size per gpu') 27 | parser.add_argument('--deterministic', type=int, default=1, 28 | help='whether use deterministic training') 29 | parser.add_argument('--base_lr', type=float, default=0.01, 30 | help='segmentation network learning rate') 31 | parser.add_argument('--patch_size', type=list, default=[224, 224], 32 | help='patch size of network input') 33 | parser.add_argument('--seed', type=int, default=1337, help='random seed') 34 | parser.add_argument('--num_classes', type=int, default=4, 35 | help='output channel of network') 36 | parser.add_argument( 37 | '--cfg', type=str, default="../code/configs/swin_tiny_patch4_window7_224_lite.yaml", help='path to config file', ) 38 | parser.add_argument( 39 | "--opts", 40 | help="Modify config options by adding 'KEY VALUE' pairs. ", 41 | default=None, 42 | nargs='+', 43 | ) 44 | parser.add_argument('--zip', action='store_true', 45 | help='use zipped dataset instead of folder dataset') 46 | parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'], 47 | help='no: no cache, ' 48 | 'full: cache all data, ' 49 | 'part: sharding the dataset into nonoverlapping pieces and only cache one piece') 50 | parser.add_argument('--resume', help='resume from checkpoint') 51 | parser.add_argument('--accumulation-steps', type=int, 52 | help="gradient accumulation steps") 53 | parser.add_argument('--use-checkpoint', action='store_true', 54 | help="whether to use gradient checkpointing to save memory") 55 | parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'], 56 | help='mixed precision opt level, if O0, no amp is used') 57 | parser.add_argument('--tag', help='tag of experiment') 58 | parser.add_argument('--eval', action='store_true', 59 | help='Perform evaluation only') 60 | parser.add_argument('--throughput', action='store_true', 61 | help='Test throughput only') 62 | 63 | # costs 64 | parser.add_argument('--ema_decay', type=float, default=0.99, help='ema_decay') 65 | parser.add_argument('--consistency_type', type=str, 66 | default="mse", help='consistency_type') 67 | parser.add_argument('--consistency', type=float, 68 | default=0.1, help='consistency') 69 | parser.add_argument('--consistency_rampup', type=float, 70 | default=200.0, help='consistency_rampup') 71 | args = parser.parse_args() 72 | config = get_config(args) 73 | 74 | 75 | def net_factory(net_type="unet", in_chns=1, class_num=3): 76 | if net_type == "unet": 77 | net = UNet(in_chns=in_chns, class_num=class_num).cuda() 78 | elif net_type == "enet": 79 | net = ENet(in_channels=in_chns, num_classes=class_num).cuda() 80 | elif net_type == "vnet": 81 | net = ENet(in_channels=in_chns, num_classes=class_num).cuda() 82 | elif net_type == "unet_ds": 83 | net = UNet_DS(in_chns=in_chns, class_num=class_num).cuda() 84 | elif net_type == "unet_cct": 85 | net = UNet_CCT(in_chns=in_chns, class_num=class_num).cuda() 86 | elif net_type == "unet_urpc": 87 | net = UNet_URPC(in_chns=in_chns, class_num=class_num).cuda() 88 | elif net_type == "efficient_unet": 89 | net = Effi_UNet('efficientnet-b3', encoder_weights='imagenet', 90 | in_channels=in_chns, classes=class_num).cuda() 91 | elif net_type == "ViT_Seg": 92 | net = ViT_seg(config, img_size=args.patch_size, 93 | num_classes=args.num_classes).cuda() 94 | elif net_type == "pnet": 95 | net = PNet2D(in_chns, class_num, 64, [1, 2, 4, 8, 16]).cuda() 96 | elif net_type == "nnUNet": 97 | net = initialize_network(num_classes=class_num).cuda() 98 | else: 99 | net = None 100 | return net -------------------------------------------------------------------------------- /code/networks/net_factory_3d.py: -------------------------------------------------------------------------------- 1 | from networks.unet_3D import unet_3D 2 | from networks.vnet import VNet 3 | from networks.VoxResNet import VoxResNet 4 | from networks.attention_unet import Attention_UNet 5 | 6 | 7 | def net_factory_3d(net_type="unet_3D", in_chns=1, class_num=2): 8 | if net_type == "unet_3D": 9 | net = unet_3D(n_classes=class_num, in_channels=in_chns).cuda() 10 | elif net_type == "attention_unet": 11 | net = Attention_UNet(n_classes=class_num, in_channels=in_chns).cuda() 12 | elif net_type == "voxresnet": 13 | net = VoxResNet(in_chns=in_chns, feature_chns=64, 14 | class_num=class_num).cuda() 15 | elif net_type == "vnet": 16 | net = VNet(n_channels=in_chns, n_classes=class_num, 17 | normalization='batchnorm', has_dropout=True).cuda() 18 | else: 19 | net = None 20 | return net 21 | -------------------------------------------------------------------------------- /code/networks/pnet.py: -------------------------------------------------------------------------------- 1 | 2 | # -*- coding: utf-8 -*- 3 | """ 4 | An PyTorch implementation of the DeepIGeoS paper: 5 | Wang, Guotai and Zuluaga, Maria A and Li, Wenqi and Pratt, Rosalind and Patel, Premal A and Aertsen, Michael and Doel, Tom and David, Anna L and Deprest, Jan and Ourselin, S{\'e}bastien and others: 6 | DeepIGeoS: a deep interactive geodesic framework for medical image segmentation. 7 | TPAMI (7) 2018: 1559--1572 8 | Note that there are some modifications from the original paper, such as 9 | the use of leaky relu here. 10 | """ 11 | from __future__ import division, print_function 12 | 13 | import torch 14 | import torch.nn as nn 15 | 16 | 17 | class PNetBlock(nn.Module): 18 | def __init__(self, in_channels, out_channels, dilation, padding): 19 | super(PNetBlock, self).__init__() 20 | 21 | self.in_chns = in_channels 22 | self.out_chns = out_channels 23 | self.dilation = dilation 24 | self.padding = padding 25 | 26 | self.conv1 = nn.Conv2d(self.in_chns, self.out_chns, kernel_size=3, 27 | padding=self.padding, dilation=self.dilation, groups=1, bias=True) 28 | self.conv2 = nn.Conv2d(self.out_chns, self.out_chns, kernel_size=3, 29 | padding=self.padding, dilation=self.dilation, groups=1, bias=True) 30 | self.in1 = nn.BatchNorm2d(self.out_chns) 31 | self.in2 = nn.BatchNorm2d(self.out_chns) 32 | self.ac1 = nn.LeakyReLU() 33 | self.ac2 = nn.LeakyReLU() 34 | 35 | def forward(self, x): 36 | x = self.conv1(x) 37 | x = self.in1(x) 38 | x = self.ac1(x) 39 | x = self.conv2(x) 40 | x = self.in2(x) 41 | x = self.ac2(x) 42 | return x 43 | 44 | 45 | class ConcatBlock(nn.Module): 46 | def __init__(self, in_channels, out_channels): 47 | super(ConcatBlock, self).__init__() 48 | self.in_chns = in_channels 49 | self.out_chns = out_channels 50 | self.conv1 = nn.Conv2d( 51 | self.in_chns, self.in_chns, kernel_size=1, padding=0) 52 | self.conv2 = nn.Conv2d( 53 | self.in_chns, self.out_chns, kernel_size=1, padding=0) 54 | self.ac1 = nn.LeakyReLU() 55 | self.ac2 = nn.LeakyReLU() 56 | 57 | def forward(self, x): 58 | x = self.conv1(x) 59 | x = self.ac1(x) 60 | x = self.conv2(x) 61 | x = self.ac2(x) 62 | return x 63 | 64 | 65 | class OutPutBlock(nn.Module): 66 | def __init__(self, in_channels, out_channels): 67 | super(OutPutBlock, self).__init__() 68 | self.in_chns = in_channels 69 | self.out_chns = out_channels 70 | self.conv1 = nn.Conv2d( 71 | self.in_chns, self.in_chns // 2, kernel_size=1, padding=0) 72 | self.conv2 = nn.Conv2d( 73 | self.in_chns // 2, self.out_chns, kernel_size=1, padding=0) 74 | self.drop1 = nn.Dropout2d(0.3) 75 | self.drop2 = nn.Dropout2d(0.3) 76 | self.ac1 = nn.LeakyReLU() 77 | 78 | def forward(self, x): 79 | x = self.drop1(x) 80 | x = self.conv1(x) 81 | x = self.ac1(x) 82 | x = self.drop2(x) 83 | x = self.conv2(x) 84 | return x 85 | 86 | 87 | class PNet2D(nn.Module): 88 | def __init__(self, in_chns, out_chns, num_filters, ratios): 89 | super(PNet2D, self).__init__() 90 | 91 | self.in_chns = in_chns 92 | self.out_chns = out_chns 93 | self.ratios = ratios 94 | self.num_filters = num_filters 95 | 96 | self.block1 = PNetBlock( 97 | self.in_chns, self.num_filters, self.ratios[0], padding=self.ratios[0]) 98 | 99 | self.block2 = PNetBlock( 100 | self.num_filters, self.num_filters, self.ratios[1], padding=self.ratios[1]) 101 | 102 | self.block3 = PNetBlock( 103 | self.num_filters, self.num_filters, self.ratios[2], padding=self.ratios[2]) 104 | 105 | self.block4 = PNetBlock( 106 | self.num_filters, self.num_filters, self.ratios[3], padding=self.ratios[3]) 107 | 108 | self.block5 = PNetBlock( 109 | self.num_filters, self.num_filters, self.ratios[4], padding=self.ratios[4]) 110 | self.catblock = ConcatBlock(self.num_filters * 5, self.num_filters * 2) 111 | self.out = OutPutBlock(self.num_filters * 2, self.out_chns) 112 | 113 | def forward(self, x): 114 | x1 = self.block1(x) 115 | x2 = self.block2(x1) 116 | x3 = self.block3(x2) 117 | x4 = self.block4(x3) 118 | x5 = self.block5(x4) 119 | conx = torch.cat([x1, x2, x3, x4, x5], dim=1) 120 | conx = self.catblock(conx) 121 | out = self.out(conx) 122 | return out 123 | -------------------------------------------------------------------------------- /code/networks/unet_3D.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | An implementation of the 3D U-Net paper: 4 | Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf Ronneberger: 5 | 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation. 6 | MICCAI (2) 2016: 424-432 7 | Note that there are some modifications from the original paper, such as 8 | the use of batch normalization, dropout, and leaky relu here. 9 | The implementation is borrowed from: https://github.com/ozan-oktay/Attention-Gated-Networks 10 | """ 11 | import math 12 | 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | from networks.networks_other import init_weights 17 | from networks.utils import UnetConv3, UnetUp3, UnetUp3_CT 18 | 19 | 20 | class unet_3D(nn.Module): 21 | 22 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True): 23 | super(unet_3D, self).__init__() 24 | self.is_deconv = is_deconv 25 | self.in_channels = in_channels 26 | self.is_batchnorm = is_batchnorm 27 | self.feature_scale = feature_scale 28 | 29 | filters = [64, 128, 256, 512, 1024] 30 | filters = [int(x / self.feature_scale) for x in filters] 31 | 32 | # downsampling 33 | self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=( 34 | 3, 3, 3), padding_size=(1, 1, 1)) 35 | self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 36 | 37 | self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=( 38 | 3, 3, 3), padding_size=(1, 1, 1)) 39 | self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 40 | 41 | self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=( 42 | 3, 3, 3), padding_size=(1, 1, 1)) 43 | self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 44 | 45 | self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=( 46 | 3, 3, 3), padding_size=(1, 1, 1)) 47 | self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 48 | 49 | self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=( 50 | 3, 3, 3), padding_size=(1, 1, 1)) 51 | 52 | # upsampling 53 | self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm) 54 | self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm) 55 | self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm) 56 | self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm) 57 | 58 | # final conv (without any concat) 59 | self.final = nn.Conv3d(filters[0], n_classes, 1) 60 | 61 | self.dropout1 = nn.Dropout(p=0.3) 62 | self.dropout2 = nn.Dropout(p=0.3) 63 | 64 | # initialise weights 65 | for m in self.modules(): 66 | if isinstance(m, nn.Conv3d): 67 | init_weights(m, init_type='kaiming') 68 | elif isinstance(m, nn.BatchNorm3d): 69 | init_weights(m, init_type='kaiming') 70 | 71 | def forward(self, inputs): 72 | conv1 = self.conv1(inputs) 73 | maxpool1 = self.maxpool1(conv1) 74 | 75 | conv2 = self.conv2(maxpool1) 76 | maxpool2 = self.maxpool2(conv2) 77 | 78 | conv3 = self.conv3(maxpool2) 79 | maxpool3 = self.maxpool3(conv3) 80 | 81 | conv4 = self.conv4(maxpool3) 82 | maxpool4 = self.maxpool4(conv4) 83 | 84 | center = self.center(maxpool4) 85 | center = self.dropout1(center) 86 | up4 = self.up_concat4(conv4, center) 87 | up3 = self.up_concat3(conv3, up4) 88 | up2 = self.up_concat2(conv2, up3) 89 | up1 = self.up_concat1(conv1, up2) 90 | up1 = self.dropout2(up1) 91 | 92 | final = self.final(up1) 93 | 94 | return final 95 | 96 | @staticmethod 97 | def apply_argmax_softmax(pred): 98 | log_p = F.softmax(pred, dim=1) 99 | 100 | return log_p 101 | -------------------------------------------------------------------------------- /code/networks/vision_mamba.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import copy 7 | import logging 8 | import math 9 | 10 | from os.path import join as pjoin 11 | 12 | import torch 13 | import torch.nn as nn 14 | import numpy as np 15 | 16 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm 17 | from torch.nn.modules.utils import _pair 18 | from scipy import ndimage 19 | from .mamba_sys import VSSM 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | class MambaUnet(nn.Module): 24 | def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False): 25 | super(MambaUnet, self).__init__() 26 | self.num_classes = num_classes 27 | self.zero_head = zero_head 28 | self.config = config 29 | 30 | self.mamba_unet = VSSM( 31 | patch_size=config.MODEL.VSSM.PATCH_SIZE, 32 | in_chans=config.MODEL.VSSM.IN_CHANS, 33 | num_classes=self.num_classes, 34 | embed_dim=config.MODEL.VSSM.EMBED_DIM, 35 | depths=config.MODEL.VSSM.DEPTHS, 36 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 37 | drop_rate=config.MODEL.DROP_RATE, 38 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 39 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 40 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 41 | 42 | def forward(self, x): 43 | if x.size()[1] == 1: 44 | x = x.repeat(1,3,1,1) 45 | logits = self.mamba_unet(x) 46 | return logits 47 | 48 | def load_from(self, config): 49 | pretrained_path = config.MODEL.PRETRAIN_CKPT 50 | if pretrained_path is not None: 51 | print("pretrained_path:{}".format(pretrained_path)) 52 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 53 | pretrained_dict = torch.load(pretrained_path, map_location=device) 54 | if "model" not in pretrained_dict: 55 | print("---start load pretrained modle by splitting---") 56 | pretrained_dict = {k[17:]:v for k,v in pretrained_dict.items()} 57 | for k in list(pretrained_dict.keys()): 58 | if "output" in k: 59 | print("delete key:{}".format(k)) 60 | del pretrained_dict[k] 61 | msg = self.mamba_unet.load_state_dict(pretrained_dict,strict=False) 62 | # print(msg) 63 | return 64 | pretrained_dict = pretrained_dict['model'] 65 | print("---start load pretrained modle of swin encoder---") 66 | 67 | model_dict = self.mamba_unet.state_dict() 68 | full_dict = copy.deepcopy(pretrained_dict) 69 | for k, v in pretrained_dict.items(): 70 | if "layers." in k: 71 | current_layer_num = 3-int(k[7:8]) 72 | current_k = "layers_up." + str(current_layer_num) + k[8:] 73 | full_dict.update({current_k:v}) 74 | for k in list(full_dict.keys()): 75 | if k in model_dict: 76 | if full_dict[k].shape != model_dict[k].shape: 77 | print("delete:{};shape pretrain:{};shape model:{}".format(k,v.shape,model_dict[k].shape)) 78 | del full_dict[k] 79 | 80 | msg = self.mamba_unet.load_state_dict(full_dict, strict=False) 81 | # print(msg) 82 | else: 83 | print("none pretrain") -------------------------------------------------------------------------------- /code/networks/vision_transformer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # This file borrowed from Swin-UNet: https://github.com/HuCaoFighting/Swin-Unet 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import copy 8 | import logging 9 | import math 10 | 11 | from os.path import join as pjoin 12 | 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | 17 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm 18 | from torch.nn.modules.utils import _pair 19 | from scipy import ndimage 20 | from networks.swin_transformer_unet_skip_expand_decoder_sys import SwinTransformerSys 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | class SwinUnet(nn.Module): 25 | def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False): 26 | super(SwinUnet, self).__init__() 27 | self.num_classes = num_classes 28 | self.zero_head = zero_head 29 | self.config = config 30 | 31 | self.swin_unet = SwinTransformerSys(img_size=config.DATA.IMG_SIZE, 32 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 33 | in_chans=config.MODEL.SWIN.IN_CHANS, 34 | num_classes=self.num_classes, 35 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 36 | depths=config.MODEL.SWIN.DEPTHS, 37 | num_heads=config.MODEL.SWIN.NUM_HEADS, 38 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 39 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 40 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 41 | qk_scale=config.MODEL.SWIN.QK_SCALE, 42 | drop_rate=config.MODEL.DROP_RATE, 43 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 44 | ape=config.MODEL.SWIN.APE, 45 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 46 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 47 | 48 | def forward(self, x): 49 | if x.size()[1] == 1: 50 | x = x.repeat(1,3,1,1) 51 | logits = self.swin_unet(x) 52 | return logits 53 | 54 | def load_from(self, config): 55 | pretrained_path = config.MODEL.PRETRAIN_CKPT 56 | if pretrained_path is not None: 57 | print("pretrained_path:{}".format(pretrained_path)) 58 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 59 | pretrained_dict = torch.load(pretrained_path, map_location=device) 60 | if "model" not in pretrained_dict: 61 | print("---start load pretrained modle by splitting---") 62 | pretrained_dict = {k[17:]:v for k,v in pretrained_dict.items()} 63 | for k in list(pretrained_dict.keys()): 64 | if "output" in k: 65 | print("delete key:{}".format(k)) 66 | del pretrained_dict[k] 67 | msg = self.swin_unet.load_state_dict(pretrained_dict,strict=False) 68 | # print(msg) 69 | return 70 | pretrained_dict = pretrained_dict['model'] 71 | print("---start load pretrained modle of swin encoder---") 72 | 73 | model_dict = self.swin_unet.state_dict() 74 | full_dict = copy.deepcopy(pretrained_dict) 75 | for k, v in pretrained_dict.items(): 76 | if "layers." in k: 77 | current_layer_num = 3-int(k[7:8]) 78 | current_k = "layers_up." + str(current_layer_num) + k[8:] 79 | full_dict.update({current_k:v}) 80 | for k in list(full_dict.keys()): 81 | if k in model_dict: 82 | if full_dict[k].shape != model_dict[k].shape: 83 | print("delete:{};shape pretrain:{};shape model:{}".format(k,v.shape,model_dict[k].shape)) 84 | del full_dict[k] 85 | 86 | msg = self.swin_unet.load_state_dict(full_dict, strict=False) 87 | # print(msg) 88 | else: 89 | print("none pretrain") 90 | -------------------------------------------------------------------------------- /code/pretrained_ckpt/readme.txt: -------------------------------------------------------------------------------- 1 | download pre-trained model to this folder, 2 | 3 | 4 | https://drive.google.com/file/d/14RzbbBDjbKbgr0ordKlWbb69EFkHuplr/view?usp=sharing 5 | 6 | https://drive.google.com/file/d/1uUPsr7XeqayCxlspqBHbg5zIWx0JYtSX/view?usp=sharing 7 | -------------------------------------------------------------------------------- /code/test_2D.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import re 4 | import shutil 5 | 6 | import h5py 7 | import nibabel as nib 8 | import numpy as np 9 | import SimpleITK as sitk 10 | import torch 11 | from medpy import metric 12 | from scipy.ndimage import zoom 13 | from scipy.ndimage.interpolation import zoom 14 | from tqdm import tqdm 15 | 16 | # from networks.efficientunet import UNet 17 | from networks.net_factory import net_factory, config, args 18 | 19 | from config import get_config 20 | from networks.vision_transformer import SwinUnet as ViT_seg 21 | 22 | 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--root_path', type=str, 26 | default='../data/ACDC', help='Name of Experiment') 27 | parser.add_argument('--exp', type=str, 28 | default='semi_mamba_unet_3unet', help='experiment_name') 29 | parser.add_argument('--model', type=str, 30 | default='unet', help='model_name') 31 | parser.add_argument('--fold', type=str, 32 | default='fold1', help='fold') 33 | parser.add_argument('--num_classes', type=int, default=4, 34 | help='output channel of network') 35 | # parser.add_argument('--sup_type', type=str, default="label",help='label') 36 | parser.add_argument('--sup_type', type=str, default="scribble",help='label/scribble') 37 | 38 | def get_fold_ids(fold): 39 | all_cases_set = ["patient{:0>3}".format(i) for i in range(1, 101)] 40 | fold1_testing_set = [ 41 | "patient{:0>3}".format(i) for i in range(1, 21)] 42 | fold1_training_set = [ 43 | i for i in all_cases_set if i not in fold1_testing_set] 44 | 45 | fold2_testing_set = [ 46 | "patient{:0>3}".format(i) for i in range(21, 41)] 47 | fold2_training_set = [ 48 | i for i in all_cases_set if i not in fold2_testing_set] 49 | 50 | fold3_testing_set = [ 51 | "patient{:0>3}".format(i) for i in range(41, 61)] 52 | fold3_training_set = [ 53 | i for i in all_cases_set if i not in fold3_testing_set] 54 | 55 | fold4_testing_set = [ 56 | "patient{:0>3}".format(i) for i in range(61, 81)] 57 | fold4_training_set = [ 58 | i for i in all_cases_set if i not in fold4_testing_set] 59 | 60 | fold5_testing_set = [ 61 | "patient{:0>3}".format(i) for i in range(81, 101)] 62 | fold5_training_set = [ 63 | i for i in all_cases_set if i not in fold5_testing_set] 64 | if fold == "fold1": 65 | return [fold1_training_set, fold1_testing_set] 66 | elif fold == "fold2": 67 | return [fold2_training_set, fold2_testing_set] 68 | elif fold == "fold3": 69 | return [fold3_training_set, fold3_testing_set] 70 | elif fold == "fold4": 71 | return [fold4_training_set, fold4_testing_set] 72 | elif fold == "fold5": 73 | return [fold5_training_set, fold5_testing_set] 74 | else: 75 | return "ERROR KEY" 76 | 77 | 78 | def calculate_metric_percase(pred, gt, spacing): 79 | pred[pred > 0] = 1 80 | gt[gt > 0] = 1 81 | dice = metric.binary.dc(pred, gt) 82 | asd = metric.binary.asd(pred, gt, voxelspacing=spacing) 83 | hd95 = metric.binary.hd95(pred, gt, voxelspacing=spacing) 84 | return dice, hd95, asd 85 | 86 | 87 | def test_single_volume(case, net, test_save_path, FLAGS): 88 | h5f = h5py.File(FLAGS.root_path + 89 | "/ACDC_training_volumes/{}".format(case), 'r') 90 | image = h5f['image'][:] 91 | label = h5f['label'][:] 92 | prediction = np.zeros_like(label) 93 | for ind in range(image.shape[0]): 94 | slice = image[ind, :, :] 95 | x, y = slice.shape[0], slice.shape[1] 96 | slice = zoom(slice, (224 / x, 224 / y), order=0) 97 | input = torch.from_numpy(slice).unsqueeze( 98 | 0).unsqueeze(0).float().cuda() 99 | net.eval() 100 | with torch.no_grad(): 101 | out_main = net(input) 102 | out = torch.argmax(torch.softmax( 103 | out_main, dim=1), dim=1).squeeze(0) 104 | out = out.cpu().detach().numpy() 105 | pred = zoom(out, (x / 224, y / 224), order=0) 106 | prediction[ind] = pred 107 | case = case.replace(".h5", "") 108 | org_img_path = "../data/ACDC_training/{}.nii.gz".format(case) 109 | org_img_itk = sitk.ReadImage(org_img_path) 110 | spacing = org_img_itk.GetSpacing() 111 | 112 | first_metric = calculate_metric_percase( 113 | prediction == 1, label == 1, (spacing[2], spacing[0], spacing[1])) 114 | second_metric = calculate_metric_percase( 115 | prediction == 2, label == 2, (spacing[2], spacing[0], spacing[1])) 116 | third_metric = calculate_metric_percase( 117 | prediction == 3, label == 3, (spacing[2], spacing[0], spacing[1])) 118 | 119 | img_itk = sitk.GetImageFromArray(image.astype(np.float32)) 120 | img_itk.CopyInformation(org_img_itk) 121 | prd_itk = sitk.GetImageFromArray(prediction.astype(np.float32)) 122 | prd_itk.CopyInformation(org_img_itk) 123 | lab_itk = sitk.GetImageFromArray(label.astype(np.float32)) 124 | lab_itk.CopyInformation(org_img_itk) 125 | sitk.WriteImage(prd_itk, test_save_path + case + "_pred.nii.gz") 126 | sitk.WriteImage(img_itk, test_save_path + case + "_img.nii.gz") 127 | sitk.WriteImage(lab_itk, test_save_path + case + "_gt.nii.gz") 128 | return first_metric, second_metric, third_metric 129 | 130 | 131 | def Inference(FLAGS): 132 | train_ids, test_ids = get_fold_ids(FLAGS.fold) 133 | all_volumes = os.listdir( 134 | FLAGS.root_path + "/ACDC_training_volumes") 135 | image_list = [] 136 | for ids in test_ids: 137 | new_data_list = list(filter(lambda x: re.match( 138 | '{}.*'.format(ids), x) != None, all_volumes)) 139 | image_list.extend(new_data_list) 140 | snapshot_path = "../model/{}_{}/{}".format( 141 | FLAGS.exp, FLAGS.fold, FLAGS.sup_type) 142 | test_save_path = "../model/{}_{}/{}/{}_predictions/".format( 143 | FLAGS.exp, FLAGS.fold, FLAGS.sup_type, FLAGS.model) 144 | if os.path.exists(test_save_path): 145 | shutil.rmtree(test_save_path) 146 | os.makedirs(test_save_path) 147 | 148 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #v 149 | net = net_factory(net_type=FLAGS.model, in_chns=1, 150 | class_num=FLAGS.num_classes) 151 | 152 | # net = ViT_seg(config, img_size=[224, 224], num_classes=args.num_classes).cuda() 153 | # net.load_from(config) 154 | 155 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 156 | save_mode_path = os.path.join( 157 | snapshot_path, 'ViT_best_model.pth') 158 | net.load_state_dict(torch.load(save_mode_path)) 159 | print("init weight from {}".format(save_mode_path)) 160 | net.eval() 161 | 162 | first_total = 0.0 163 | second_total = 0.0 164 | third_total = 0.0 165 | for case in tqdm(image_list): 166 | print(case) 167 | first_metric, second_metric, third_metric = test_single_volume( 168 | case, net, test_save_path, FLAGS) 169 | first_total += np.asarray(first_metric) 170 | second_total += np.asarray(second_metric) 171 | third_total += np.asarray(third_metric) 172 | avg_metric = [first_total / len(image_list), second_total / 173 | len(image_list), third_total / len(image_list)] 174 | print(avg_metric) 175 | print((avg_metric[0] + avg_metric[1] + avg_metric[2]) / 3) 176 | return ((avg_metric[0] + avg_metric[1] + avg_metric[2]) / 3)[0] 177 | 178 | 179 | if __name__ == '__main__': 180 | FLAGS = parser.parse_args() 181 | total = 0.0 182 | # for i in [5]: 183 | # for i in [5]: 184 | # FLAGS.fold = "fold{}".format(i) 185 | # print("Inference fold{}".format(i)) 186 | mean_dice = Inference(FLAGS) 187 | total += mean_dice 188 | # print(total/5.0) 189 | -------------------------------------------------------------------------------- /code/utils/__pycache__/gate_crf_loss.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/utils/__pycache__/gate_crf_loss.cpython-310.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/gate_crf_loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/utils/__pycache__/gate_crf_loss.cpython-38.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/losses.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/utils/__pycache__/losses.cpython-310.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/losses.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/utils/__pycache__/losses.cpython-38.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/metrics.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/utils/__pycache__/metrics.cpython-310.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/utils/__pycache__/metrics.cpython-38.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/ramps.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/utils/__pycache__/ramps.cpython-310.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/ramps.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/utils/__pycache__/ramps.cpython-38.pyc -------------------------------------------------------------------------------- /code/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from medpy import metric 3 | 4 | 5 | def cal_dice(prediction, label, num=2): 6 | total_dice = np.zeros(num-1) 7 | for i in range(1, num): 8 | prediction_tmp = (prediction == i) 9 | label_tmp = (label == i) 10 | prediction_tmp = prediction_tmp.astype(np.float) 11 | label_tmp = label_tmp.astype(np.float) 12 | 13 | dice = 2 * np.sum(prediction_tmp * label_tmp) / (np.sum(prediction_tmp) + np.sum(label_tmp)) 14 | total_dice[i - 1] += dice 15 | 16 | return total_dice 17 | 18 | 19 | def calculate_metric_percase(pred, gt): 20 | dc = metric.binary.dc(pred, gt) 21 | jc = metric.binary.jc(pred, gt) 22 | hd = metric.binary.hd95(pred, gt) 23 | asd = metric.binary.asd(pred, gt) 24 | 25 | return dc, jc, hd, asd 26 | 27 | 28 | def dice(input, target, ignore_index=None): 29 | smooth = 1. 30 | # using clone, so that it can do change to original target. 31 | iflat = input.clone().view(-1) 32 | tflat = target.clone().view(-1) 33 | if ignore_index is not None: 34 | mask = tflat == ignore_index 35 | tflat[mask] = 0 36 | iflat[mask] = 0 37 | intersection = (iflat * tflat).sum() 38 | 39 | return (2. * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth) -------------------------------------------------------------------------------- /code/utils/ramps.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, Curious AI Ltd. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Functions for ramping hyperparameters up or down 9 | 10 | Each function takes the current training step or epoch, and the 11 | ramp length in the same format, and returns a multiplier between 12 | 0 and 1. 13 | """ 14 | 15 | 16 | import numpy as np 17 | 18 | 19 | def sigmoid_rampup(current, rampup_length): 20 | """Exponential rampup from https://arxiv.org/abs/1610.02242""" 21 | if rampup_length == 0: 22 | return 1.0 23 | else: 24 | current = np.clip(current, 0.0, rampup_length) 25 | phase = 1.0 - current / rampup_length 26 | return float(np.exp(-5.0 * phase * phase)) 27 | 28 | 29 | def linear_rampup(current, rampup_length): 30 | """Linear rampup""" 31 | assert current >= 0 and rampup_length >= 0 32 | if current >= rampup_length: 33 | return 1.0 34 | else: 35 | return current / rampup_length 36 | 37 | 38 | def cosine_rampdown(current, rampdown_length): 39 | """Cosine rampdown from https://arxiv.org/abs/1608.03983""" 40 | assert 0 <= current <= rampdown_length 41 | return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1)) 42 | -------------------------------------------------------------------------------- /code/utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | from scipy.ndimage import distance_transform_edt as distance 5 | from skimage import segmentation as skimage_seg 6 | import torch 7 | from torch.utils.data.sampler import Sampler 8 | 9 | import networks 10 | 11 | def load_model(path): 12 | """Loads model and return it without DataParallel table.""" 13 | if os.path.isfile(path): 14 | print("=> loading checkpoint '{}'".format(path)) 15 | checkpoint = torch.load(path) 16 | 17 | # size of the top layer 18 | N = checkpoint['state_dict']['top_layer.bias'].size() 19 | 20 | # build skeleton of the model 21 | sob = 'sobel.0.weight' in checkpoint['state_dict'].keys() 22 | model = models.__dict__[checkpoint['arch']](sobel=sob, out=int(N[0])) 23 | 24 | # deal with a dataparallel table 25 | def rename_key(key): 26 | if not 'module' in key: 27 | return key 28 | return ''.join(key.split('.module')) 29 | 30 | checkpoint['state_dict'] = {rename_key(key): val 31 | for key, val 32 | in checkpoint['state_dict'].items()} 33 | 34 | # load weights 35 | model.load_state_dict(checkpoint['state_dict']) 36 | print("Loaded") 37 | else: 38 | model = None 39 | print("=> no checkpoint found at '{}'".format(path)) 40 | return model 41 | 42 | 43 | class UnifLabelSampler(Sampler): 44 | """Samples elements uniformely accross pseudolabels. 45 | Args: 46 | N (int): size of returned iterator. 47 | images_lists: dict of key (target), value (list of data with this target) 48 | """ 49 | 50 | def __init__(self, N, images_lists): 51 | self.N = N 52 | self.images_lists = images_lists 53 | self.indexes = self.generate_indexes_epoch() 54 | 55 | def generate_indexes_epoch(self): 56 | size_per_pseudolabel = int(self.N / len(self.images_lists)) + 1 57 | res = np.zeros(size_per_pseudolabel * len(self.images_lists)) 58 | 59 | for i in range(len(self.images_lists)): 60 | indexes = np.random.choice( 61 | self.images_lists[i], 62 | size_per_pseudolabel, 63 | replace=(len(self.images_lists[i]) <= size_per_pseudolabel) 64 | ) 65 | res[i * size_per_pseudolabel: (i + 1) * size_per_pseudolabel] = indexes 66 | 67 | np.random.shuffle(res) 68 | return res[:self.N].astype('int') 69 | 70 | def __iter__(self): 71 | return iter(self.indexes) 72 | 73 | def __len__(self): 74 | return self.N 75 | 76 | 77 | class AverageMeter(object): 78 | """Computes and stores the average and current value""" 79 | def __init__(self): 80 | self.reset() 81 | 82 | def reset(self): 83 | self.val = 0 84 | self.avg = 0 85 | self.sum = 0 86 | self.count = 0 87 | 88 | def update(self, val, n=1): 89 | self.val = val 90 | self.sum += val * n 91 | self.count += n 92 | self.avg = self.sum / self.count 93 | 94 | 95 | def learning_rate_decay(optimizer, t, lr_0): 96 | for param_group in optimizer.param_groups: 97 | lr = lr_0 / np.sqrt(1 + lr_0 * param_group['weight_decay'] * t) 98 | param_group['lr'] = lr 99 | 100 | 101 | class Logger(): 102 | """ Class to update every epoch to keep trace of the results 103 | Methods: 104 | - log() log and save 105 | """ 106 | 107 | def __init__(self, path): 108 | self.path = path 109 | self.data = [] 110 | 111 | def log(self, train_point): 112 | self.data.append(train_point) 113 | with open(os.path.join(self.path), 'wb') as fp: 114 | pickle.dump(self.data, fp, -1) 115 | 116 | 117 | def compute_sdf(img_gt, out_shape): 118 | """ 119 | compute the signed distance map of binary mask 120 | input: segmentation, shape = (batch_size, x, y, z) 121 | output: the Signed Distance Map (SDM) 122 | sdf(x) = 0; x in segmentation boundary 123 | -inf|x-y|; x in segmentation 124 | +inf|x-y|; x out of segmentation 125 | normalize sdf to [-1,1] 126 | """ 127 | 128 | img_gt = img_gt.astype(np.uint8) 129 | normalized_sdf = np.zeros(out_shape) 130 | 131 | for b in range(out_shape[0]): # batch size 132 | posmask = img_gt[b].astype(np.bool) 133 | if posmask.any(): 134 | negmask = ~posmask 135 | posdis = distance(posmask) 136 | negdis = distance(negmask) 137 | boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8) 138 | sdf = (negdis-np.min(negdis))/(np.max(negdis)-np.min(negdis)) - (posdis-np.min(posdis))/(np.max(posdis)-np.min(posdis)) 139 | sdf[boundary==1] = 0 140 | normalized_sdf[b] = sdf 141 | # assert np.min(sdf) == -1.0, print(np.min(posdis), np.max(posdis), np.min(negdis), np.max(negdis)) 142 | # assert np.max(sdf) == 1.0, print(np.min(posdis), np.min(negdis), np.max(posdis), np.max(negdis)) 143 | 144 | return normalized_sdf -------------------------------------------------------------------------------- /code/val_2D.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from medpy import metric 4 | from scipy.ndimage import zoom 5 | 6 | 7 | def calculate_metric_percase(pred, gt): 8 | pred[pred > 0] = 1 9 | gt[gt > 0] = 1 10 | if pred.sum() > 0: 11 | dice = metric.binary.dc(pred, gt) 12 | hd95 = metric.binary.hd95(pred, gt) 13 | return dice, hd95 14 | else: 15 | return 0, 0 16 | 17 | 18 | def test_single_volume(image, label, net, classes, patch_size=[256, 256]): 19 | image, label = image.squeeze(0).cpu().detach( 20 | ).numpy(), label.squeeze(0).cpu().detach().numpy() 21 | if len(image.shape) == 3: 22 | prediction = np.zeros_like(label) 23 | for ind in range(image.shape[0]): 24 | slice = image[ind, :, :] 25 | x, y = slice.shape[0], slice.shape[1] 26 | slice = zoom( 27 | slice, (patch_size[0] / x, patch_size[1] / y), order=0) 28 | input = torch.from_numpy(slice).unsqueeze( 29 | 0).unsqueeze(0).float().cuda() 30 | net.eval() 31 | with torch.no_grad(): 32 | out = torch.argmax(torch.softmax( 33 | net(input), dim=1), dim=1).squeeze(0) 34 | out = out.cpu().detach().numpy() 35 | pred = zoom( 36 | out, (x / patch_size[0], y / patch_size[1]), order=0) 37 | prediction[ind] = pred 38 | else: 39 | input = torch.from_numpy(image).unsqueeze( 40 | 0).unsqueeze(0).float().cuda() 41 | net.eval() 42 | with torch.no_grad(): 43 | out = torch.argmax(torch.softmax( 44 | net(input), dim=1), dim=1).squeeze(0) 45 | prediction = out.cpu().detach().numpy() 46 | metric_list = [] 47 | for i in range(1, classes): 48 | metric_list.append(calculate_metric_percase( 49 | prediction == i, label == i)) 50 | return metric_list 51 | 52 | 53 | def test_single_volume_ds(image, label, net, classes, patch_size=[256, 256]): 54 | image, label = image.squeeze(0).cpu().detach( 55 | ).numpy(), label.squeeze(0).cpu().detach().numpy() 56 | if len(image.shape) == 3: 57 | prediction = np.zeros_like(label) 58 | for ind in range(image.shape[0]): 59 | slice = image[ind, :, :] 60 | x, y = slice.shape[0], slice.shape[1] 61 | slice = zoom( 62 | slice, (patch_size[0] / x, patch_size[1] / y), order=0) 63 | input = torch.from_numpy(slice).unsqueeze( 64 | 0).unsqueeze(0).float().cuda() 65 | net.eval() 66 | with torch.no_grad(): 67 | output_main, _, _, _ = net(input) 68 | out = torch.argmax(torch.softmax( 69 | output_main, dim=1), dim=1).squeeze(0) 70 | out = out.cpu().detach().numpy() 71 | pred = zoom( 72 | out, (x / patch_size[0], y / patch_size[1]), order=0) 73 | prediction[ind] = pred 74 | else: 75 | input = torch.from_numpy(image).unsqueeze( 76 | 0).unsqueeze(0).float().cuda() 77 | net.eval() 78 | with torch.no_grad(): 79 | output_main, _, _, _ = net(input) 80 | out = torch.argmax(torch.softmax( 81 | output_main, dim=1), dim=1).squeeze(0) 82 | prediction = out.cpu().detach().numpy() 83 | metric_list = [] 84 | for i in range(1, classes): 85 | metric_list.append(calculate_metric_percase( 86 | prediction == i, label == i)) 87 | return metric_list 88 | 89 | 90 | def test_single_volume_cct(image, label, net, classes, patch_size=[256, 256]): 91 | image, label = image.squeeze(0).cpu().detach( 92 | ).numpy(), label.squeeze(0).cpu().detach().numpy() 93 | if len(image.shape) == 3: 94 | prediction = np.zeros_like(label) 95 | for ind in range(image.shape[0]): 96 | slice = image[ind, :, :] 97 | x, y = slice.shape[0], slice.shape[1] 98 | slice = zoom( 99 | slice, (patch_size[0] / x, patch_size[1] / y), order=0) 100 | input = torch.from_numpy(slice).unsqueeze( 101 | 0).unsqueeze(0).float().cuda() 102 | net.eval() 103 | with torch.no_grad(): 104 | output_main = net(input)[0] 105 | out = torch.argmax(torch.softmax( 106 | output_main, dim=1), dim=1).squeeze(0) 107 | out = out.cpu().detach().numpy() 108 | pred = zoom( 109 | out, (x / patch_size[0], y / patch_size[1]), order=0) 110 | prediction[ind] = pred 111 | else: 112 | input = torch.from_numpy(image).unsqueeze( 113 | 0).unsqueeze(0).float().cuda() 114 | net.eval() 115 | with torch.no_grad(): 116 | output_main, _, _, _ = net(input) 117 | out = torch.argmax(torch.softmax( 118 | output_main, dim=1), dim=1).squeeze(0) 119 | prediction = out.cpu().detach().numpy() 120 | metric_list = [] 121 | for i in range(1, classes): 122 | metric_list.append(calculate_metric_percase( 123 | prediction == i, label == i)) 124 | return metric_list 125 | -------------------------------------------------------------------------------- /data/readme.txt: -------------------------------------------------------------------------------- 1 | 2 | Google Drive: 3 | 4 | https://drive.google.com/file/d/1XR_Id0wdvXY9QeKtdOdgJHKVJ-nVr2j1/view?usp=sharing 5 | 6 | or 7 | 8 | Baidu Netdisk: 9 | 10 | https://pan.baidu.com/s/1dHkp9daqE3kLEbAP6zl7Jw with passcode: 'rwv2' 11 | -------------------------------------------------------------------------------- /img/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/img/results.png -------------------------------------------------------------------------------- /img/wslframework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/img/wslframework.png -------------------------------------------------------------------------------- /img/wslintro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/img/wslintro.png -------------------------------------------------------------------------------- /mamba/.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "3rdparty/lm-evaluation-harness"] 2 | path = 3rdparty/lm-evaluation-harness 3 | url = https://github.com/EleutherAI/lm-evaluation-harness/ 4 | -------------------------------------------------------------------------------- /mamba/AUTHORS: -------------------------------------------------------------------------------- 1 | Tri Dao, tri@tridao.me 2 | Albert Gu, agu@andrew.cmu.edu 3 | -------------------------------------------------------------------------------- /mamba/README.md: -------------------------------------------------------------------------------- 1 | # Mamba 2 | 3 | ![Mamba](assets/selection.png "Selective State Space") 4 | > **Mamba: Linear-Time Sequence Modeling with Selective State Spaces**\ 5 | > Albert Gu*, Tri Dao*\ 6 | > Paper: https://arxiv.org/abs/2312.00752 7 | 8 | ## About 9 | 10 | Mamba is a new state space model architecture showing promising performance on information-dense data such as language modeling, where previous subquadratic models fall short of Transformers. 11 | It is based on the line of progress on [structured state space models](https://github.com/state-spaces/s4), 12 | with an efficient hardware-aware design and implementation in the spirit of [FlashAttention](https://github.com/Dao-AILab/flash-attention). 13 | 14 | ## Installation 15 | 16 | - `pip install causal-conv1d`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block. 17 | - `pip install mamba-ssm`: the core Mamba package. 18 | 19 | It can also be built from source with `pip install .` from this repository. 20 | 21 | If `pip` complains about PyTorch versions, try passing `--no-build-isolation` to `pip`. 22 | 23 | Other requirements: 24 | - Linux 25 | - NVIDIA GPU 26 | - PyTorch 1.12+ 27 | - CUDA 11.6+ 28 | 29 | ## Usage 30 | 31 | We expose several levels of interface with the Mamba model. 32 | 33 | ### Selective SSM 34 | 35 | Mamba is based on a selective SSM layer, which is the focus of the paper (Section 3; Algorithm 2). 36 | 37 | Source: [ops/selective_scan_interface.py](mamba_ssm/ops/selective_scan_interface.py). 38 | 39 | ### Mamba Block 40 | 41 | The main module of this repository is the Mamba architecture block wrapping the selective SSM. 42 | 43 | Source: [modules/mamba_simple.py](mamba_ssm/modules/mamba_simple.py). 44 | 45 | Usage: 46 | ``` 47 | from mamba_ssm import Mamba 48 | 49 | batch, length, dim = 2, 64, 16 50 | x = torch.randn(batch, length, dim).to("cuda") 51 | model = Mamba( 52 | # This module uses roughly 3 * expand * d_model^2 parameters 53 | d_model=dim, # Model dimension d_model 54 | d_state=16, # SSM state expansion factor 55 | d_conv=4, # Local convolution width 56 | expand=2, # Block expansion factor 57 | ).to("cuda") 58 | y = model(x) 59 | assert y.shape == x.shape 60 | ``` 61 | 62 | ### Mamba Language Model 63 | 64 | Finally, we provide an example of a complete language model: a deep sequence model backbone (with repeating Mamba blocks) + language model head. 65 | 66 | Source: [models/mixer_seq_simple.py](mamba_ssm/models/mixer_seq_simple.py). 67 | 68 | This is an example of how to integrate Mamba into an end-to-end neural network. 69 | This example is used in the generation scripts below. 70 | 71 | 72 | 73 | ## Pretrained Models 74 | 75 | Pretrained models are uploaded to 76 | [HuggingFace](https://huggingface.co/state-spaces): `mamba-130m`, `mamba-370m`, 77 | `mamba-790m`, `mamba-1.4b`, `mamba-2.8b`. 78 | 79 | The models will be autodownloaded by the generation script below. 80 | 81 | These models were trained on the [Pile](https://huggingface.co/datasets/EleutherAI/pile), and follow the standard model dimensions described by GPT-3 and followed by many open source models: 82 | 83 | | Parameters | Layers | Model dim. | 84 | |------------|--------|------------| 85 | | 130M | 12 | 768 | 86 | | 370M | 24 | 1024 | 87 | | 790M | 24 | 1536 | 88 | | 1.4B | 24 | 2048 | 89 | | 2.8B | 32 | 2560 | 90 | 91 | (The layer count of Mamba should be doubled, as two Mamba blocks are needed for each "layer" (MHA block + MLP block) of a Transformer.) 92 | 93 | Note: these are base models trained only for 300B tokens, without any form of downstream modification (instruction tuning, etc.). 94 | Performance is expected to be comparable or better than other architectures trained on similar data, but not to match larger or fine-tuned models. 95 | 96 | 97 | ## Evaluations 98 | 99 | To run zero-shot evaluations of models (corresponding to Table 3 of the paper), 100 | we use the 101 | [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) 102 | library. 103 | 104 | 1. Pull the `lm-evaluation-harness` repo by `git submodule update --init 105 | --recursive`. We use the `big-refactor` branch. 106 | 2. Install `lm-evaluation-harness`: `pip install -e 3rdparty/lm-evaluation-harness` 107 | 3. Run evaluation with (more documentation at the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) repo): 108 | ``` 109 | python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64 110 | python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64 111 | ``` 112 | 113 | Note that the result of each task might differ from reported values by 0.1-0.3 due to noise in the evaluation process. 114 | 115 | ## Inference 116 | 117 | The script [benchmarks/benchmark_generation_mamba_simple.py](benchmarks/benchmark_generation_mamba_simple.py) 118 | 1. autoloads a model from the HuggingFace Hub, 119 | 2. generates completions of a user-specified prompt, 120 | 3. benchmarks the inference speed of this generation. 121 | 122 | Other configurable options include the top-p (nucleus sampling) probability, and the softmax temperature. 123 | 124 | ### Examples 125 | 126 | To test generation latency (e.g. batch size = 1) with different sampling strategies: 127 | 128 | ``` 129 | python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.5 130 | python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.5 131 | ``` 132 | 133 | To test generation throughput with random prompts (e.g. large batch size): 134 | ``` 135 | python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 128 136 | python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 128 137 | ``` 138 | 139 | ## Citation 140 | 141 | If you use this codebase, or otherwise found our work valuable, please cite Mamba: 142 | ``` 143 | @article{mamba, 144 | title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces}, 145 | author={Gu, Albert and Dao, Tri}, 146 | journal={arXiv preprint arXiv:2312.00752}, 147 | year={2023} 148 | } 149 | ``` 150 | -------------------------------------------------------------------------------- /mamba/assets/selection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/mamba/assets/selection.png -------------------------------------------------------------------------------- /mamba/benchmarks/benchmark_generation_mamba_simple.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao, Albert Gu. 2 | 3 | import argparse 4 | import time 5 | import json 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from einops import rearrange 11 | 12 | from transformers import AutoTokenizer, AutoModelForCausalLM 13 | 14 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 15 | 16 | 17 | parser = argparse.ArgumentParser(description="Generation benchmarking") 18 | parser.add_argument("--model-name", type=str, default="state-spaces/mamba-130m") 19 | parser.add_argument("--prompt", type=str, default=None) 20 | parser.add_argument("--promptlen", type=int, default=100) 21 | parser.add_argument("--genlen", type=int, default=100) 22 | parser.add_argument("--temperature", type=float, default=1.0) 23 | parser.add_argument("--topk", type=int, default=1) 24 | parser.add_argument("--topp", type=float, default=1.0) 25 | parser.add_argument("--batch", type=int, default=1) 26 | args = parser.parse_args() 27 | 28 | repeats = 3 29 | device = "cuda" 30 | dtype = torch.float16 31 | 32 | print(f"Loading model {args.model_name}") 33 | is_mamba = args.model_name.startswith("state-spaces/mamba-") or "mamba" in args.model_name 34 | 35 | if is_mamba: 36 | tokenizer = AutoTokenizer.from_pretrained("/home/zhulianghui/VisionProjects/mamba/ckpts/gpt-neox-20b-tokenizer") 37 | model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype) 38 | else: 39 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 40 | model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map={"": device}, torch_dtype=dtype) 41 | model.eval() 42 | print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") 43 | 44 | torch.random.manual_seed(0) 45 | if args.prompt is None: 46 | input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device="cuda") 47 | attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda") 48 | else: 49 | tokens = tokenizer(args.prompt, return_tensors="pt") 50 | input_ids = tokens.input_ids.to(device=device) 51 | attn_mask = tokens.attention_mask.to(device=device) 52 | max_length = input_ids.shape[1] + args.genlen 53 | 54 | if is_mamba: 55 | fn = lambda: model.generate( 56 | input_ids=input_ids, 57 | max_length=max_length, 58 | cg=True, 59 | return_dict_in_generate=True, 60 | output_scores=True, 61 | enable_timing=False, 62 | temperature=args.temperature, 63 | top_k=args.topk, 64 | top_p=args.topp, 65 | ) 66 | else: 67 | fn = lambda: model.generate( 68 | input_ids=input_ids, 69 | attention_mask=attn_mask, 70 | max_length=max_length, 71 | return_dict_in_generate=True, 72 | pad_token_id=tokenizer.eos_token_id, 73 | do_sample=True, 74 | temperature=args.temperature, 75 | top_k=args.topk, 76 | top_p=args.topp, 77 | ) 78 | out = fn() 79 | if args.prompt is not None: 80 | print(tokenizer.batch_decode(out.sequences.tolist())) 81 | 82 | torch.cuda.synchronize() 83 | start = time.time() 84 | for _ in range(repeats): 85 | fn() 86 | torch.cuda.synchronize() 87 | print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}") 88 | print(f"{args.model_name} prompt processing + decoding time: {(time.time() - start) / repeats * 1000:.0f}ms") 89 | -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/selective_scan.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | //////////////////////////////////////////////////////////////////////////////////////////////////// 8 | 9 | struct SSMScanParamsBase { 10 | using index_t = uint32_t; 11 | 12 | int batch, seqlen, n_chunks; 13 | index_t a_batch_stride; 14 | index_t b_batch_stride; 15 | index_t out_batch_stride; 16 | 17 | // Common data pointers. 18 | void *__restrict__ a_ptr; 19 | void *__restrict__ b_ptr; 20 | void *__restrict__ out_ptr; 21 | void *__restrict__ x_ptr; 22 | }; 23 | 24 | //////////////////////////////////////////////////////////////////////////////////////////////////// 25 | 26 | struct SSMParamsBase { 27 | using index_t = uint32_t; 28 | 29 | int batch, dim, seqlen, dstate, n_groups, n_chunks; 30 | int dim_ngroups_ratio; 31 | bool is_variable_B; 32 | bool is_variable_C; 33 | 34 | bool delta_softplus; 35 | 36 | index_t A_d_stride; 37 | index_t A_dstate_stride; 38 | index_t B_batch_stride; 39 | index_t B_d_stride; 40 | index_t B_dstate_stride; 41 | index_t B_group_stride; 42 | index_t C_batch_stride; 43 | index_t C_d_stride; 44 | index_t C_dstate_stride; 45 | index_t C_group_stride; 46 | index_t u_batch_stride; 47 | index_t u_d_stride; 48 | index_t delta_batch_stride; 49 | index_t delta_d_stride; 50 | index_t z_batch_stride; 51 | index_t z_d_stride; 52 | index_t out_batch_stride; 53 | index_t out_d_stride; 54 | index_t out_z_batch_stride; 55 | index_t out_z_d_stride; 56 | 57 | // Common data pointers. 58 | void *__restrict__ A_ptr; 59 | void *__restrict__ B_ptr; 60 | void *__restrict__ C_ptr; 61 | void *__restrict__ D_ptr; 62 | void *__restrict__ u_ptr; 63 | void *__restrict__ delta_ptr; 64 | void *__restrict__ delta_bias_ptr; 65 | void *__restrict__ out_ptr; 66 | void *__restrict__ x_ptr; 67 | void *__restrict__ z_ptr; 68 | void *__restrict__ out_z_ptr; 69 | }; 70 | 71 | struct SSMParamsBwd: public SSMParamsBase { 72 | index_t dout_batch_stride; 73 | index_t dout_d_stride; 74 | index_t dA_d_stride; 75 | index_t dA_dstate_stride; 76 | index_t dB_batch_stride; 77 | index_t dB_group_stride; 78 | index_t dB_d_stride; 79 | index_t dB_dstate_stride; 80 | index_t dC_batch_stride; 81 | index_t dC_group_stride; 82 | index_t dC_d_stride; 83 | index_t dC_dstate_stride; 84 | index_t du_batch_stride; 85 | index_t du_d_stride; 86 | index_t dz_batch_stride; 87 | index_t dz_d_stride; 88 | index_t ddelta_batch_stride; 89 | index_t ddelta_d_stride; 90 | 91 | // Common data pointers. 92 | void *__restrict__ dout_ptr; 93 | void *__restrict__ dA_ptr; 94 | void *__restrict__ dB_ptr; 95 | void *__restrict__ dC_ptr; 96 | void *__restrict__ dD_ptr; 97 | void *__restrict__ du_ptr; 98 | void *__restrict__ dz_ptr; 99 | void *__restrict__ ddelta_ptr; 100 | void *__restrict__ ddelta_bias_ptr; 101 | }; 102 | -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/selective_scan_bwd_bf16_real.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/selective_scan_bwd_fp16_real.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/selective_scan_bwd_fp32_real.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/selective_scan_fwd_bf16.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_fwd_kernel.cuh" 8 | 9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); 10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/selective_scan_fwd_fp16.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_fwd_kernel.cuh" 8 | 9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); 10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/selective_scan_fwd_fp32.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_fwd_kernel.cuh" 8 | 9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); 10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/static_switch.h: -------------------------------------------------------------------------------- 1 | // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h 2 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h 3 | 4 | #pragma once 5 | 6 | /// @param COND - a boolean expression to switch by 7 | /// @param CONST_NAME - a name given for the constexpr bool variable. 8 | /// @param ... - code to execute for true and false 9 | /// 10 | /// Usage: 11 | /// ``` 12 | /// BOOL_SWITCH(flag, BoolConst, [&] { 13 | /// some_function(...); 14 | /// }); 15 | /// ``` 16 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \ 17 | [&] { \ 18 | if (COND) { \ 19 | constexpr bool CONST_NAME = true; \ 20 | return __VA_ARGS__(); \ 21 | } else { \ 22 | constexpr bool CONST_NAME = false; \ 23 | return __VA_ARGS__(); \ 24 | } \ 25 | }() 26 | -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/uninitialized_copy.cuh: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted provided that the following conditions are met: 6 | * * Redistributions of source code must retain the above copyright 7 | * notice, this list of conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright 9 | * notice, this list of conditions and the following disclaimer in the 10 | * documentation and/or other materials provided with the distribution. 11 | * * Neither the name of the NVIDIA CORPORATION nor the 12 | * names of its contributors may be used to endorse or promote products 13 | * derived from this software without specific prior written permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 18 | * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | * 26 | ******************************************************************************/ 27 | 28 | #pragma once 29 | 30 | #include 31 | 32 | #include 33 | 34 | 35 | namespace detail 36 | { 37 | 38 | #if defined(_NVHPC_CUDA) 39 | template 40 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val) 41 | { 42 | // NVBug 3384810 43 | new (ptr) T(::cuda::std::forward(val)); 44 | } 45 | #else 46 | template ::value, 50 | int 51 | >::type = 0> 52 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val) 53 | { 54 | *ptr = ::cuda::std::forward(val); 55 | } 56 | 57 | template ::value, 61 | int 62 | >::type = 0> 63 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val) 64 | { 65 | new (ptr) T(::cuda::std::forward(val)); 66 | } 67 | #endif 68 | 69 | } // namespace detail 70 | -------------------------------------------------------------------------------- /mamba/evals/lm_harness_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import transformers 4 | from transformers import AutoTokenizer 5 | 6 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 7 | 8 | from lm_eval.api.model import LM 9 | from lm_eval.models.huggingface import HFLM 10 | from lm_eval.api.registry import register_model 11 | from lm_eval.__main__ import cli_evaluate 12 | 13 | 14 | @register_model("mamba") 15 | class MambaEvalWrapper(HFLM): 16 | 17 | AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM 18 | 19 | def __init__(self, pretrained="state-spaces/mamba-2.8b", max_length=2048, batch_size=None, device="cuda", 20 | dtype=torch.float16): 21 | LM.__init__(self) 22 | self._model = MambaLMHeadModel.from_pretrained(pretrained, device=device, dtype=dtype) 23 | self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") 24 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id 25 | self.vocab_size = self.tokenizer.vocab_size 26 | self._batch_size = batch_size if batch_size is None else 64 27 | self._max_length = max_length 28 | self._device = torch.device(device) 29 | 30 | @property 31 | def batch_size(self): 32 | return self._batch_size 33 | 34 | def _model_generate(self, context, max_length, stop, **generation_kwargs): 35 | raise NotImplementedError() 36 | 37 | 38 | if __name__ == "__main__": 39 | cli_evaluate() 40 | -------------------------------------------------------------------------------- /mamba/mamba_ssm/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.1" 2 | 3 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn 4 | from mamba_ssm.modules.mamba_simple import Mamba 5 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 6 | -------------------------------------------------------------------------------- /mamba/mamba_ssm/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/mamba/mamba_ssm/models/__init__.py -------------------------------------------------------------------------------- /mamba/mamba_ssm/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/mamba/mamba_ssm/modules/__init__.py -------------------------------------------------------------------------------- /mamba/mamba_ssm/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/mamba/mamba_ssm/ops/__init__.py -------------------------------------------------------------------------------- /mamba/mamba_ssm/ops/triton/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/mamba/mamba_ssm/ops/triton/__init__.py -------------------------------------------------------------------------------- /mamba/mamba_ssm/ops/triton/selective_state_update.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao. 2 | 3 | """We want triton==2.1.0 for this 4 | """ 5 | 6 | import math 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | import triton 11 | import triton.language as tl 12 | 13 | from einops import rearrange, repeat 14 | 15 | 16 | @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) 17 | @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) 18 | @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) 19 | @triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) 20 | @triton.jit 21 | def _selective_scan_update_kernel( 22 | # Pointers to matrices 23 | state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr, 24 | # Matrix dimensions 25 | batch, dim, dstate, 26 | # Strides 27 | stride_state_batch, stride_state_dim, stride_state_dstate, 28 | stride_x_batch, stride_x_dim, 29 | stride_dt_batch, stride_dt_dim, 30 | stride_dt_bias_dim, 31 | stride_A_dim, stride_A_dstate, 32 | stride_B_batch, stride_B_dstate, 33 | stride_C_batch, stride_C_dstate, 34 | stride_D_dim, 35 | stride_z_batch, stride_z_dim, 36 | stride_out_batch, stride_out_dim, 37 | # Meta-parameters 38 | DT_SOFTPLUS: tl.constexpr, 39 | BLOCK_SIZE_M: tl.constexpr, 40 | HAS_DT_BIAS: tl.constexpr, 41 | HAS_D: tl.constexpr, 42 | HAS_Z: tl.constexpr, 43 | BLOCK_SIZE_DSTATE: tl.constexpr, 44 | ): 45 | pid_m = tl.program_id(axis=0) 46 | pid_b = tl.program_id(axis=1) 47 | state_ptr += pid_b * stride_state_batch 48 | x_ptr += pid_b * stride_x_batch 49 | dt_ptr += pid_b * stride_dt_batch 50 | B_ptr += pid_b * stride_B_batch 51 | C_ptr += pid_b * stride_C_batch 52 | if HAS_Z: 53 | z_ptr += pid_b * stride_z_batch 54 | out_ptr += pid_b * stride_out_batch 55 | 56 | offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 57 | offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) 58 | state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate) 59 | x_ptrs = x_ptr + offs_m * stride_x_dim 60 | dt_ptrs = dt_ptr + offs_m * stride_dt_dim 61 | if HAS_DT_BIAS: 62 | dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim 63 | A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate) 64 | B_ptrs = B_ptr + offs_n * stride_B_dstate 65 | C_ptrs = C_ptr + offs_n * stride_C_dstate 66 | if HAS_D: 67 | D_ptrs = D_ptr + offs_m * stride_D_dim 68 | if HAS_Z: 69 | z_ptrs = z_ptr + offs_m * stride_z_dim 70 | out_ptrs = out_ptr + offs_m * stride_out_dim 71 | 72 | state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0) 73 | x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 74 | dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 75 | if HAS_DT_BIAS: 76 | dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 77 | if DT_SOFTPLUS: 78 | dt = tl.log(1.0 + tl.exp(dt)) 79 | A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) 80 | dA = tl.exp(A * dt[:, None]) 81 | B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) 82 | C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) 83 | if HAS_D: 84 | D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 85 | if HAS_Z: 86 | z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 87 | 88 | dB = B[None, :] * dt[:, None] 89 | state = state * dA + dB * x[:, None] 90 | tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) 91 | out = tl.sum(state * C[None, :], axis=1) 92 | if HAS_D: 93 | out += x * D 94 | if HAS_Z: 95 | out *= z * tl.sigmoid(z) 96 | tl.store(out_ptrs, out, mask=offs_m < dim) 97 | 98 | 99 | def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): 100 | """ 101 | Argument: 102 | state: (batch, dim, dstate) 103 | x: (batch, dim) 104 | dt: (batch, dim) 105 | A: (dim, dstate) 106 | B: (batch, dstate) 107 | C: (batch, dstate) 108 | D: (dim,) 109 | z: (batch, dim) 110 | dt_bias: (dim,) 111 | Return: 112 | out: (batch, dim) 113 | """ 114 | batch, dim, dstate = state.shape 115 | assert x.shape == (batch, dim) 116 | assert dt.shape == x.shape 117 | assert A.shape == (dim, dstate) 118 | assert B.shape == (batch, dstate) 119 | assert C.shape == B.shape 120 | if D is not None: 121 | assert D.shape == (dim,) 122 | if z is not None: 123 | assert z.shape == x.shape 124 | if dt_bias is not None: 125 | assert dt_bias.shape == (dim,) 126 | out = torch.empty_like(x) 127 | grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch) 128 | z_strides = ((z.stride(0), z.stride(1)) if z is not None else (0, 0)) 129 | # We don't want autotune since it will overwrite the state 130 | # We instead tune by hand. 131 | BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 132 | else ((16, 4) if dstate <= 32 else 133 | ((8, 4) if dstate <= 64 else 134 | ((4, 4) if dstate <= 128 else 135 | ((4, 8)))))) 136 | with torch.cuda.device(x.device.index): 137 | _selective_scan_update_kernel[grid]( 138 | state, x, dt, dt_bias, A, B, C, D, z, out, 139 | batch, dim, dstate, 140 | state.stride(0), state.stride(1), state.stride(2), 141 | x.stride(0), x.stride(1), 142 | dt.stride(0), dt.stride(1), 143 | dt_bias.stride(0) if dt_bias is not None else 0, 144 | A.stride(0), A.stride(1), 145 | B.stride(0), B.stride(1), 146 | C.stride(0), C.stride(1), 147 | D.stride(0) if D is not None else 0, 148 | z_strides[0], z_strides[1], 149 | out.stride(0), out.stride(1), 150 | dt_softplus, 151 | BLOCK_SIZE_M, 152 | num_warps=num_warps, 153 | ) 154 | return out 155 | 156 | 157 | def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): 158 | """ 159 | Argument: 160 | state: (batch, dim, dstate) 161 | x: (batch, dim) 162 | dt: (batch, dim) 163 | A: (dim, dstate) 164 | B: (batch, dstate) 165 | C: (batch, dstate) 166 | D: (dim,) 167 | z: (batch, dim) 168 | dt_bias: (dim,) 169 | Return: 170 | out: (batch, dim) 171 | """ 172 | batch, dim, dstate = state.shape 173 | assert x.shape == (batch, dim) 174 | assert dt.shape == x.shape 175 | assert A.shape == (dim, dstate) 176 | assert B.shape == (batch, dstate) 177 | assert C.shape == B.shape 178 | if D is not None: 179 | assert D.shape == (dim,) 180 | if z is not None: 181 | assert z.shape == x.shape 182 | if dt_bias is not None: 183 | assert dt_bias.shape == (dim,) 184 | dt = dt + dt_bias 185 | dt = F.softplus(dt) if dt_softplus else dt 186 | dA = torch.exp(rearrange(dt, "b d -> b d 1") * A) # (batch, dim, dstate) 187 | dB = rearrange(dt, "b d -> b d 1") * rearrange(B, "b n -> b 1 n") # (batch, dim, dstate) 188 | state.copy_(state * dA + dB * rearrange(x, "b d -> b d 1")) # (batch, dim, dstate 189 | out = torch.einsum("bdn,bn->bd", state.to(C.dtype), C) 190 | if D is not None: 191 | out += (x * D).to(out.dtype) 192 | return (out if z is None else out * F.silu(z)).to(x.dtype) 193 | -------------------------------------------------------------------------------- /mamba/mamba_ssm/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/mamba/mamba_ssm/utils/__init__.py -------------------------------------------------------------------------------- /mamba/mamba_ssm/utils/hf.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import torch 4 | 5 | from transformers.utils import WEIGHTS_NAME, CONFIG_NAME 6 | from transformers.utils.hub import cached_file 7 | 8 | 9 | def load_config_hf(model_name): 10 | resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False) 11 | return json.load(open(resolved_archive_file)) 12 | 13 | 14 | def load_state_dict_hf(model_name, device=None, dtype=None): 15 | # If not fp32, then we don't want to load directly to the GPU 16 | mapped_device = "cpu" if dtype not in [torch.float32, None] else device 17 | resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False) 18 | return torch.load(resolved_archive_file, map_location=mapped_device) 19 | # Convert dtype before moving to GPU to save memory 20 | if dtype is not None: 21 | state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()} 22 | state_dict = {k: v.to(device=device) for k, v in state_dict.items()} 23 | return state_dict 24 | -------------------------------------------------------------------------------- /mamba/test_mamba_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from mamba_ssm import Mamba 3 | 4 | batch, length, dim = 2, 64, 768 5 | x = torch.randn(batch, length, dim).to("cuda") 6 | model = Mamba( 7 | # This module uses roughly 3 * expand * d_model^2 parameters 8 | d_model=dim, # Model dimension d_model 9 | d_state=16, # SSM state expansion factor # 64 10 | d_conv=4, # Local convolution width 11 | expand=2, # Block expansion factor 12 | use_fast_path=False, 13 | ).to("cuda") 14 | y = model(x) 15 | assert y.shape == x.shape 16 | -------------------------------------------------------------------------------- /mamba/tests/ops/triton/test_selective_state_update.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023, Tri Dao. 2 | 3 | import math 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import pytest 8 | 9 | from einops import rearrange 10 | 11 | from mamba_ssm.ops.triton.selective_state_update import selective_state_update, selective_state_update_ref 12 | 13 | 14 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) 15 | # @pytest.mark.parametrize('itype', [torch.float16]) 16 | @pytest.mark.parametrize("has_z", [False, True]) 17 | # @pytest.mark.parametrize('has_z', [True]) 18 | @pytest.mark.parametrize("dstate", [16, 32, 64]) 19 | # @pytest.mark.parametrize("dstate", [16]) 20 | @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) 21 | # @pytest.mark.parametrize("dim", [2048]) 22 | def test_causal_conv1d_update(dim, dstate, has_z, itype): 23 | device = "cuda" 24 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) 25 | if itype == torch.bfloat16: 26 | rtol, atol = 1e-2, 5e-2 27 | # set seed 28 | torch.random.manual_seed(0) 29 | batch_size = 2 30 | state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device) 31 | x = torch.randn(batch_size, dim, device=device, dtype=itype) 32 | dt = torch.randn(batch_size, dim, device=device, dtype=itype) 33 | dt_bias = torch.rand(dim, device=device) - 4.0 34 | A = -torch.rand(dim, dstate, device=device) - 1.0 35 | B = torch.randn(batch_size, dstate, device=device) 36 | C = torch.randn(batch_size, dstate, device=device) 37 | D = torch.randn(dim, device=device) 38 | if has_z: 39 | z = torch.randn_like(x) 40 | else: 41 | z = None 42 | state_ref = state.detach().clone() 43 | out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) 44 | out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) 45 | 46 | print(f"Output max diff: {(out - out_ref).abs().max().item()}") 47 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") 48 | assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) 49 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) 50 | --------------------------------------------------------------------------------