├── LICENSE
├── README.md
├── causal-conv1d
├── AUTHORS
├── LICENSE
├── README.md
├── causal_conv1d
│ ├── __init__.py
│ └── causal_conv1d_interface.py
├── csrc
│ ├── causal_conv1d.cpp
│ ├── causal_conv1d.h
│ ├── causal_conv1d_bwd.cu
│ ├── causal_conv1d_common.h
│ ├── causal_conv1d_fwd.cu
│ ├── causal_conv1d_update.cu
│ └── static_switch.h
├── setup.py
└── tests
│ └── test_causal_conv1d.py
├── code
├── augmentations
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ └── ctaugment.cpython-38.pyc
│ └── ctaugment.py
├── config.py
├── configs
│ ├── swin_tiny_patch4_window7_224_lite.yaml
│ └── vmamba_tiny.yaml
├── dataloaders
│ ├── __pycache__
│ │ ├── dataset.cpython-310.pyc
│ │ ├── dataset.cpython-311.pyc
│ │ ├── dataset.cpython-38.pyc
│ │ ├── dataset_s2l.cpython-38.pyc
│ │ ├── dataset_semi.cpython-38.pyc
│ │ ├── utils.cpython-310.pyc
│ │ ├── utils.cpython-311.pyc
│ │ └── utils.cpython-38.pyc
│ ├── acdc_data_processing.py
│ ├── acdc_pseudo_label_random_walker.py
│ ├── dataset.py
│ ├── dataset_s2l.py
│ ├── dataset_semi.py
│ └── utils.py
├── networks
│ ├── VoxResNet.py
│ ├── __pycache__
│ │ ├── attention.cpython-310.pyc
│ │ ├── attention.cpython-38.pyc
│ │ ├── config.cpython-310.pyc
│ │ ├── config.cpython-38.pyc
│ │ ├── discriminator.cpython-38.pyc
│ │ ├── efficient_encoder.cpython-310.pyc
│ │ ├── efficient_encoder.cpython-38.pyc
│ │ ├── efficientunet.cpython-310.pyc
│ │ ├── efficientunet.cpython-38.pyc
│ │ ├── enet.cpython-310.pyc
│ │ ├── enet.cpython-38.pyc
│ │ ├── mamba_sys.cpython-310.pyc
│ │ ├── net_factory.cpython-310.pyc
│ │ ├── net_factory.cpython-38.pyc
│ │ ├── neural_network.cpython-310.pyc
│ │ ├── neural_network.cpython-38.pyc
│ │ ├── nnunet.cpython-310.pyc
│ │ ├── nnunet.cpython-38.pyc
│ │ ├── pnet.cpython-310.pyc
│ │ ├── pnet.cpython-38.pyc
│ │ ├── swin_transformer_unet_skip_expand_decoder_sys.cpython-310.pyc
│ │ ├── swin_transformer_unet_skip_expand_decoder_sys.cpython-38.pyc
│ │ ├── unet.cpython-310.pyc
│ │ ├── unet.cpython-38.pyc
│ │ ├── vision_mamba.cpython-310.pyc
│ │ ├── vision_transformer.cpython-310.pyc
│ │ └── vision_transformer.cpython-38.pyc
│ ├── attention.py
│ ├── attention_unet.py
│ ├── config.py
│ ├── discriminator.py
│ ├── efficient_encoder.py
│ ├── efficientunet.py
│ ├── encoder_tool.py
│ ├── enet.py
│ ├── grid_attention_layer.py
│ ├── mamba_sys.py
│ ├── net_factory.py
│ ├── net_factory_3d.py
│ ├── networks_other.py
│ ├── neural_network.py
│ ├── nnunet.py
│ ├── pnet.py
│ ├── segmamba.py
│ ├── swin_transformer_unet_skip_expand_decoder_sys.py
│ ├── unet.py
│ ├── unet_3D.py
│ ├── utils.py
│ ├── vision_mamba.py
│ ├── vision_transformer.py
│ └── vnet.py
├── pretrained_ckpt
│ └── readme.txt
├── scribbles_generator.py
├── test_2D.py
├── test_2D_ViT.py
├── train_weak_mamba_unet.py
├── train_weakly_supervised_pCE_2D.py
├── train_weakly_supervised_pCE_2D_ViT.py
├── train_weakly_supervised_ustm_2D_ViT.py
├── utils
│ ├── __pycache__
│ │ ├── gate_crf_loss.cpython-310.pyc
│ │ ├── gate_crf_loss.cpython-38.pyc
│ │ ├── losses.cpython-310.pyc
│ │ ├── losses.cpython-38.pyc
│ │ ├── metrics.cpython-310.pyc
│ │ ├── metrics.cpython-38.pyc
│ │ ├── ramps.cpython-310.pyc
│ │ └── ramps.cpython-38.pyc
│ ├── gate_crf_loss.py
│ ├── losses.py
│ ├── metrics.py
│ ├── ramps.py
│ └── util.py
└── val_2D.py
├── data
└── readme.txt
├── img
├── results.png
├── wslframework.png
└── wslintro.png
└── mamba
├── .gitmodules
├── AUTHORS
├── LICENSE
├── README.md
├── assets
└── selection.png
├── benchmarks
└── benchmark_generation_mamba_simple.py
├── csrc
└── selective_scan
│ ├── reverse_scan.cuh
│ ├── selective_scan.cpp
│ ├── selective_scan.h
│ ├── selective_scan_bwd_bf16_complex.cu
│ ├── selective_scan_bwd_bf16_real.cu
│ ├── selective_scan_bwd_fp16_complex.cu
│ ├── selective_scan_bwd_fp16_real.cu
│ ├── selective_scan_bwd_fp32_complex.cu
│ ├── selective_scan_bwd_fp32_real.cu
│ ├── selective_scan_bwd_kernel.cuh
│ ├── selective_scan_common.h
│ ├── selective_scan_fwd_bf16.cu
│ ├── selective_scan_fwd_fp16.cu
│ ├── selective_scan_fwd_fp32.cu
│ ├── selective_scan_fwd_kernel.cuh
│ ├── static_switch.h
│ └── uninitialized_copy.cuh
├── evals
└── lm_harness_eval.py
├── mamba_ssm
├── __init__.py
├── models
│ ├── __init__.py
│ └── mixer_seq_simple.py
├── modules
│ ├── __init__.py
│ └── mamba_simple.py
├── ops
│ ├── __init__.py
│ ├── selective_scan_interface.py
│ └── triton
│ │ ├── __init__.py
│ │ ├── layernorm.py
│ │ └── selective_state_update.py
└── utils
│ ├── __init__.py
│ ├── generation.py
│ └── hf.py
├── setup.py
├── test_mamba_module.py
└── tests
└── ops
├── test_selective_scan.py
└── triton
└── test_selective_state_update.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 ziyangwang007
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
Weak-Mamba-UNet:
Visual Mamba Makes CNN and ViT Work Better for Scribble-based Medical Image Segmentation
3 |
4 | [](https://arxiv.org/abs/2402.10887)
5 |
6 |
7 |
8 | > This repo provides an implementation of the training and inference pipeline for [Weak-Mamba-UNet](https://arxiv.org/abs/2402.10887).
9 |
10 |
11 | ## Contents ###
12 | - [Graphical Abstract](#Graphical-Abstract)
13 | - [Results](#Results)
14 | - [Requirements](#Requirements)
15 | - [Usage](#Usage)
16 | - [Reference](#Reference)
17 | - [Contact](#Contact)
18 |
19 |
20 |
21 |
22 | ## Graphical Abstract
23 |
24 | The introduction of Scribble Annotation
25 |
26 |
27 |
28 | The proposed Framework
29 |
30 |
31 |
32 |
33 | ## Results
34 |
35 |
36 |
37 |
38 |
39 |
40 | ## Requirements
41 | * Pytorch, MONAI
42 | * Some basic python packages: Torchio, Numpy, Scikit-image, SimpleITK, Scipy, Medpy, nibabel, tqdm ......
43 |
44 | ```shell
45 | cd casual-conv1d
46 |
47 | python setup.py install
48 | ```
49 |
50 | ```shell
51 | cd mamba
52 |
53 | python setup.py install
54 | ```
55 |
56 |
57 |
58 | ## Usage
59 |
60 | 1. Clone the repo:
61 | ```shell
62 | git clone https://github.com/ziyangwang007/Weak-Mamba-UNet.git
63 | cd Weak-Mamba-UNet
64 | ```
65 |
66 | 2. Download Pretrained Model
67 |
68 | Download through [Google Drive](https://drive.google.com/file/d/14RzbbBDjbKbgr0ordKlWbb69EFkHuplr/view?usp=sharing) for SwinUNet, and [[Google Drive]](https://drive.google.com/file/d/1uUPsr7XeqayCxlspqBHbg5zIWx0JYtSX/view?usp=sharing) for Mamba-UNet, and save in `../code/pretrained_ckpt`.
69 |
70 | 3. Download Dataset
71 |
72 | Download ACDC for Weak-Supervised learning through [[Google Drive]](https://drive.google.com/file/d/1XR_Id0wdvXY9QeKtdOdgJHKVJ-nVr2j1/view?usp=sharing), or [[Baidu Netdisk]](https://pan.baidu.com/s/1dHkp9daqE3kLEbAP6zl7Jw) with passcode: 'rwv2', and save in `../data/ACDC` folder.
73 |
74 |
75 | 4. Train
76 |
77 | ```shell
78 | cd code
79 | ```
80 |
81 | 5. Train 2D UNet with pCE
82 |
83 | ```shell
84 | python train_weakly_supervised_pCE_2D.py
85 | ```
86 |
87 | 6. Train 2D SwinUNet with pCE
88 | ```shell
89 | python train_weakly_supervised_pCE_2D_ViT.py
90 | ```
91 |
92 | 7. Train 2D SwinUNet with MT and pCE
93 | ```shell
94 | python train_weakly_supervised_ustm_2D_ViT.py
95 | ```
96 |
97 | 8. Train 2D Semi-Mamba-UNet with pCE
98 | ```shell
99 | python train_weak_mamba_unet.py
100 | ```
101 |
102 | 9. Test
103 |
104 | Test CNN-based model
105 | ```shell
106 | python test_2D.py -root_path ../data/XXX --exp ACDC/XXX
107 | ```
108 | Test ViT/Mamba-based model
109 | ```shell
110 | python test_2D_fully.py -root_path ../data/XXX --exp ACDC/XXX
111 | ```
112 |
113 |
114 | ## Reference
115 | Wang, Ziyang, et al. "Mamba-unet: Unet-like pure visual mamba for medical image segmentation." arXiv preprint arXiv:2402.05079 (2024).
116 |
117 | Wang, Ziyang, and Chao Ma. "Weak-Mamba-UNet: Visual Mamba Makes CNN and ViT Work Better for Scribble-based Medical Image Segmentation." arXiv preprint arXiv:2402.10887 (2024).
118 |
119 |
120 | ```bibtex
121 | @article{wang2024mamba,
122 | title={Mamba-unet: Unet-like pure visual mamba for medical image segmentation},
123 | author={Wang, Ziyang and Zheng, Jian-Qing and Zhang, Yichi and Cui, Ge and Li, Lei},
124 | journal={arXiv preprint arXiv:2402.05079},
125 | year={2024}
126 | }
127 |
128 | @article{wang2024weakmamba,
129 | title={Weak-Mamba-UNet: Visual Mamba Makes CNN and ViT Work Better for Scribble-based Medical Image Segmentation},
130 | author={Wang, Ziyang and Ma, Chao},
131 | journal={arXiv preprint arXiv:2402.10887},
132 | year={2024}
133 | }
134 | ```
135 | ## Contact
136 |
137 | ziyang [dot] wang17 [at] gmail [dot] com
138 |
--------------------------------------------------------------------------------
/causal-conv1d/AUTHORS:
--------------------------------------------------------------------------------
1 | Tri Dao, tri@tridao.me
2 |
--------------------------------------------------------------------------------
/causal-conv1d/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/causal-conv1d/README.md:
--------------------------------------------------------------------------------
1 | # Causal depthwise conv1d in CUDA with a PyTorch interface
2 |
--------------------------------------------------------------------------------
/causal-conv1d/causal_conv1d/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "1.0.0"
2 |
3 | from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
4 |
--------------------------------------------------------------------------------
/causal-conv1d/causal_conv1d/causal_conv1d_interface.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Tri Dao.
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 |
7 | import causal_conv1d_cuda
8 |
9 |
10 | class CausalConv1dFn(torch.autograd.Function):
11 | @staticmethod
12 | def forward(ctx, x, weight, bias=None, activation=None):
13 | if activation not in [None, "silu", "swish"]:
14 | raise NotImplementedError("activation must be None, silu, or swish")
15 | if x.stride(2) != 1 and x.stride(1) != 1:
16 | x = x.contiguous()
17 | bias = bias.contiguous() if bias is not None else None
18 | ctx.save_for_backward(x, weight, bias)
19 | ctx.activation = activation in ["silu", "swish"]
20 | out = causal_conv1d_cuda.causal_conv1d_fwd(x, weight, bias, ctx.activation)
21 | return out
22 |
23 | @staticmethod
24 | def backward(ctx, dout):
25 | x, weight, bias = ctx.saved_tensors
26 | if dout.stride(2) != 1 and dout.stride(1) != 1:
27 | dout = dout.contiguous()
28 | # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
29 | # backward of conv1d with the backward of chunk).
30 | # Here we just pass in None and dx will be allocated in the C++ code.
31 | dx, dweight, dbias = causal_conv1d_cuda.causal_conv1d_bwd(
32 | x, weight, bias, dout, None, ctx.activation
33 | )
34 | return dx, dweight, dbias if bias is not None else None, None
35 |
36 |
37 | def causal_conv1d_fn(x, weight, bias=None, activation=None):
38 | """
39 | x: (batch, dim, seqlen)
40 | weight: (dim, width)
41 | bias: (dim,)
42 | activation: either None or "silu" or "swish"
43 |
44 | out: (batch, dim, seqlen)
45 | """
46 | return CausalConv1dFn.apply(x, weight, bias, activation)
47 |
48 |
49 | def causal_conv1d_ref(x, weight, bias=None, activation=None):
50 | """
51 | x: (batch, dim, seqlen)
52 | weight: (dim, width)
53 | bias: (dim,)
54 |
55 | out: (batch, dim, seqlen)
56 | """
57 | if activation not in [None, "silu", "swish"]:
58 | raise NotImplementedError("activation must be None, silu, or swish")
59 | dtype_in = x.dtype
60 | x = x.to(weight.dtype)
61 | seqlen = x.shape[-1]
62 | dim, width = weight.shape
63 | out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
64 | out = out[..., :seqlen]
65 | return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
66 |
67 |
68 | def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None):
69 | """
70 | x: (batch, dim)
71 | conv_state: (batch, dim, width)
72 | weight: (dim, width)
73 | bias: (dim,)
74 |
75 | out: (batch, dim)
76 | """
77 | if activation not in [None, "silu", "swish"]:
78 | raise NotImplementedError("activation must be None, silu, or swish")
79 | activation = activation in ["silu", "swish"]
80 | return causal_conv1d_cuda.causal_conv1d_update(x, conv_state, weight, bias, activation)
81 |
82 |
83 | def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None):
84 | """
85 | x: (batch, dim)
86 | conv_state: (batch, dim, width)
87 | weight: (dim, width)
88 | bias: (dim,)
89 |
90 | out: (batch, dim)
91 | """
92 | if activation not in [None, "silu", "swish"]:
93 | raise NotImplementedError("activation must be None, silu, or swish")
94 | dtype_in = x.dtype
95 | batch, dim = x.shape
96 | width = weight.shape[1]
97 | assert conv_state.shape == (batch, dim, width)
98 | assert weight.shape == (dim, width)
99 | conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
100 | conv_state[:, :, -1] = x
101 | out = torch.sum(conv_state * weight, dim=-1) # (B D)
102 | if bias is not None:
103 | out += bias
104 | return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
105 |
--------------------------------------------------------------------------------
/causal-conv1d/csrc/causal_conv1d.h:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | #pragma once
6 |
7 | ////////////////////////////////////////////////////////////////////////////////////////////////////
8 |
9 | struct ConvParamsBase {
10 | using index_t = uint32_t;
11 |
12 | int batch, dim, seqlen, width;
13 | bool silu_activation;
14 |
15 | index_t x_batch_stride;
16 | index_t x_c_stride;
17 | index_t x_l_stride;
18 | index_t weight_c_stride;
19 | index_t weight_width_stride;
20 | index_t out_batch_stride;
21 | index_t out_c_stride;
22 | index_t out_l_stride;
23 |
24 | index_t conv_state_batch_stride;
25 | index_t conv_state_c_stride;
26 | index_t conv_state_l_stride;
27 |
28 | // Common data pointers.
29 | void *__restrict__ x_ptr;
30 | void *__restrict__ weight_ptr;
31 | void *__restrict__ bias_ptr;
32 | void *__restrict__ out_ptr;
33 |
34 | void *__restrict__ conv_state_ptr;
35 | };
36 |
37 | struct ConvParamsBwd: public ConvParamsBase {
38 | index_t dx_batch_stride;
39 | index_t dx_c_stride;
40 | index_t dx_l_stride;
41 | index_t dweight_c_stride;
42 | index_t dweight_width_stride;
43 | index_t dout_batch_stride;
44 | index_t dout_c_stride;
45 | index_t dout_l_stride;
46 |
47 | // Common data pointers.
48 | void *__restrict__ dx_ptr;
49 | void *__restrict__ dweight_ptr;
50 | void *__restrict__ dbias_ptr;
51 | void *__restrict__ dout_ptr;
52 | };
53 |
54 |
--------------------------------------------------------------------------------
/causal-conv1d/csrc/causal_conv1d_common.h:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | #pragma once
6 |
7 | #include
8 | #include
9 |
10 | ////////////////////////////////////////////////////////////////////////////////////////////////////
11 |
12 | template struct BytesToType {};
13 |
14 | template<> struct BytesToType<16> {
15 | using Type = uint4;
16 | static_assert(sizeof(Type) == 16);
17 | };
18 |
19 | template<> struct BytesToType<8> {
20 | using Type = uint64_t;
21 | static_assert(sizeof(Type) == 8);
22 | };
23 |
24 | template<> struct BytesToType<4> {
25 | using Type = uint32_t;
26 | static_assert(sizeof(Type) == 4);
27 | };
28 |
29 | template<> struct BytesToType<2> {
30 | using Type = uint16_t;
31 | static_assert(sizeof(Type) == 2);
32 | };
33 |
34 | template<> struct BytesToType<1> {
35 | using Type = uint8_t;
36 | static_assert(sizeof(Type) == 1);
37 | };
38 |
39 | ////////////////////////////////////////////////////////////////////////////////////////////////////
40 |
41 | template
42 | struct SumOp {
43 | __device__ inline T operator()(T const & x, T const & y) { return x + y; }
44 | };
45 |
46 | template
47 | struct Allreduce {
48 | static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
49 | template
50 | static __device__ inline T run(T x, Operator &op) {
51 | constexpr int OFFSET = THREADS / 2;
52 | x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
53 | return Allreduce::run(x, op);
54 | }
55 | };
56 |
57 | template<>
58 | struct Allreduce<2> {
59 | template
60 | static __device__ inline T run(T x, Operator &op) {
61 | x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
62 | return x;
63 | }
64 | };
65 |
--------------------------------------------------------------------------------
/causal-conv1d/csrc/causal_conv1d_update.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | #include
6 | #include
7 | #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
8 |
9 | #include
10 | #include
11 |
12 | #include "causal_conv1d.h"
13 | #include "causal_conv1d_common.h"
14 | #include "static_switch.h"
15 |
16 | template
17 | struct Causal_conv1d_update_kernel_traits {
18 | using input_t = input_t_;
19 | using weight_t = weight_t_;
20 | static constexpr int kNThreads = kNThreads_;
21 | static constexpr int kWidth = kWidth_;
22 | static constexpr int kNBytes = sizeof(input_t);
23 | static_assert(kNBytes == 2 || kNBytes == 4);
24 | };
25 |
26 | template
27 | __global__ __launch_bounds__(Ktraits::kNThreads)
28 | void causal_conv1d_update_kernel(ConvParamsBase params) {
29 | constexpr int kWidth = Ktraits::kWidth;
30 | constexpr int kNThreads = Ktraits::kNThreads;
31 | using input_t = typename Ktraits::input_t;
32 | using weight_t = typename Ktraits::weight_t;
33 |
34 | const int tidx = threadIdx.x;
35 | const int batch_id = blockIdx.x;
36 | const int channel_id = blockIdx.y * kNThreads + tidx;
37 | input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride
38 | + channel_id * params.x_c_stride;
39 | input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride
40 | + channel_id * params.conv_state_c_stride;
41 | weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride;
42 | input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride
43 | + channel_id * params.out_c_stride;
44 | float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]);
45 |
46 | float weight_vals[kWidth] = {0};
47 | if (channel_id < params.dim) {
48 | #pragma unroll
49 | for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
50 | }
51 |
52 | float x_vals[kWidth] = {0};
53 | if (channel_id < params.dim) {
54 | #pragma unroll
55 | for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); }
56 | x_vals[kWidth - 1] = float(x[0]);
57 | #pragma unroll
58 | for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); }
59 | }
60 |
61 | float out_val = bias_val;
62 | #pragma unroll
63 | for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; }
64 | if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
65 | if (channel_id < params.dim) { out[0] = input_t(out_val); }
66 | }
67 |
68 | template
69 | void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
70 | using Ktraits = Causal_conv1d_update_kernel_traits;
71 | dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
72 | auto kernel = &causal_conv1d_update_kernel;
73 | kernel<<>>(params);
74 | C10_CUDA_KERNEL_LAUNCH_CHECK();
75 | }
76 |
77 | template
78 | void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
79 | if (params.width == 2) {
80 | causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
81 | } else if (params.width == 3) {
82 | causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
83 | } else if (params.width == 4) {
84 | causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
85 | }
86 | }
87 |
88 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
89 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
90 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
91 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
92 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
93 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
94 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
95 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
96 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/causal-conv1d/csrc/static_switch.h:
--------------------------------------------------------------------------------
1 | // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
2 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
3 |
4 | #pragma once
5 |
6 | /// @param COND - a boolean expression to switch by
7 | /// @param CONST_NAME - a name given for the constexpr bool variable.
8 | /// @param ... - code to execute for true and false
9 | ///
10 | /// Usage:
11 | /// ```
12 | /// BOOL_SWITCH(flag, BoolConst, [&] {
13 | /// some_function(...);
14 | /// });
15 | /// ```
16 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \
17 | [&] { \
18 | if (COND) { \
19 | static constexpr bool CONST_NAME = true; \
20 | return __VA_ARGS__(); \
21 | } else { \
22 | static constexpr bool CONST_NAME = false; \
23 | return __VA_ARGS__(); \
24 | } \
25 | }()
26 |
--------------------------------------------------------------------------------
/causal-conv1d/tests/test_causal_conv1d.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2023, Tri Dao.
2 |
3 | import math
4 |
5 | import torch
6 | import pytest
7 |
8 | from einops import rearrange
9 |
10 | from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_ref
11 | from causal_conv1d.causal_conv1d_interface import causal_conv1d_update, causal_conv1d_update_ref
12 |
13 |
14 | @pytest.mark.parametrize("channel_last", [False, True])
15 | # @pytest.mark.parametrize('channel_last', [True])
16 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
17 | # @pytest.mark.parametrize('itype', [torch.float16])
18 | @pytest.mark.parametrize("silu_activation", [False, True])
19 | # @pytest.mark.parametrize('silu_activation', [True])
20 | @pytest.mark.parametrize("has_bias", [False, True])
21 | # @pytest.mark.parametrize('has_bias', [True])
22 | @pytest.mark.parametrize("width", [2, 3, 4])
23 | # @pytest.mark.parametrize('width', [2])
24 | @pytest.mark.parametrize(
25 | "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
26 | )
27 | # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
28 | # @pytest.mark.parametrize('seqlen', [128])
29 | def test_causal_conv1d(seqlen, width, has_bias, silu_activation, itype, channel_last):
30 | device = "cuda"
31 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
32 | if itype == torch.bfloat16:
33 | rtol, atol = 1e-2, 5e-2
34 | rtolw, atolw = (1e-3, 1e-3)
35 | # set seed
36 | torch.random.manual_seed(0)
37 | batch_size = 2
38 | # batch_size = 1
39 | dim = 4096 + 32 # Try dim not divisible by 64
40 | # dim = 64
41 | if not channel_last:
42 | x = torch.randn(batch_size, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_()
43 | else:
44 | x = rearrange(
45 | torch.randn(batch_size, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
46 | ).requires_grad_()
47 | weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
48 | if has_bias:
49 | bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
50 | else:
51 | bias = None
52 | x_ref = x.detach().clone().requires_grad_()
53 | weight_ref = weight.detach().clone().requires_grad_()
54 | bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
55 | activation = None if not silu_activation else "silu"
56 | out = causal_conv1d_fn(x, weight, bias, activation=activation)
57 | out_ref = causal_conv1d_ref(x_ref, weight_ref, bias_ref, activation=activation)
58 |
59 | print(f"Output max diff: {(out - out_ref).abs().max().item()}")
60 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
61 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
62 |
63 | g = torch.randn_like(out)
64 | out_ref.backward(g)
65 | out.backward(g)
66 |
67 | print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}")
68 | print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}")
69 | if has_bias:
70 | print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}")
71 |
72 | assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol)
73 | assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw)
74 | if has_bias:
75 | assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw)
76 |
77 |
78 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
79 | # @pytest.mark.parametrize('itype', [torch.float16])
80 | @pytest.mark.parametrize("silu_activation", [False, True])
81 | # @pytest.mark.parametrize('silu_activation', [False])
82 | @pytest.mark.parametrize("has_bias", [False, True])
83 | # @pytest.mark.parametrize('has_bias', [True])
84 | @pytest.mark.parametrize("width", [2, 3, 4])
85 | # @pytest.mark.parametrize('width', [2])
86 | @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
87 | # @pytest.mark.parametrize("dim", [2048])
88 | def test_causal_conv1d_update(dim, width, has_bias, silu_activation, itype):
89 | device = "cuda"
90 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
91 | if itype == torch.bfloat16:
92 | rtol, atol = 1e-2, 5e-2
93 | rtolw, atolw = (1e-3, 1e-3)
94 | # set seed
95 | torch.random.manual_seed(0)
96 | batch_size = 2
97 | # batch_size = 1
98 | # dim = 64
99 | x = torch.randn(batch_size, dim, device=device, dtype=itype)
100 | conv_state = torch.randn(batch_size, dim, width, device=device, dtype=itype)
101 | weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
102 | if has_bias:
103 | bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
104 | else:
105 | bias = None
106 | conv_state_ref = conv_state.detach().clone()
107 | activation = None if not silu_activation else "silu"
108 | out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation)
109 | out_ref = causal_conv1d_update_ref(x, conv_state_ref, weight, bias, activation=activation)
110 |
111 | print(f"Output max diff: {(out - out_ref).abs().max().item()}")
112 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
113 | assert torch.equal(conv_state, conv_state_ref)
114 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
115 |
116 |
117 | # @pytest.mark.parametrize("channel_last", [False, True])
118 | @pytest.mark.parametrize('channel_last', [True])
119 | # @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
120 | @pytest.mark.parametrize('itype', [torch.bfloat16])
121 | # @pytest.mark.parametrize("silu_activation", [False, True])
122 | @pytest.mark.parametrize('silu_activation', [True])
123 | # @pytest.mark.parametrize("has_bias", [False, True])
124 | @pytest.mark.parametrize('has_bias', [True])
125 | # @pytest.mark.parametrize("width", [2, 3, 4])
126 | @pytest.mark.parametrize('width', [4])
127 | @pytest.mark.parametrize(
128 | # "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
129 | "seqlen", [2048]
130 | )
131 | # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
132 | # @pytest.mark.parametrize('seqlen', [128])
133 | def test_causal_conv1d_race_condition(seqlen, width, has_bias, silu_activation, itype, channel_last):
134 | device = "cuda"
135 | # set seed
136 | torch.random.manual_seed(0)
137 | batch_size = 2
138 | # batch_size = 1
139 | dim = 4096 + 32 # Try dim not divisible by 64
140 | # dim = 64
141 | if not channel_last:
142 | x = torch.randn(batch_size, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_()
143 | else:
144 | x = rearrange(
145 | torch.randn(batch_size, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
146 | ).requires_grad_()
147 | weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
148 | if has_bias:
149 | bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
150 | else:
151 | bias = None
152 | activation = None if not silu_activation else "silu"
153 | out0 = causal_conv1d_fn(x, weight, bias, activation=activation)
154 | g = torch.randn_like(out0)
155 | dx0, dw0, db0 = torch.autograd.grad(out0, (x, weight, bias), g)
156 | dw_atol = 1e-4
157 | db_atol = 1e-4
158 |
159 | for i in range(10000):
160 | out = causal_conv1d_fn(x, weight, bias, activation=activation)
161 | dx, dw, db = torch.autograd.grad(out, (x, weight, bias), g)
162 | dw_equal = torch.allclose(dw, dw0, atol=dw_atol)
163 | # if not dw_equal:
164 | # breakpoint()
165 | if has_bias:
166 | db_equal = torch.allclose(db, db0, atol=db_atol)
167 | # if not db_equal:
168 | # breakpoint()
169 | assert torch.equal(out, out0)
170 | assert torch.equal(dx, dx0)
171 | assert dw_equal
172 | if has_bias:
173 | assert dw_equal
174 |
--------------------------------------------------------------------------------
/code/augmentations/__init__.py:
--------------------------------------------------------------------------------
1 | import json
2 | from collections import OrderedDict
3 |
4 | from augmentations.ctaugment import *
5 |
6 |
7 | class StorableCTAugment(CTAugment):
8 | def load_state_dict(self, state):
9 | for k in ["decay", "depth", "th", "rates"]:
10 | assert k in state, "{} not in {}".format(k, state.keys())
11 | setattr(self, k, state[k])
12 |
13 | def state_dict(self):
14 | return OrderedDict(
15 | [(k, getattr(self, k)) for k in ["decay", "depth", "th", "rates"]]
16 | )
17 |
18 |
19 | def get_default_cta():
20 | return StorableCTAugment()
21 |
22 |
23 | def cta_apply(pil_img, ops):
24 | if ops is None:
25 | return pil_img
26 | for op, args in ops:
27 | pil_img = OPS[op].f(pil_img, *args)
28 | return pil_img
29 |
30 |
31 | def deserialize(policy_str):
32 | return [OP(f=x[0], bins=x[1]) for x in json.loads(policy_str)]
33 |
34 |
35 | def stats(cta):
36 | return "\n".join(
37 | "%-16s %s"
38 | % (
39 | k,
40 | " / ".join(
41 | " ".join("%.2f" % x for x in cta.rate_to_p(rate))
42 | for rate in cta.rates[k]
43 | ),
44 | )
45 | for k in sorted(OPS.keys())
46 | )
47 |
48 |
49 | def interleave(x, batch, inverse=False):
50 | """
51 | TF code
52 | def interleave(x, batch):
53 | s = x.get_shape().as_list()
54 | return tf.reshape(tf.transpose(tf.reshape(x, [-1, batch] + s[1:]), [1, 0] + list(range(2, 1+len(s)))), [-1] + s[1:])
55 | """
56 | shape = x.shape
57 | axes = [batch, -1] if inverse else [-1, batch]
58 | return x.reshape(*axes, *shape[1:]).transpose(0, 1).reshape(-1, *shape[1:])
59 |
60 |
61 | def deinterleave(x, batch):
62 | return interleave(x, batch, inverse=True)
--------------------------------------------------------------------------------
/code/augmentations/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/augmentations/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/code/augmentations/__pycache__/ctaugment.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/augmentations/__pycache__/ctaugment.cpython-38.pyc
--------------------------------------------------------------------------------
/code/augmentations/ctaugment.py:
--------------------------------------------------------------------------------
1 | # https://raw.githubusercontent.com/google-research/fixmatch/master/libml/ctaugment.py
2 | #
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """Control Theory based self-augmentation, modified from https://github.com/vfdev-5/FixMatch-pytorch"""
17 | import random
18 | import torch
19 | from collections import namedtuple
20 |
21 | import numpy as np
22 | from scipy.ndimage.interpolation import zoom
23 | from PIL import Image, ImageOps, ImageEnhance, ImageFilter
24 |
25 |
26 | OPS = {}
27 | OP = namedtuple("OP", ("f", "bins"))
28 | Sample = namedtuple("Sample", ("train", "probe"))
29 |
30 |
31 | def register(*bins):
32 | def wrap(f):
33 | OPS[f.__name__] = OP(f, bins)
34 | return f
35 |
36 | return wrap
37 |
38 |
39 | class CTAugment(object):
40 | def __init__(self, depth=2, th=0.85, decay=0.99):
41 | self.decay = decay
42 | self.depth = depth
43 | self.th = th
44 | self.rates = {}
45 | for k, op in OPS.items():
46 | self.rates[k] = tuple([np.ones(x, "f") for x in op.bins])
47 |
48 | def rate_to_p(self, rate):
49 | p = rate + (1 - self.decay) # Avoid to have all zero.
50 | p = p / p.max()
51 | p[p < self.th] = 0
52 | return p
53 |
54 | def policy(self, probe, weak):
55 | num_strong_ops = 11
56 | kl_weak = list(OPS.keys())[num_strong_ops:]
57 | kl_strong = list(OPS.keys())[:num_strong_ops]
58 |
59 | if weak:
60 | kl = kl_weak
61 | else:
62 | kl = kl_strong
63 |
64 | v = []
65 | if probe:
66 | for _ in range(self.depth):
67 | k = random.choice(kl)
68 | bins = self.rates[k]
69 | rnd = np.random.uniform(0, 1, len(bins))
70 | v.append(OP(k, rnd.tolist()))
71 | return v
72 | for _ in range(self.depth):
73 | vt = []
74 | k = random.choice(kl)
75 | bins = self.rates[k]
76 | rnd = np.random.uniform(0, 1, len(bins))
77 | for r, bin in zip(rnd, bins):
78 | p = self.rate_to_p(bin)
79 | value = np.random.choice(p.shape[0], p=p / p.sum())
80 | vt.append((value + r) / p.shape[0])
81 | v.append(OP(k, vt))
82 | return v
83 |
84 | def update_rates(self, policy, proximity):
85 | for k, bins in policy:
86 | for p, rate in zip(bins, self.rates[k]):
87 | p = int(p * len(rate) * 0.999)
88 | rate[p] = rate[p] * self.decay + proximity * (1 - self.decay)
89 | print(f"\t {k} weights updated")
90 |
91 | def stats(self):
92 | return "\n".join(
93 | "%-16s %s"
94 | % (
95 | k,
96 | " / ".join(
97 | " ".join("%.2f" % x for x in self.rate_to_p(rate))
98 | for rate in self.rates[k]
99 | ),
100 | )
101 | for k in sorted(OPS.keys())
102 | )
103 |
104 |
105 | def _enhance(x, op, level):
106 | return op(x).enhance(0.1 + 1.9 * level)
107 |
108 |
109 | def _imageop(x, op, level):
110 | return Image.blend(x, op(x), level)
111 |
112 |
113 | def _filter(x, op, level):
114 | return Image.blend(x, x.filter(op), level)
115 |
116 |
117 | @register(17)
118 | def autocontrast(x, level):
119 | return _imageop(x, ImageOps.autocontrast, level)
120 |
121 |
122 | @register(17)
123 | def brightness(x, brightness):
124 | return _enhance(x, ImageEnhance.Brightness, brightness)
125 |
126 |
127 | @register(17)
128 | def color(x, color):
129 | return _enhance(x, ImageEnhance.Color, color)
130 |
131 |
132 | @register(17)
133 | def contrast(x, contrast):
134 | return _enhance(x, ImageEnhance.Contrast, contrast)
135 |
136 |
137 | @register(17)
138 | def equalize(x, level):
139 | return _imageop(x, ImageOps.equalize, level)
140 |
141 |
142 | @register(17)
143 | def invert(x, level):
144 | return _imageop(x, ImageOps.invert, level)
145 |
146 |
147 | @register(8)
148 | def posterize(x, level):
149 | level = 1 + int(level * 7.999)
150 | return ImageOps.posterize(x, level)
151 |
152 |
153 | @register(17)
154 | def solarize(x, th):
155 | th = int(th * 255.999)
156 | return ImageOps.solarize(x, th)
157 |
158 |
159 | @register(17)
160 | def smooth(x, level):
161 | return _filter(x, ImageFilter.SMOOTH, level)
162 |
163 |
164 | @register(17)
165 | def blur(x, level):
166 | return _filter(x, ImageFilter.BLUR, level)
167 |
168 |
169 | @register(17)
170 | def sharpness(x, sharpness):
171 | return _enhance(x, ImageEnhance.Sharpness, sharpness)
172 |
173 |
174 | # weak after here
175 |
176 |
177 | @register(17)
178 | def cutout(x, level):
179 | """Apply cutout to pil_img at the specified level."""
180 | size = 1 + int(level * min(x.size) * 0.499)
181 | img_height, img_width = x.size
182 | height_loc = np.random.randint(low=img_height // 2, high=img_height)
183 | width_loc = np.random.randint(low=img_height // 2, high=img_width)
184 | upper_coord = (max(0, height_loc - size // 2), max(0, width_loc - size // 2))
185 | lower_coord = (
186 | min(img_height, height_loc + size // 2),
187 | min(img_width, width_loc + size // 2),
188 | )
189 | pixels = x.load() # create the pixel map
190 | for i in range(upper_coord[0], lower_coord[0]): # for every col:
191 | for j in range(upper_coord[1], lower_coord[1]): # For every row
192 | x.putpixel((i, j), 0) # set the color accordingly
193 | return x
194 |
195 |
196 | @register()
197 | def identity(x):
198 | return x
199 |
200 |
201 | @register(17, 6)
202 | def rescale(x, scale, method):
203 | s = x.size
204 | scale *= 0.25
205 | crop = (scale * s[0], scale * s[1], s[0] * (1 - scale), s[1] * (1 - scale))
206 | methods = (
207 | Image.ANTIALIAS,
208 | Image.BICUBIC,
209 | Image.BILINEAR,
210 | Image.BOX,
211 | Image.HAMMING,
212 | Image.NEAREST,
213 | )
214 | method = methods[int(method * 5.99)]
215 | return x.crop(crop).resize(x.size, method)
216 |
217 |
218 | @register(17)
219 | def rotate(x, angle):
220 | angle = int(np.round((2 * angle - 1) * 45))
221 | return x.rotate(angle)
222 |
223 |
224 | @register(17)
225 | def shear_x(x, shear):
226 | shear = (2 * shear - 1) * 0.3
227 | return x.transform(x.size, Image.AFFINE, (1, shear, 0, 0, 1, 0))
228 |
229 |
230 | @register(17)
231 | def shear_y(x, shear):
232 | shear = (2 * shear - 1) * 0.3
233 | return x.transform(x.size, Image.AFFINE, (1, 0, 0, shear, 1, 0))
234 |
235 |
236 | @register(17)
237 | def translate_x(x, delta):
238 | delta = (2 * delta - 1) * 0.3
239 | return x.transform(x.size, Image.AFFINE, (1, 0, delta, 0, 1, 0))
240 |
241 |
242 | @register(17)
243 | def translate_y(x, delta):
244 | delta = (2 * delta - 1) * 0.3
245 | return x.transform(x.size, Image.AFFINE, (1, 0, 0, 0, 1, delta))
--------------------------------------------------------------------------------
/code/config.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------'
7 |
8 | import os
9 | import yaml
10 | from yacs.config import CfgNode as CN
11 |
12 | _C = CN()
13 |
14 | # Base config files
15 | _C.BASE = ['']
16 |
17 | # -----------------------------------------------------------------------------
18 | # Data settings
19 | # -----------------------------------------------------------------------------
20 | _C.DATA = CN()
21 | # Batch size for a single GPU, could be overwritten by command line argument
22 | _C.DATA.BATCH_SIZE = 128
23 | # Path to dataset, could be overwritten by command line argument
24 | _C.DATA.DATA_PATH = ''
25 | # Dataset name
26 | _C.DATA.DATASET = 'imagenet'
27 | # Input image size
28 | _C.DATA.IMG_SIZE = 224
29 | # Interpolation to resize image (random, bilinear, bicubic)
30 | _C.DATA.INTERPOLATION = 'bicubic'
31 | # Use zipped dataset instead of folder dataset
32 | # could be overwritten by command line argument
33 | _C.DATA.ZIP_MODE = False
34 | # Cache Data in Memory, could be overwritten by command line argument
35 | _C.DATA.CACHE_MODE = 'part'
36 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
37 | _C.DATA.PIN_MEMORY = True
38 | # Number of data loading threads
39 | _C.DATA.NUM_WORKERS = 8
40 |
41 | # -----------------------------------------------------------------------------
42 | # Model settings
43 | # -----------------------------------------------------------------------------
44 | _C.MODEL = CN()
45 | # Model type
46 | _C.MODEL.TYPE = 'swin'
47 | # Model name
48 | _C.MODEL.NAME = 'swin_tiny_patch4_window7_224'
49 | # Checkpoint to resume, could be overwritten by command line argument
50 | # _C.MODEL.PRETRAIN_CKPT = './pretrained_ckpt/swin_tiny_patch4_window7_224.pth'
51 | _C.MODEL.PRETRAIN_CKPT = './pretrained_ckpt/vmamba_tiny_e292.pth'
52 | _C.MODEL.RESUME = ''
53 | # Number of classes, overwritten in data preparation
54 | _C.MODEL.NUM_CLASSES = 1000
55 | # Dropout rate
56 | _C.MODEL.DROP_RATE = 0.0
57 | # Drop path rate
58 | _C.MODEL.DROP_PATH_RATE = 0.1
59 | # Label Smoothing
60 | _C.MODEL.LABEL_SMOOTHING = 0.1
61 |
62 | # VSSM parameters
63 | _C.MODEL.VSSM = CN()
64 | _C.MODEL.VSSM.PATCH_SIZE = 4
65 | _C.MODEL.VSSM.IN_CHANS = 3
66 | _C.MODEL.VSSM.EMBED_DIM = 96
67 | _C.MODEL.VSSM.DEPTHS = [2, 2, 9, 2]
68 | _C.MODEL.VSSM.MLP_RATIO = 4.
69 | _C.MODEL.VSSM.PATCH_NORM = True
70 |
71 | # Swin Transformer parameters
72 | _C.MODEL.SWIN = CN()
73 | _C.MODEL.SWIN.PATCH_SIZE = 4
74 | _C.MODEL.SWIN.IN_CHANS = 3
75 | _C.MODEL.SWIN.EMBED_DIM = 96
76 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
77 | _C.MODEL.SWIN.DECODER_DEPTHS = [2, 2, 6, 2]
78 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
79 | _C.MODEL.SWIN.WINDOW_SIZE = 7
80 | _C.MODEL.SWIN.MLP_RATIO = 4.
81 | _C.MODEL.SWIN.QKV_BIAS = True
82 | _C.MODEL.SWIN.QK_SCALE = False
83 | _C.MODEL.SWIN.APE = False
84 | _C.MODEL.SWIN.PATCH_NORM = True
85 | _C.MODEL.SWIN.FINAL_UPSAMPLE= "expand_first"
86 |
87 | # -----------------------------------------------------------------------------
88 | # Training settings
89 | # -----------------------------------------------------------------------------
90 | _C.TRAIN = CN()
91 | _C.TRAIN.START_EPOCH = 0
92 | _C.TRAIN.EPOCHS = 300
93 | _C.TRAIN.WARMUP_EPOCHS = 20
94 | _C.TRAIN.WEIGHT_DECAY = 0.05
95 | _C.TRAIN.BASE_LR = 5e-4
96 | _C.TRAIN.WARMUP_LR = 5e-7
97 | _C.TRAIN.MIN_LR = 5e-6
98 | # Clip gradient norm
99 | _C.TRAIN.CLIP_GRAD = 5.0
100 | # Auto resume from latest checkpoint
101 | _C.TRAIN.AUTO_RESUME = True
102 | # Gradient accumulation steps
103 | # could be overwritten by command line argument
104 | _C.TRAIN.ACCUMULATION_STEPS = 0
105 | # Whether to use gradient checkpointing to save memory
106 | # could be overwritten by command line argument
107 | _C.TRAIN.USE_CHECKPOINT = False
108 |
109 | # LR scheduler
110 | _C.TRAIN.LR_SCHEDULER = CN()
111 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
112 | # Epoch interval to decay LR, used in StepLRScheduler
113 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
114 | # LR decay rate, used in StepLRScheduler
115 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
116 |
117 | # Optimizer
118 | _C.TRAIN.OPTIMIZER = CN()
119 | _C.TRAIN.OPTIMIZER.NAME = 'adamw'
120 | # Optimizer Epsilon
121 | _C.TRAIN.OPTIMIZER.EPS = 1e-8
122 | # Optimizer Betas
123 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
124 | # SGD momentum
125 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
126 |
127 | # -----------------------------------------------------------------------------
128 | # Augmentation settings
129 | # -----------------------------------------------------------------------------
130 | _C.AUG = CN()
131 | # Color jitter factor
132 | _C.AUG.COLOR_JITTER = 0.4
133 | # Use AutoAugment policy. "v0" or "original"
134 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
135 | # Random erase prob
136 | _C.AUG.REPROB = 0.25
137 | # Random erase mode
138 | _C.AUG.REMODE = 'pixel'
139 | # Random erase count
140 | _C.AUG.RECOUNT = 1
141 | # Mixup alpha, mixup enabled if > 0
142 | _C.AUG.MIXUP = 0.8
143 | # Cutmix alpha, cutmix enabled if > 0
144 | _C.AUG.CUTMIX = 1.0
145 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set
146 | _C.AUG.CUTMIX_MINMAX = False
147 | # Probability of performing mixup or cutmix when either/both is enabled
148 | _C.AUG.MIXUP_PROB = 1.0
149 | # Probability of switching to cutmix when both mixup and cutmix enabled
150 | _C.AUG.MIXUP_SWITCH_PROB = 0.5
151 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
152 | _C.AUG.MIXUP_MODE = 'batch'
153 |
154 | # -----------------------------------------------------------------------------
155 | # Testing settings
156 | # -----------------------------------------------------------------------------
157 | _C.TEST = CN()
158 | # Whether to use center crop when testing
159 | _C.TEST.CROP = True
160 |
161 | # -----------------------------------------------------------------------------
162 | # Misc
163 | # -----------------------------------------------------------------------------
164 | # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2')
165 | # overwritten by command line argument
166 | _C.AMP_OPT_LEVEL = ''
167 | # Path to output folder, overwritten by command line argument
168 | _C.OUTPUT = ''
169 | # Tag of experiment, overwritten by command line argument
170 | _C.TAG = 'default'
171 | # Frequency to save checkpoint
172 | _C.SAVE_FREQ = 1
173 | # Frequency to logging info
174 | _C.PRINT_FREQ = 10
175 | # Fixed random seed
176 | _C.SEED = 0
177 | # Perform evaluation only, overwritten by command line argument
178 | _C.EVAL_MODE = False
179 | # Test throughput only, overwritten by command line argument
180 | _C.THROUGHPUT_MODE = False
181 | # local rank for DistributedDataParallel, given by command line argument
182 | _C.LOCAL_RANK = 0
183 |
184 |
185 | def _update_config_from_file(config, cfg_file):
186 | config.defrost()
187 | with open(cfg_file, 'r') as f:
188 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
189 |
190 | for cfg in yaml_cfg.setdefault('BASE', ['']):
191 | if cfg:
192 | _update_config_from_file(
193 | config, os.path.join(os.path.dirname(cfg_file), cfg)
194 | )
195 | print('=> merge config from {}'.format(cfg_file))
196 | config.merge_from_file(cfg_file)
197 | config.freeze()
198 |
199 |
200 | def update_config(config, args):
201 | _update_config_from_file(config, args.cfg)
202 |
203 | config.defrost()
204 | if args.opts:
205 | config.merge_from_list(args.opts)
206 |
207 | # merge from specific arguments
208 | if args.batch_size:
209 | config.DATA.BATCH_SIZE = args.batch_size
210 | if args.zip:
211 | config.DATA.ZIP_MODE = True
212 | if args.cache_mode:
213 | config.DATA.CACHE_MODE = args.cache_mode
214 | if args.resume:
215 | config.MODEL.RESUME = args.resume
216 | if args.accumulation_steps:
217 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps
218 | if args.use_checkpoint:
219 | config.TRAIN.USE_CHECKPOINT = True
220 | if args.amp_opt_level:
221 | config.AMP_OPT_LEVEL = args.amp_opt_level
222 | if args.tag:
223 | config.TAG = args.tag
224 | if args.eval:
225 | config.EVAL_MODE = True
226 | if args.throughput:
227 | config.THROUGHPUT_MODE = True
228 |
229 | config.freeze()
230 |
231 |
232 | def get_config(args):
233 | """Get a yacs CfgNode object with default values."""
234 | # Return a clone so that the defaults will not be altered
235 | # This is for the "local variable" use pattern
236 | config = _C.clone()
237 | update_config(config, args)
238 |
239 | return config
240 |
--------------------------------------------------------------------------------
/code/configs/swin_tiny_patch4_window7_224_lite.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: swin
3 | NAME: swin_tiny_patch4_window7_224
4 | DROP_PATH_RATE: 0.2
5 | PRETRAIN_CKPT: "../code/pretrained_ckpt/swin_tiny_patch4_window7_224.pth"
6 | SWIN:
7 | FINAL_UPSAMPLE: "expand_first"
8 | EMBED_DIM: 96
9 | DEPTHS: [ 2, 2, 2, 2 ]
10 | DECODER_DEPTHS: [ 2, 2, 2, 1]
11 | NUM_HEADS: [ 3, 6, 12, 24 ]
12 | WINDOW_SIZE: 7
--------------------------------------------------------------------------------
/code/configs/vmamba_tiny.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: vssm
3 | NAME: vssm_tiny
4 | DROP_PATH_RATE: 0.2
5 | PRETRAIN_CKPT: "../code/pretrained_ckpt/vmamba_tiny_e292.pth"
6 | VSSM:
7 | EMBED_DIM: 96
8 | DEPTHS: [ 2, 2, 2, 2 ]
--------------------------------------------------------------------------------
/code/dataloaders/__pycache__/dataset.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/dataloaders/__pycache__/dataset.cpython-310.pyc
--------------------------------------------------------------------------------
/code/dataloaders/__pycache__/dataset.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/dataloaders/__pycache__/dataset.cpython-311.pyc
--------------------------------------------------------------------------------
/code/dataloaders/__pycache__/dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/dataloaders/__pycache__/dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/code/dataloaders/__pycache__/dataset_s2l.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/dataloaders/__pycache__/dataset_s2l.cpython-38.pyc
--------------------------------------------------------------------------------
/code/dataloaders/__pycache__/dataset_semi.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/dataloaders/__pycache__/dataset_semi.cpython-38.pyc
--------------------------------------------------------------------------------
/code/dataloaders/__pycache__/utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/dataloaders/__pycache__/utils.cpython-310.pyc
--------------------------------------------------------------------------------
/code/dataloaders/__pycache__/utils.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/dataloaders/__pycache__/utils.cpython-311.pyc
--------------------------------------------------------------------------------
/code/dataloaders/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/dataloaders/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/code/dataloaders/acdc_data_processing.py:
--------------------------------------------------------------------------------
1 | # save images in slice level
2 | import glob
3 | import os
4 |
5 | import h5py
6 | import numpy as np
7 | import SimpleITK as sitk
8 |
9 |
10 | class MedicalImageDeal(object):
11 | def __init__(self, img, percent=1):
12 | self.img = img
13 | self.percent = percent
14 |
15 | @property
16 | def valid_img(self):
17 | from skimage import exposure
18 | cdf = exposure.cumulative_distribution(self.img)
19 | watershed = cdf[1][cdf[0] >= self.percent][0]
20 | return np.clip(self.img, self.img.min(), watershed)
21 |
22 | @property
23 | def norm_img(self):
24 | return (self.img - self.img.min()) / (self.img.max() - self.img.min())
25 |
26 | # saving images in slice level
27 |
28 |
29 | slice_num = 0
30 | mask_path = sorted(
31 | glob.glob("../data/ACDC_training/*_gt.nii.gz"))
32 | for case in mask_path:
33 | label_itk = sitk.ReadImage(case)
34 | label = sitk.GetArrayFromImage(label_itk)
35 |
36 | image_path = case.replace("_gt", "")
37 | image_itk = sitk.ReadImage(image_path)
38 | image = sitk.GetArrayFromImage(image_itk)
39 |
40 | scribble_path = case.replace("_gt", "_scribble")
41 | scribble_itk = sitk.ReadImage(scribble_path)
42 | scribble = sitk.GetArrayFromImage(scribble_itk)
43 |
44 | image = MedicalImageDeal(image, percent=0.99).valid_img
45 | image = (image - image.min()) / (image.max() - image.min())
46 | print(image.shape)
47 | image = image.astype(np.float32)
48 | item = case.split("/")[-1].split(".")[0].replace("_gt", "")
49 | if image.shape != label.shape:
50 | print("Error")
51 | print(item)
52 | for slice_ind in range(image.shape[0]):
53 | f = h5py.File(
54 | '../data/ACDC_training_slices/{}_slice_{}.h5'.format(item, slice_ind), 'w')
55 | f.create_dataset(
56 | 'image', data=image[slice_ind], compression="gzip")
57 | f.create_dataset('label', data=label[slice_ind], compression="gzip")
58 | f.create_dataset(
59 | 'scribble', data=scribble[slice_ind], compression="gzip")
60 | f.close()
61 | slice_num += 1
62 | print("Converted all ACDC volumes to 2D slices")
63 | print("Total {} slices".format(slice_num))
64 |
65 | # saving images in volume level
66 |
67 |
68 | class MedicalImageDeal(object):
69 | def __init__(self, img, percent=1):
70 | self.img = img
71 | self.percent = percent
72 |
73 | @property
74 | def valid_img(self):
75 | from skimage import exposure
76 | cdf = exposure.cumulative_distribution(self.img)
77 | watershed = cdf[1][cdf[0] >= self.percent][0]
78 | return np.clip(self.img, self.img.min(), watershed)
79 |
80 | @property
81 | def norm_img(self):
82 | return (self.img - self.img.min()) / (self.img.max() - self.img.min())
83 |
84 |
85 | slice_num = 0
86 | mask_path = sorted(
87 | glob.glob("../data/ACDC_training/*_gt.nii.gz"))
88 | print(mask_path)
89 |
90 | for case in mask_path:
91 | label_itk = sitk.ReadImage(case)
92 | label = sitk.GetArrayFromImage(label_itk)
93 |
94 | image_path = case.replace("_gt", "")
95 | image_itk = sitk.ReadImage(image_path)
96 | image = sitk.GetArrayFromImage(image_itk)
97 |
98 | scribble_path = case.replace("_gt", "_scribble")
99 | scribble_itk = sitk.ReadImage(scribble_path)
100 | scribble = sitk.GetArrayFromImage(scribble_itk)
101 |
102 | image = MedicalImageDeal(image, percent=0.99).valid_img
103 | image = (image - image.min()) / (image.max() - image.min())
104 | print(image.shape)
105 | image = image.astype(np.float32)
106 | item = case.split("/")[-1].split(".")[0].replace("_gt", "")
107 | if image.shape != label.shape:
108 | print("Error")
109 | print(item)
110 | f = h5py.File(
111 | '../data/ACDC_training_volumes/{}.h5'.format(item), 'w')
112 | f.create_dataset(
113 | 'image', data=image, compression="gzip")
114 | f.create_dataset('label', data=label, compression="gzip")
115 | f.create_dataset('scribble', data=scribble, compression="gzip")
116 | f.close()
117 | slice_num += 1
118 | print("Converted all ACDC volumes to 2D slices")
119 | print("Total {} slices".format(slice_num))
120 |
--------------------------------------------------------------------------------
/code/dataloaders/acdc_pseudo_label_random_walker.py:
--------------------------------------------------------------------------------
1 | import SimpleITK as sitk
2 | import glob
3 | import os
4 | import numpy as np
5 | from skimage.exposure import rescale_intensity
6 | from skimage.segmentation import random_walker
7 |
8 |
9 | def pseudo_label_generator_acdc(data, seed):
10 | from skimage.exposure import rescale_intensity
11 | from skimage.segmentation import random_walker
12 | if 1 not in np.unique(seed) or 2 not in np.unique(seed) or 3 not in np.unique(seed):
13 | pseudo_label = np.zeros_like(seed)
14 | else:
15 | markers = np.ones_like(seed)
16 | markers[seed == 4] = 0
17 | markers[seed == 0] = 1
18 | markers[seed == 1] = 2
19 | markers[seed == 2] = 3
20 | markers[seed == 3] = 4
21 | sigma = 0.35
22 | data = rescale_intensity(data, in_range=(-sigma, 1 + sigma),
23 | out_range=(-1, 1))
24 | segmentation = random_walker(data, markers, beta=100, mode='bf')
25 | pseudo_label = segmentation - 1
26 | return pseudo_label
27 |
28 |
29 | def pseudo_label_generator(data, seed):
30 | # in the seed array: 0 means background, 1 to 3 mean class 1 to 3, 4 means: unknown region
31 | markers = np.ones_like(seed)
32 | markers[seed == 4] = 0
33 | markers[seed == 0] = 1
34 | markers[seed == 1] = 2
35 | markers[seed == 2] = 3
36 | markers[seed == 3] = 4
37 | sigma = 0.35
38 | data = rescale_intensity(data, in_range=(-sigma, 1 + sigma),
39 | out_range=(-1, 1))
40 | pseudo_label = random_walker(data, markers, beta=100, mode='bf')
41 | return pseudo_label-1
42 |
43 |
44 | for i in sorted(glob.glob("../data/ACDC_training/*_scribble.nii.gz"))[2:]:
45 | print(i.replace("_scribble.nii.gz", ".nii.gz"))
46 | img_itk = sitk.ReadImage(i.replace("_scribble.nii.gz", ".nii.gz"))
47 | image = sitk.GetArrayFromImage(img_itk)
48 | scribble = sitk.GetArrayFromImage(sitk.ReadImage(i))
49 | pseudo_volumes = np.zeros_like(image)
50 | for ind, slice_ind in enumerate(range(image.shape[0])):
51 | if 1 not in np.unique(scribble[ind, ...]) or 2 not in np.unique(scribble[ind, ...]) or 3 not in np.unique(scribble[ind, ...]):
52 | pass
53 | else:
54 | pseudo_volumes[ind, ...] = pseudo_label_generator(
55 | image[ind, ...], scribble[ind, ...])
56 | pseudo_volumes_itk = sitk.GetImageFromArray(pseudo_volumes)
57 | pseudo_volumes_itk.CopyInformation(img_itk)
58 | sitk.WriteImage(pseudo_volumes_itk, i.replace(
59 | "_scribble.nii.gz", "_random_walker.nii.gz"))
--------------------------------------------------------------------------------
/code/dataloaders/dataset.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import os
3 | import random
4 | import re
5 | from glob import glob
6 |
7 | import cv2
8 | import h5py
9 | import numpy as np
10 | import torch
11 | from scipy import ndimage
12 | from scipy.ndimage.interpolation import zoom
13 | from torch.utils.data import Dataset
14 | from torch.utils.data.sampler import Sampler
15 |
16 |
17 | def pseudo_label_generator_acdc(data, seed, beta=100, mode='bf'):
18 | from skimage.exposure import rescale_intensity
19 | from skimage.segmentation import random_walker
20 | if 1 not in np.unique(seed) or 2 not in np.unique(seed) or 3 not in np.unique(seed):
21 | pseudo_label = np.zeros_like(seed)
22 | else:
23 | markers = np.ones_like(seed)
24 | markers[seed == 4] = 0
25 | markers[seed == 0] = 1
26 | markers[seed == 1] = 2
27 | markers[seed == 2] = 3
28 | markers[seed == 3] = 4
29 | sigma = 0.35
30 | data = rescale_intensity(data, in_range=(-sigma, 1 + sigma),
31 | out_range=(-1, 1))
32 | segmentation = random_walker(data, markers, beta, mode)
33 | pseudo_label = segmentation - 1
34 | return pseudo_label
35 |
36 |
37 | class BaseDataSets(Dataset):
38 | def __init__(self, base_dir=None, split='train', transform=None, fold="fold1", sup_type="label"):
39 | self._base_dir = base_dir
40 | self.sample_list = []
41 | self.split = split
42 | self.sup_type = sup_type
43 | self.transform = transform
44 | train_ids, test_ids = self._get_fold_ids(fold)
45 | if self.split == 'train':
46 | self.all_slices = os.listdir(
47 | self._base_dir + "/ACDC_training_slices")
48 | self.sample_list = []
49 | for ids in train_ids:
50 | new_data_list = list(filter(lambda x: re.match(
51 | '{}.*'.format(ids), x) != None, self.all_slices))
52 | self.sample_list.extend(new_data_list)
53 |
54 | elif self.split == 'val':
55 | self.all_volumes = os.listdir(
56 | self._base_dir + "/ACDC_training_volumes")
57 | self.sample_list = []
58 | for ids in test_ids:
59 | new_data_list = list(filter(lambda x: re.match(
60 | '{}.*'.format(ids), x) != None, self.all_volumes))
61 | self.sample_list.extend(new_data_list)
62 |
63 | # if num is not None and self.split == "train":
64 | # self.sample_list = self.sample_list[:num]
65 | print("total {} samples".format(len(self.sample_list)))
66 |
67 | def _get_fold_ids(self, fold):
68 | all_cases_set = ["patient{:0>3}".format(i) for i in range(1, 101)]
69 | fold1_testing_set = [
70 | "patient{:0>3}".format(i) for i in range(1, 21)]
71 | fold1_training_set = [
72 | i for i in all_cases_set if i not in fold1_testing_set]
73 |
74 | fold2_testing_set = [
75 | "patient{:0>3}".format(i) for i in range(21, 41)]
76 | fold2_training_set = [
77 | i for i in all_cases_set if i not in fold2_testing_set]
78 |
79 | fold3_testing_set = [
80 | "patient{:0>3}".format(i) for i in range(41, 61)]
81 | fold3_training_set = [
82 | i for i in all_cases_set if i not in fold3_testing_set]
83 |
84 | fold4_testing_set = [
85 | "patient{:0>3}".format(i) for i in range(61, 81)]
86 | fold4_training_set = [
87 | i for i in all_cases_set if i not in fold4_testing_set]
88 |
89 | fold5_testing_set = [
90 | "patient{:0>3}".format(i) for i in range(81, 101)]
91 | fold5_training_set = [
92 | i for i in all_cases_set if i not in fold5_testing_set]
93 | if fold == "fold1":
94 | return [fold1_training_set, fold1_testing_set]
95 | elif fold == "fold2":
96 | return [fold2_training_set, fold2_testing_set]
97 | elif fold == "fold3":
98 | return [fold3_training_set, fold3_testing_set]
99 | elif fold == "fold4":
100 | return [fold4_training_set, fold4_testing_set]
101 | elif fold == "fold5":
102 | return [fold5_training_set, fold5_testing_set]
103 | else:
104 | return "ERROR KEY"
105 |
106 | def __len__(self):
107 | return len(self.sample_list)
108 |
109 | def __getitem__(self, idx):
110 | case = self.sample_list[idx]
111 | if self.split == "train":
112 | h5f = h5py.File(self._base_dir +
113 | "/ACDC_training_slices/{}".format(case), 'r')
114 | else:
115 | h5f = h5py.File(self._base_dir +
116 | "/ACDC_training_volumes/{}".format(case), 'r')
117 | image = h5f['image'][:]
118 | label = h5f['label'][:]
119 | sample = {'image': image, 'label': label}
120 | if self.split == "train":
121 | image = h5f['image'][:]
122 | if self.sup_type == "random_walker":
123 | label = pseudo_label_generator_acdc(image, h5f["scribble"][:])
124 | else:
125 | label = h5f[self.sup_type][:]
126 | sample = {'image': image, 'label': label}
127 | sample = self.transform(sample)
128 | else:
129 | image = h5f['image'][:]
130 | label = h5f['label'][:]
131 | sample = {'image': image, 'label': label}
132 | sample["idx"] = idx
133 | return sample
134 |
135 |
136 | def random_rot_flip(image, label):
137 | k = np.random.randint(0, 4)
138 | image = np.rot90(image, k)
139 | label = np.rot90(label, k)
140 | axis = np.random.randint(0, 2)
141 | image = np.flip(image, axis=axis).copy()
142 | label = np.flip(label, axis=axis).copy()
143 | return image, label
144 |
145 |
146 | def random_rotate(image, label, cval):
147 | angle = np.random.randint(-20, 20)
148 | image = ndimage.rotate(image, angle, order=0, reshape=False)
149 | label = ndimage.rotate(label, angle, order=0,
150 | reshape=False, mode="constant", cval=cval)
151 | return image, label
152 |
153 |
154 | class RandomGenerator(object):
155 | def __init__(self, output_size):
156 | self.output_size = output_size
157 |
158 | def __call__(self, sample):
159 | image, label = sample['image'], sample['label']
160 | # ind = random.randrange(0, img.shape[0])
161 | # image = img[ind, ...]
162 | # label = lab[ind, ...]
163 | if random.random() > 0.5:
164 | image, label = random_rot_flip(image, label)
165 | elif random.random() > 0.5:
166 | if 4 in np.unique(label):
167 | image, label = random_rotate(image, label, cval=4)
168 | else:
169 | image, label = random_rotate(image, label, cval=0)
170 | x, y = image.shape
171 | image = zoom(
172 | image, (self.output_size[0] / x, self.output_size[1] / y), order=0)
173 | label = zoom(
174 | label, (self.output_size[0] / x, self.output_size[1] / y), order=0)
175 | image = torch.from_numpy(
176 | image.astype(np.float32)).unsqueeze(0)
177 | label = torch.from_numpy(label.astype(np.uint8))
178 | sample = {'image': image, 'label': label}
179 | return sample
180 |
181 |
182 | class TwoStreamBatchSampler(Sampler):
183 | """Iterate two sets of indices
184 |
185 | An 'epoch' is one iteration through the primary indices.
186 | During the epoch, the secondary indices are iterated through
187 | as many times as needed.
188 | """
189 |
190 | def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size):
191 | self.primary_indices = primary_indices
192 | self.secondary_indices = secondary_indices
193 | self.secondary_batch_size = secondary_batch_size
194 | self.primary_batch_size = batch_size - secondary_batch_size
195 |
196 | assert len(self.primary_indices) >= self.primary_batch_size > 0
197 | assert len(self.secondary_indices) >= self.secondary_batch_size > 0
198 |
199 | def __iter__(self):
200 | primary_iter = iterate_once(self.primary_indices)
201 | secondary_iter = iterate_eternally(self.secondary_indices)
202 | return (
203 | primary_batch + secondary_batch
204 | for (primary_batch, secondary_batch)
205 | in zip(grouper(primary_iter, self.primary_batch_size),
206 | grouper(secondary_iter, self.secondary_batch_size))
207 | )
208 |
209 | def __len__(self):
210 | return len(self.primary_indices) // self.primary_batch_size
211 |
212 |
213 | def iterate_once(iterable):
214 | return np.random.permutation(iterable)
215 |
216 |
217 | def iterate_eternally(indices):
218 | def infinite_shuffles():
219 | while True:
220 | yield np.random.permutation(indices)
221 | return itertools.chain.from_iterable(infinite_shuffles())
222 |
223 |
224 | def grouper(iterable, n):
225 | "Collect data into fixed-length chunks or blocks"
226 | # grouper('ABCDEFG', 3) --> ABC DEF"
227 | args = [iter(iterable)] * n
228 | return zip(*args)
229 |
--------------------------------------------------------------------------------
/code/dataloaders/dataset_semi.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import os
3 | import random
4 | import re
5 | from glob import glob
6 |
7 | import cv2
8 | import h5py
9 | import numpy as np
10 | import torch
11 | from scipy import ndimage
12 | from scipy.ndimage.interpolation import zoom
13 | from torch.utils.data import Dataset
14 | from torch.utils.data.sampler import Sampler
15 |
16 |
17 | class BaseDataSets(Dataset):
18 | def __init__(self, base_dir=None, num=4, labeled_type="labeled", split='train', transform=None, fold="fold1", sup_type="label"):
19 | self._base_dir = base_dir
20 | self.sample_list = []
21 | self.split = split
22 | self.sup_type = sup_type
23 | self.transform = transform
24 | self.num = num
25 | self.labeled_type = labeled_type
26 | train_ids, test_ids = self._get_fold_ids(fold)
27 | all_labeled_ids = ["patient{:0>3}".format(
28 | 10 * i) for i in range(1, 11)]
29 | if self.split == 'train':
30 | self.all_slices = os.listdir(
31 | self._base_dir + "/ACDC_training_slices")
32 | self.sample_list = []
33 | labeled_ids = [i for i in all_labeled_ids if i in train_ids]
34 | unlabeled_ids = [i for i in train_ids if i not in labeled_ids]
35 | if self.labeled_type == "labeled":
36 | print("Labeled patients IDs", labeled_ids)
37 | for ids in labeled_ids:
38 | new_data_list = list(filter(lambda x: re.match(
39 | '{}.*'.format(ids), x) != None, self.all_slices))
40 | self.sample_list.extend(new_data_list)
41 | print("total labeled {} samples".format(len(self.sample_list)))
42 | else:
43 | print("Unlabeled patients IDs", unlabeled_ids)
44 | for ids in unlabeled_ids:
45 | new_data_list = list(filter(lambda x: re.match(
46 | '{}.*'.format(ids), x) != None, self.all_slices))
47 | self.sample_list.extend(new_data_list)
48 | print("total unlabeled {} samples".format(len(self.sample_list)))
49 |
50 | elif self.split == 'val':
51 | self.all_volumes = os.listdir(
52 | self._base_dir + "/ACDC_training_volumes")
53 | self.sample_list = []
54 | for ids in test_ids:
55 | new_data_list = list(filter(lambda x: re.match(
56 | '{}.*'.format(ids), x) != None, self.all_volumes))
57 | self.sample_list.extend(new_data_list)
58 |
59 | # if num is not None and self.split == "train":
60 | # self.sample_list = self.sample_list[:num]
61 |
62 | def _get_fold_ids(self, fold):
63 | all_cases_set = ["patient{:0>3}".format(i) for i in range(1, 101)]
64 | fold1_testing_set = [
65 | "patient{:0>3}".format(i) for i in range(1, 21)]
66 | fold1_training_set = [
67 | i for i in all_cases_set if i not in fold1_testing_set]
68 |
69 | fold2_testing_set = [
70 | "patient{:0>3}".format(i) for i in range(21, 41)]
71 | fold2_training_set = [
72 | i for i in all_cases_set if i not in fold2_testing_set]
73 |
74 | fold3_testing_set = [
75 | "patient{:0>3}".format(i) for i in range(41, 61)]
76 | fold3_training_set = [
77 | i for i in all_cases_set if i not in fold3_testing_set]
78 |
79 | fold4_testing_set = [
80 | "patient{:0>3}".format(i) for i in range(61, 81)]
81 | fold4_training_set = [
82 | i for i in all_cases_set if i not in fold4_testing_set]
83 |
84 | fold5_testing_set = [
85 | "patient{:0>3}".format(i) for i in range(81, 101)]
86 | fold5_training_set = [
87 | i for i in all_cases_set if i not in fold5_testing_set]
88 | if fold == "fold1":
89 | return [fold1_training_set, fold1_testing_set]
90 | elif fold == "fold2":
91 | return [fold2_training_set, fold2_testing_set]
92 | elif fold == "fold3":
93 | return [fold3_training_set, fold3_testing_set]
94 | elif fold == "fold4":
95 | return [fold4_training_set, fold4_testing_set]
96 | elif fold == "fold5":
97 | return [fold5_training_set, fold5_testing_set]
98 | else:
99 | return "ERROR KEY"
100 |
101 | def __len__(self):
102 | return len(self.sample_list)
103 |
104 | def __getitem__(self, idx):
105 | case = self.sample_list[idx]
106 | if self.split == "train":
107 | h5f = h5py.File(self._base_dir +
108 | "/ACDC_training_slices/{}".format(case), 'r')
109 | else:
110 | h5f = h5py.File(self._base_dir +
111 | "/ACDC_training_volumes/{}".format(case), 'r')
112 | image = h5f['image'][:]
113 | label = h5f['label'][:]
114 | sample = {'image': image, 'label': label}
115 | if self.split == "train":
116 | image = h5f['image'][:]
117 | label = h5f[self.sup_type][:]
118 | sample = {'image': image, 'label': label}
119 | sample = self.transform(sample)
120 | else:
121 | image = h5f['image'][:]
122 | label = h5f['label'][:]
123 | sample = {'image': image, 'label': label}
124 | sample["idx"] = case.split("_")[0]
125 | return sample
126 |
127 |
128 | def random_rot_flip(image, label):
129 | k = np.random.randint(0, 4)
130 | image = np.rot90(image, k)
131 | label = np.rot90(label, k)
132 | axis = np.random.randint(0, 2)
133 | image = np.flip(image, axis=axis).copy()
134 | label = np.flip(label, axis=axis).copy()
135 | return image, label
136 |
137 |
138 | def random_rotate(image, label, cval):
139 | angle = np.random.randint(-20, 20)
140 | image = ndimage.rotate(image, angle, order=0, reshape=False)
141 | label = ndimage.rotate(label, angle, order=0,
142 | reshape=False, mode="constant", cval=cval)
143 | return image, label
144 |
145 |
146 | class RandomGenerator(object):
147 | def __init__(self, output_size):
148 | self.output_size = output_size
149 |
150 | def __call__(self, sample):
151 | image, label = sample['image'], sample['label']
152 | # ind = random.randrange(0, img.shape[0])
153 | # image = img[ind, ...]
154 | # label = lab[ind, ...]
155 | if random.random() > 0.5:
156 | image, label = random_rot_flip(image, label)
157 | elif random.random() > 0.5:
158 | if 4 in np.unique(label):
159 | image, label = random_rotate(image, label, cval=4)
160 | else:
161 | image, label = random_rotate(image, label, cval=0)
162 | x, y = image.shape
163 | image = zoom(
164 | image, (self.output_size[0] / x, self.output_size[1] / y), order=0)
165 | label = zoom(
166 | label, (self.output_size[0] / x, self.output_size[1] / y), order=0)
167 | image = torch.from_numpy(
168 | image.astype(np.float32)).unsqueeze(0)
169 | label = torch.from_numpy(label.astype(np.uint8))
170 | sample = {'image': image, 'label': label}
171 | return sample
172 |
173 |
174 | class TwoStreamBatchSampler(Sampler):
175 | """Iterate two sets of indices
176 |
177 | An 'epoch' is one iteration through the primary indices.
178 | During the epoch, the secondary indices are iterated through
179 | as many times as needed.
180 | """
181 |
182 | def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size):
183 | self.primary_indices = primary_indices
184 | self.secondary_indices = secondary_indices
185 | self.secondary_batch_size = secondary_batch_size
186 | self.primary_batch_size = batch_size - secondary_batch_size
187 |
188 | assert len(self.primary_indices) >= self.primary_batch_size > 0
189 | assert len(self.secondary_indices) >= self.secondary_batch_size > 0
190 |
191 | def __iter__(self):
192 | primary_iter = iterate_once(self.primary_indices)
193 | secondary_iter = iterate_eternally(self.secondary_indices)
194 | return (
195 | primary_batch + secondary_batch
196 | for (primary_batch, secondary_batch)
197 | in zip(grouper(primary_iter, self.primary_batch_size),
198 | grouper(secondary_iter, self.secondary_batch_size))
199 | )
200 |
201 | def __len__(self):
202 | return len(self.primary_indices) // self.primary_batch_size
203 |
204 |
205 | def iterate_once(iterable):
206 | return np.random.permutation(iterable)
207 |
208 |
209 | def iterate_eternally(indices):
210 | def infinite_shuffles():
211 | while True:
212 | yield np.random.permutation(indices)
213 | return itertools.chain.from_iterable(infinite_shuffles())
214 |
215 |
216 | def grouper(iterable, n):
217 | "Collect data into fixed-length chunks or blocks"
218 | # grouper('ABCDEFG', 3) --> ABC DEF"
219 | args = [iter(iterable)] * n
220 | return zip(*args)
221 |
--------------------------------------------------------------------------------
/code/dataloaders/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import scipy.ndimage as nd
5 | import torch
6 | import torch.nn as nn
7 | # import matplotlib.pyplot as plt
8 | from skimage import measure
9 |
10 |
11 | def recursive_glob(rootdir='.', suffix=''):
12 | """Performs recursive glob with given suffix and rootdir
13 | :param rootdir is the root directory
14 | :param suffix is the suffix to be searched
15 | """
16 | return [os.path.join(looproot, filename)
17 | for looproot, _, filenames in os.walk(rootdir)
18 | for filename in filenames if filename.endswith(suffix)]
19 |
20 |
21 | def get_cityscapes_labels():
22 | return np.array([
23 | # [ 0, 0, 0],
24 | [128, 64, 128],
25 | [244, 35, 232],
26 | [70, 70, 70],
27 | [102, 102, 156],
28 | [190, 153, 153],
29 | [153, 153, 153],
30 | [250, 170, 30],
31 | [220, 220, 0],
32 | [107, 142, 35],
33 | [152, 251, 152],
34 | [0, 130, 180],
35 | [220, 20, 60],
36 | [255, 0, 0],
37 | [0, 0, 142],
38 | [0, 0, 70],
39 | [0, 60, 100],
40 | [0, 80, 100],
41 | [0, 0, 230],
42 | [119, 11, 32]])
43 |
44 |
45 | def get_pascal_labels():
46 | """Load the mapping that associates pascal classes with label colors
47 | Returns:
48 | np.ndarray with dimensions (21, 3)
49 | """
50 | return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
51 | [0, 0, 128], [128, 0, 128], [
52 | 0, 128, 128], [128, 128, 128],
53 | [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
54 | [64, 0, 128], [192, 0, 128], [
55 | 64, 128, 128], [192, 128, 128],
56 | [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
57 | [0, 64, 128]])
58 |
59 |
60 | def encode_segmap(mask):
61 | """Encode segmentation label images as pascal classes
62 | Args:
63 | mask (np.ndarray): raw segmentation label image of dimension
64 | (M, N, 3), in which the Pascal classes are encoded as colours.
65 | Returns:
66 | (np.ndarray): class map with dimensions (M,N), where the value at
67 | a given location is the integer denoting the class index.
68 | """
69 | mask = mask.astype(int)
70 | label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16)
71 | for ii, label in enumerate(get_pascal_labels()):
72 | label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii
73 | label_mask = label_mask.astype(int)
74 | return label_mask
75 |
76 |
77 | def decode_seg_map_sequence(label_masks, dataset='pascal'):
78 | rgb_masks = []
79 | for label_mask in label_masks:
80 | rgb_mask = decode_segmap(label_mask, dataset)
81 | rgb_masks.append(rgb_mask)
82 | rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2]))
83 | return rgb_masks
84 |
85 |
86 | def decode_segmap(label_mask, dataset, plot=False):
87 | """Decode segmentation class labels into a color image
88 | Args:
89 | label_mask (np.ndarray): an (M,N) array of integer values denoting
90 | the class label at each spatial location.
91 | plot (bool, optional): whether to show the resulting color image
92 | in a figure.
93 | Returns:
94 | (np.ndarray, optional): the resulting decoded color image.
95 | """
96 | if dataset == 'pascal':
97 | n_classes = 21
98 | label_colours = get_pascal_labels()
99 | elif dataset == 'cityscapes':
100 | n_classes = 19
101 | label_colours = get_cityscapes_labels()
102 | else:
103 | raise NotImplementedError
104 |
105 | r = label_mask.copy()
106 | g = label_mask.copy()
107 | b = label_mask.copy()
108 | for ll in range(0, n_classes):
109 | r[label_mask == ll] = label_colours[ll, 0]
110 | g[label_mask == ll] = label_colours[ll, 1]
111 | b[label_mask == ll] = label_colours[ll, 2]
112 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3))
113 | rgb[:, :, 0] = r / 255.0
114 | rgb[:, :, 1] = g / 255.0
115 | rgb[:, :, 2] = b / 255.0
116 | if plot:
117 | plt.imshow(rgb)
118 | plt.show()
119 | else:
120 | return rgb
121 |
122 |
123 | def generate_param_report(logfile, param):
124 | log_file = open(logfile, 'w')
125 | # for key, val in param.items():
126 | # log_file.write(key + ':' + str(val) + '\n')
127 | log_file.write(str(param))
128 | log_file.close()
129 |
130 |
131 | def cross_entropy2d(logit, target, ignore_index=255, weight=None, size_average=True, batch_average=True):
132 | n, c, h, w = logit.size()
133 | # logit = logit.permute(0, 2, 3, 1)
134 | target = target.squeeze(1)
135 | if weight is None:
136 | criterion = nn.CrossEntropyLoss(
137 | weight=weight, ignore_index=ignore_index, size_average=False)
138 | else:
139 | criterion = nn.CrossEntropyLoss(weight=torch.from_numpy(np.array(
140 | weight)).float().cuda(), ignore_index=ignore_index, size_average=False)
141 | loss = criterion(logit, target.long())
142 |
143 | if size_average:
144 | loss /= (h * w)
145 |
146 | if batch_average:
147 | loss /= n
148 |
149 | return loss
150 |
151 |
152 | def lr_poly(base_lr, iter_, max_iter=100, power=0.9):
153 | return base_lr * ((1 - float(iter_) / max_iter) ** power)
154 |
155 |
156 | def get_iou(pred, gt, n_classes=21):
157 | total_iou = 0.0
158 | for i in range(len(pred)):
159 | pred_tmp = pred[i]
160 | gt_tmp = gt[i]
161 |
162 | intersect = [0] * n_classes
163 | union = [0] * n_classes
164 | for j in range(n_classes):
165 | match = (pred_tmp == j) + (gt_tmp == j)
166 |
167 | it = torch.sum(match == 2).item()
168 | un = torch.sum(match > 0).item()
169 |
170 | intersect[j] += it
171 | union[j] += un
172 |
173 | iou = []
174 | for k in range(n_classes):
175 | if union[k] == 0:
176 | continue
177 | iou.append(intersect[k] / union[k])
178 |
179 | img_iou = (sum(iou) / len(iou))
180 | total_iou += img_iou
181 |
182 | return total_iou
183 |
184 |
185 | def get_dice(pred, gt):
186 | total_dice = 0.0
187 | pred = pred.long()
188 | gt = gt.long()
189 | for i in range(len(pred)):
190 | pred_tmp = pred[i]
191 | gt_tmp = gt[i]
192 | dice = 2.0*torch.sum(pred_tmp*gt_tmp).item()/(1.0 +
193 | torch.sum(pred_tmp**2)+torch.sum(gt_tmp**2)).item()
194 | print(dice)
195 | total_dice += dice
196 |
197 | return total_dice
198 |
199 |
200 | def get_mc_dice(pred, gt, num=2):
201 | # num is the total number of classes, include the background
202 | total_dice = np.zeros(num-1)
203 | pred = pred.long()
204 | gt = gt.long()
205 | for i in range(len(pred)):
206 | for j in range(1, num):
207 | pred_tmp = (pred[i] == j)
208 | gt_tmp = (gt[i] == j)
209 | dice = 2.0*torch.sum(pred_tmp*gt_tmp).item()/(1.0 +
210 | torch.sum(pred_tmp**2)+torch.sum(gt_tmp**2)).item()
211 | total_dice[j-1] += dice
212 | return total_dice
213 |
214 |
215 | def post_processing(prediction):
216 | prediction = nd.binary_fill_holes(prediction)
217 | label_cc, num_cc = measure.label(prediction, return_num=True)
218 | total_cc = np.sum(prediction)
219 | measure.regionprops(label_cc)
220 | for cc in range(1, num_cc+1):
221 | single_cc = (label_cc == cc)
222 | single_vol = np.sum(single_cc)
223 | if single_vol/total_cc < 0.2:
224 | prediction[single_cc] = 0
225 |
226 | return prediction
227 |
--------------------------------------------------------------------------------
/code/networks/VoxResNet.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from __future__ import print_function, division
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 |
9 | class SEBlock(nn.Module):
10 | def __init__(self, in_channels, r):
11 | super(SEBlock, self).__init__()
12 |
13 | redu_chns = int(in_channels / r)
14 | self.se_layers = nn.Sequential(
15 | nn.AdaptiveAvgPool3d(1),
16 | nn.Conv3d(in_channels, redu_chns, kernel_size=1, padding=0),
17 | nn.ReLU(),
18 | nn.Conv3d(redu_chns, in_channels, kernel_size=1, padding=0),
19 | nn.ReLU())
20 |
21 | def forward(self, x):
22 | f = self.se_layers(x)
23 | return f * x + x
24 |
25 |
26 | class VoxRex(nn.Module):
27 | def __init__(self, in_channels):
28 | super(VoxRex, self).__init__()
29 | self.block = nn.Sequential(
30 | nn.InstanceNorm3d(in_channels),
31 | nn.ReLU(inplace=True),
32 | nn.Conv3d(in_channels, in_channels,
33 | kernel_size=3, padding=1, bias=False),
34 | nn.InstanceNorm3d(in_channels),
35 | nn.ReLU(inplace=True),
36 | nn.Conv3d(in_channels, in_channels,
37 | kernel_size=3, padding=1, bias=False)
38 | )
39 |
40 | def forward(self, x):
41 | return self.block(x)+x
42 |
43 |
44 | class ConvBlock(nn.Module):
45 | """two convolution layers with batch norm and leaky relu"""
46 |
47 | def __init__(self, in_channels, out_channels):
48 | super(ConvBlock, self).__init__()
49 | self.conv_conv = nn.Sequential(
50 | nn.InstanceNorm3d(in_channels),
51 | nn.ReLU(inplace=True),
52 | nn.Conv3d(in_channels, out_channels,
53 | kernel_size=3, padding=1, bias=False),
54 | nn.InstanceNorm3d(out_channels),
55 | nn.ReLU(inplace=True),
56 | nn.Conv3d(out_channels, out_channels,
57 | kernel_size=3, padding=1, bias=False)
58 | )
59 |
60 | def forward(self, x):
61 | return self.conv_conv(x)
62 |
63 |
64 | class UpBlock(nn.Module):
65 | """Upssampling followed by ConvBlock"""
66 |
67 | def __init__(self, in_channels, out_channels):
68 | super(UpBlock, self).__init__()
69 | self.up = nn.Upsample(
70 | scale_factor=2, mode='trilinear', align_corners=True)
71 | self.conv = ConvBlock(in_channels, out_channels)
72 |
73 | def forward(self, x1, x2):
74 | x1 = self.up(x1)
75 | x = torch.cat([x2, x1], dim=1)
76 | return self.conv(x)
77 |
78 |
79 | class VoxResNet(nn.Module):
80 | def __init__(self, in_chns=1, feature_chns=64, class_num=2):
81 | super(VoxResNet, self).__init__()
82 | self.in_chns = in_chns
83 | self.ft_chns = feature_chns
84 | self.n_class = class_num
85 |
86 | self.conv1 = nn.Conv3d(in_chns, feature_chns, kernel_size=3, padding=1)
87 | self.res1 = VoxRex(feature_chns)
88 | self.res2 = VoxRex(feature_chns)
89 | self.res3 = VoxRex(feature_chns)
90 | self.res4 = VoxRex(feature_chns)
91 | self.res5 = VoxRex(feature_chns)
92 | self.res6 = VoxRex(feature_chns)
93 |
94 | self.up1 = UpBlock(feature_chns * 2, feature_chns)
95 | self.up2 = UpBlock(feature_chns * 2, feature_chns)
96 |
97 | self.out = nn.Conv3d(feature_chns, self.n_class, kernel_size=1)
98 |
99 | self.maxpool = nn.MaxPool3d(2)
100 | self.upsample = nn.Upsample(
101 | scale_factor=2, mode='trilinear', align_corners=True)
102 |
103 | def forward(self, x):
104 | x = self.maxpool(self.conv1(x))
105 | x1 = self.res1(x)
106 | x2 = self.res2(x1)
107 | x2_pool = self.maxpool(x2)
108 | x3 = self.res3(x2_pool)
109 | x4 = self.maxpool(self.res4(x3))
110 | x5 = self.res5(x4)
111 | x6 = self.res6(x5)
112 | up1 = self.up1(x6, x2_pool)
113 | up2 = self.up2(up1, x)
114 | up = self.upsample(up2)
115 | out = self.out(up)
116 | return out
117 |
--------------------------------------------------------------------------------
/code/networks/__pycache__/attention.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/attention.cpython-310.pyc
--------------------------------------------------------------------------------
/code/networks/__pycache__/attention.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/attention.cpython-38.pyc
--------------------------------------------------------------------------------
/code/networks/__pycache__/config.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/config.cpython-310.pyc
--------------------------------------------------------------------------------
/code/networks/__pycache__/config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/config.cpython-38.pyc
--------------------------------------------------------------------------------
/code/networks/__pycache__/discriminator.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/discriminator.cpython-38.pyc
--------------------------------------------------------------------------------
/code/networks/__pycache__/efficient_encoder.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/efficient_encoder.cpython-310.pyc
--------------------------------------------------------------------------------
/code/networks/__pycache__/efficient_encoder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/efficient_encoder.cpython-38.pyc
--------------------------------------------------------------------------------
/code/networks/__pycache__/efficientunet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/efficientunet.cpython-310.pyc
--------------------------------------------------------------------------------
/code/networks/__pycache__/efficientunet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/efficientunet.cpython-38.pyc
--------------------------------------------------------------------------------
/code/networks/__pycache__/enet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/enet.cpython-310.pyc
--------------------------------------------------------------------------------
/code/networks/__pycache__/enet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/enet.cpython-38.pyc
--------------------------------------------------------------------------------
/code/networks/__pycache__/mamba_sys.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/mamba_sys.cpython-310.pyc
--------------------------------------------------------------------------------
/code/networks/__pycache__/net_factory.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/net_factory.cpython-310.pyc
--------------------------------------------------------------------------------
/code/networks/__pycache__/net_factory.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/net_factory.cpython-38.pyc
--------------------------------------------------------------------------------
/code/networks/__pycache__/neural_network.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/neural_network.cpython-310.pyc
--------------------------------------------------------------------------------
/code/networks/__pycache__/neural_network.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/neural_network.cpython-38.pyc
--------------------------------------------------------------------------------
/code/networks/__pycache__/nnunet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/nnunet.cpython-310.pyc
--------------------------------------------------------------------------------
/code/networks/__pycache__/nnunet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/nnunet.cpython-38.pyc
--------------------------------------------------------------------------------
/code/networks/__pycache__/pnet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/pnet.cpython-310.pyc
--------------------------------------------------------------------------------
/code/networks/__pycache__/pnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/pnet.cpython-38.pyc
--------------------------------------------------------------------------------
/code/networks/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-310.pyc
--------------------------------------------------------------------------------
/code/networks/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-38.pyc
--------------------------------------------------------------------------------
/code/networks/__pycache__/unet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/unet.cpython-310.pyc
--------------------------------------------------------------------------------
/code/networks/__pycache__/unet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/unet.cpython-38.pyc
--------------------------------------------------------------------------------
/code/networks/__pycache__/vision_mamba.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/vision_mamba.cpython-310.pyc
--------------------------------------------------------------------------------
/code/networks/__pycache__/vision_transformer.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/vision_transformer.cpython-310.pyc
--------------------------------------------------------------------------------
/code/networks/__pycache__/vision_transformer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/networks/__pycache__/vision_transformer.cpython-38.pyc
--------------------------------------------------------------------------------
/code/networks/attention.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | try:
4 | from inplace_abn import InPlaceABN
5 | except ImportError:
6 | InPlaceABN = None
7 |
8 |
9 | class Conv2dReLU(nn.Sequential):
10 | def __init__(
11 | self,
12 | in_channels,
13 | out_channels,
14 | kernel_size,
15 | padding=0,
16 | stride=1,
17 | use_batchnorm=True,
18 | ):
19 |
20 | if use_batchnorm == "inplace" and InPlaceABN is None:
21 | raise RuntimeError(
22 | "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
23 | + "To install see: https://github.com/mapillary/inplace_abn"
24 | )
25 |
26 | super().__init__()
27 |
28 | conv = nn.Conv2d(
29 | in_channels,
30 | out_channels,
31 | kernel_size,
32 | stride=stride,
33 | padding=padding,
34 | bias=not (use_batchnorm),
35 | )
36 | relu = nn.ReLU(inplace=True)
37 |
38 | if use_batchnorm == "inplace":
39 | bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0)
40 | relu = nn.Identity()
41 |
42 | elif use_batchnorm and use_batchnorm != "inplace":
43 | bn = nn.BatchNorm2d(out_channels)
44 |
45 | else:
46 | bn = nn.Identity()
47 |
48 | super(Conv2dReLU, self).__init__(conv, bn, relu)
49 |
50 |
51 | class SCSEModule(nn.Module):
52 | def __init__(self, in_channels, reduction=16):
53 | super().__init__()
54 | self.cSE = nn.Sequential(
55 | nn.AdaptiveAvgPool2d(1),
56 | nn.Conv2d(in_channels, in_channels // reduction, 1),
57 | nn.ReLU(inplace=True),
58 | nn.Conv2d(in_channels // reduction, in_channels, 1),
59 | nn.Sigmoid(),
60 | )
61 | self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid())
62 |
63 | def forward(self, x):
64 | return x * self.cSE(x) + x * self.sSE(x)
65 |
66 |
67 | class Activation(nn.Module):
68 |
69 | def __init__(self, name, **params):
70 |
71 | super().__init__()
72 |
73 | if name is None or name == 'identity':
74 | self.activation = nn.Identity(**params)
75 | elif name == 'sigmoid':
76 | self.activation = nn.Sigmoid()
77 | elif name == 'softmax2d':
78 | self.activation = nn.Softmax(dim=1, **params)
79 | elif name == 'softmax':
80 | self.activation = nn.Softmax(**params)
81 | elif name == 'logsoftmax':
82 | self.activation = nn.LogSoftmax(**params)
83 | elif callable(name):
84 | self.activation = name(**params)
85 | else:
86 | raise ValueError('Activation should be callable/sigmoid/softmax/logsoftmax/None; got {}'.format(name))
87 |
88 | def forward(self, x):
89 | return self.activation(x)
90 |
91 |
92 | class Attention(nn.Module):
93 |
94 | def __init__(self, name, **params):
95 | super().__init__()
96 |
97 | if name is None:
98 | self.attention = nn.Identity(**params)
99 | elif name == 'scse':
100 | self.attention = SCSEModule(**params)
101 | else:
102 | raise ValueError("Attention {} is not implemented".format(name))
103 |
104 | def forward(self, x):
105 | return self.attention(x)
106 |
107 |
108 | class Flatten(nn.Module):
109 | def forward(self, x):
110 | return x.view(x.shape[0], -1)
111 |
--------------------------------------------------------------------------------
/code/networks/attention_unet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | from networks.utils import UnetConv3, UnetUp3_CT, UnetGridGatingSignal3, UnetDsv3
4 | import torch.nn.functional as F
5 | from networks.networks_other import init_weights
6 | from networks.grid_attention_layer import GridAttentionBlock3D
7 |
8 |
9 | class Attention_UNet(nn.Module):
10 |
11 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3,
12 | nonlocal_mode='concatenation', attention_dsample=(2,2,2), is_batchnorm=True):
13 | super(Attention_UNet, self).__init__()
14 | self.is_deconv = is_deconv
15 | self.in_channels = in_channels
16 | self.is_batchnorm = is_batchnorm
17 | self.feature_scale = feature_scale
18 |
19 | filters = [64, 128, 256, 512, 1024]
20 | filters = [int(x / self.feature_scale) for x in filters]
21 |
22 | # downsampling
23 | self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
24 | self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2))
25 |
26 | self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
27 | self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2))
28 |
29 | self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
30 | self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2))
31 |
32 | self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
33 | self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2))
34 |
35 | self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
36 | self.gating = UnetGridGatingSignal3(filters[4], filters[4], kernel_size=(1, 1, 1), is_batchnorm=self.is_batchnorm)
37 |
38 | # attention blocks
39 | self.attentionblock2 = MultiAttentionBlock(in_size=filters[1], gate_size=filters[2], inter_size=filters[1],
40 | nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample)
41 | self.attentionblock3 = MultiAttentionBlock(in_size=filters[2], gate_size=filters[3], inter_size=filters[2],
42 | nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample)
43 | self.attentionblock4 = MultiAttentionBlock(in_size=filters[3], gate_size=filters[4], inter_size=filters[3],
44 | nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample)
45 |
46 | # upsampling
47 | self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm)
48 | self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm)
49 | self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm)
50 | self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm)
51 |
52 | # deep supervision
53 | self.dsv4 = UnetDsv3(in_size=filters[3], out_size=n_classes, scale_factor=8)
54 | self.dsv3 = UnetDsv3(in_size=filters[2], out_size=n_classes, scale_factor=4)
55 | self.dsv2 = UnetDsv3(in_size=filters[1], out_size=n_classes, scale_factor=2)
56 | self.dsv1 = nn.Conv3d(in_channels=filters[0], out_channels=n_classes, kernel_size=1)
57 |
58 | # final conv (without any concat)
59 | self.final = nn.Conv3d(n_classes*4, n_classes, 1)
60 |
61 | # initialise weights
62 | for m in self.modules():
63 | if isinstance(m, nn.Conv3d):
64 | init_weights(m, init_type='kaiming')
65 | elif isinstance(m, nn.BatchNorm3d):
66 | init_weights(m, init_type='kaiming')
67 |
68 | def forward(self, inputs):
69 | # Feature Extraction
70 | conv1 = self.conv1(inputs)
71 | maxpool1 = self.maxpool1(conv1)
72 |
73 | conv2 = self.conv2(maxpool1)
74 | maxpool2 = self.maxpool2(conv2)
75 |
76 | conv3 = self.conv3(maxpool2)
77 | maxpool3 = self.maxpool3(conv3)
78 |
79 | conv4 = self.conv4(maxpool3)
80 | maxpool4 = self.maxpool4(conv4)
81 |
82 | # Gating Signal Generation
83 | center = self.center(maxpool4)
84 | gating = self.gating(center)
85 |
86 | # Attention Mechanism
87 | # Upscaling Part (Decoder)
88 | g_conv4, att4 = self.attentionblock4(conv4, gating)
89 | up4 = self.up_concat4(g_conv4, center)
90 | g_conv3, att3 = self.attentionblock3(conv3, up4)
91 | up3 = self.up_concat3(g_conv3, up4)
92 | g_conv2, att2 = self.attentionblock2(conv2, up3)
93 | up2 = self.up_concat2(g_conv2, up3)
94 | up1 = self.up_concat1(conv1, up2)
95 |
96 | # Deep Supervision
97 | dsv4 = self.dsv4(up4)
98 | dsv3 = self.dsv3(up3)
99 | dsv2 = self.dsv2(up2)
100 | dsv1 = self.dsv1(up1)
101 | final = self.final(torch.cat([dsv1,dsv2,dsv3,dsv4], dim=1))
102 |
103 | return final
104 |
105 |
106 | @staticmethod
107 | def apply_argmax_softmax(pred):
108 | log_p = F.softmax(pred, dim=1)
109 |
110 | return log_p
111 |
112 |
113 | class MultiAttentionBlock(nn.Module):
114 | def __init__(self, in_size, gate_size, inter_size, nonlocal_mode, sub_sample_factor):
115 | super(MultiAttentionBlock, self).__init__()
116 | self.gate_block_1 = GridAttentionBlock3D(in_channels=in_size, gating_channels=gate_size,
117 | inter_channels=inter_size, mode=nonlocal_mode,
118 | sub_sample_factor= sub_sample_factor)
119 | self.gate_block_2 = GridAttentionBlock3D(in_channels=in_size, gating_channels=gate_size,
120 | inter_channels=inter_size, mode=nonlocal_mode,
121 | sub_sample_factor=sub_sample_factor)
122 | self.combine_gates = nn.Sequential(nn.Conv3d(in_size*2, in_size, kernel_size=1, stride=1, padding=0),
123 | nn.BatchNorm3d(in_size),
124 | nn.ReLU(inplace=True)
125 | )
126 |
127 | # initialise the blocks
128 | for m in self.children():
129 | if m.__class__.__name__.find('GridAttentionBlock3D') != -1: continue
130 | init_weights(m, init_type='kaiming')
131 |
132 | def forward(self, input, gating_signal):
133 | gate_1, attention_1 = self.gate_block_1(input, gating_signal)
134 | gate_2, attention_2 = self.gate_block_2(input, gating_signal)
135 |
136 | return self.combine_gates(torch.cat([gate_1, gate_2], 1)), torch.cat([attention_1, attention_2], 1)
--------------------------------------------------------------------------------
/code/networks/config.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------'
7 |
8 | import os
9 | import yaml
10 | from yacs.config import CfgNode as CN
11 |
12 | _C = CN()
13 |
14 | # Base config files
15 | _C.BASE = ['']
16 |
17 | # -----------------------------------------------------------------------------
18 | # Data settings
19 | # -----------------------------------------------------------------------------
20 | _C.DATA = CN()
21 | # Batch size for a single GPU, could be overwritten by command line argument
22 | _C.DATA.BATCH_SIZE = 128
23 | # Path to dataset, could be overwritten by command line argument
24 | _C.DATA.DATA_PATH = ''
25 | # Dataset name
26 | _C.DATA.DATASET = 'imagenet'
27 | # Input image size
28 | _C.DATA.IMG_SIZE = 224
29 | # Interpolation to resize image (random, bilinear, bicubic)
30 | _C.DATA.INTERPOLATION = 'bicubic'
31 | # Use zipped dataset instead of folder dataset
32 | # could be overwritten by command line argument
33 | _C.DATA.ZIP_MODE = False
34 | # Cache Data in Memory, could be overwritten by command line argument
35 | _C.DATA.CACHE_MODE = 'part'
36 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
37 | _C.DATA.PIN_MEMORY = True
38 | # Number of data loading threads
39 | _C.DATA.NUM_WORKERS = 8
40 |
41 | # -----------------------------------------------------------------------------
42 | # Model settings
43 | # -----------------------------------------------------------------------------
44 | _C.MODEL = CN()
45 | # Model type
46 | _C.MODEL.TYPE = 'swin'
47 | # Model name
48 | _C.MODEL.NAME = 'swin_tiny_patch4_window7_224'
49 | # Checkpoint to resume, could be overwritten by command line argument
50 | _C.MODEL.PRETRAIN_CKPT = './pretrained_ckpt/swin_tiny_patch4_window7_224.pth'
51 | _C.MODEL.RESUME = ''
52 | # Number of classes, overwritten in data preparation
53 | _C.MODEL.NUM_CLASSES = 1000
54 | # Dropout rate
55 | _C.MODEL.DROP_RATE = 0.0
56 | # Drop path rate
57 | _C.MODEL.DROP_PATH_RATE = 0.1
58 | # Label Smoothing
59 | _C.MODEL.LABEL_SMOOTHING = 0.1
60 |
61 | # Swin Transformer parameters
62 | _C.MODEL.SWIN = CN()
63 | _C.MODEL.SWIN.PATCH_SIZE = 4
64 | _C.MODEL.SWIN.IN_CHANS = 3
65 | _C.MODEL.SWIN.EMBED_DIM = 96
66 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
67 | _C.MODEL.SWIN.DECODER_DEPTHS = [2, 2, 6, 2]
68 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
69 | _C.MODEL.SWIN.WINDOW_SIZE = 7
70 | _C.MODEL.SWIN.MLP_RATIO = 4.
71 | _C.MODEL.SWIN.QKV_BIAS = True
72 | _C.MODEL.SWIN.QK_SCALE = None
73 | _C.MODEL.SWIN.APE = False
74 | _C.MODEL.SWIN.PATCH_NORM = True
75 | _C.MODEL.SWIN.FINAL_UPSAMPLE= "expand_first"
76 |
77 | # -----------------------------------------------------------------------------
78 | # Training settings
79 | # -----------------------------------------------------------------------------
80 | _C.TRAIN = CN()
81 | _C.TRAIN.START_EPOCH = 0
82 | _C.TRAIN.EPOCHS = 300
83 | _C.TRAIN.WARMUP_EPOCHS = 20
84 | _C.TRAIN.WEIGHT_DECAY = 0.05
85 | _C.TRAIN.BASE_LR = 5e-4
86 | _C.TRAIN.WARMUP_LR = 5e-7
87 | _C.TRAIN.MIN_LR = 5e-6
88 | # Clip gradient norm
89 | _C.TRAIN.CLIP_GRAD = 5.0
90 | # Auto resume from latest checkpoint
91 | _C.TRAIN.AUTO_RESUME = True
92 | # Gradient accumulation steps
93 | # could be overwritten by command line argument
94 | _C.TRAIN.ACCUMULATION_STEPS = 0
95 | # Whether to use gradient checkpointing to save memory
96 | # could be overwritten by command line argument
97 | _C.TRAIN.USE_CHECKPOINT = False
98 |
99 | # LR scheduler
100 | _C.TRAIN.LR_SCHEDULER = CN()
101 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
102 | # Epoch interval to decay LR, used in StepLRScheduler
103 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
104 | # LR decay rate, used in StepLRScheduler
105 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
106 |
107 | # Optimizer
108 | _C.TRAIN.OPTIMIZER = CN()
109 | _C.TRAIN.OPTIMIZER.NAME = 'adamw'
110 | # Optimizer Epsilon
111 | _C.TRAIN.OPTIMIZER.EPS = 1e-8
112 | # Optimizer Betas
113 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
114 | # SGD momentum
115 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
116 |
117 | # -----------------------------------------------------------------------------
118 | # Augmentation settings
119 | # -----------------------------------------------------------------------------
120 | _C.AUG = CN()
121 | # Color jitter factor
122 | _C.AUG.COLOR_JITTER = 0.4
123 | # Use AutoAugment policy. "v0" or "original"
124 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
125 | # Random erase prob
126 | _C.AUG.REPROB = 0.25
127 | # Random erase mode
128 | _C.AUG.REMODE = 'pixel'
129 | # Random erase count
130 | _C.AUG.RECOUNT = 1
131 | # Mixup alpha, mixup enabled if > 0
132 | _C.AUG.MIXUP = 0.8
133 | # Cutmix alpha, cutmix enabled if > 0
134 | _C.AUG.CUTMIX = 1.0
135 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set
136 | _C.AUG.CUTMIX_MINMAX = None
137 | # Probability of performing mixup or cutmix when either/both is enabled
138 | _C.AUG.MIXUP_PROB = 1.0
139 | # Probability of switching to cutmix when both mixup and cutmix enabled
140 | _C.AUG.MIXUP_SWITCH_PROB = 0.5
141 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
142 | _C.AUG.MIXUP_MODE = 'batch'
143 |
144 | # -----------------------------------------------------------------------------
145 | # Testing settings
146 | # -----------------------------------------------------------------------------
147 | _C.TEST = CN()
148 | # Whether to use center crop when testing
149 | _C.TEST.CROP = True
150 |
151 | # -----------------------------------------------------------------------------
152 | # Misc
153 | # -----------------------------------------------------------------------------
154 | # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2')
155 | # overwritten by command line argument
156 | _C.AMP_OPT_LEVEL = ''
157 | # Path to output folder, overwritten by command line argument
158 | _C.OUTPUT = ''
159 | # Tag of experiment, overwritten by command line argument
160 | _C.TAG = 'default'
161 | # Frequency to save checkpoint
162 | _C.SAVE_FREQ = 1
163 | # Frequency to logging info
164 | _C.PRINT_FREQ = 10
165 | # Fixed random seed
166 | _C.SEED = 0
167 | # Perform evaluation only, overwritten by command line argument
168 | _C.EVAL_MODE = False
169 | # Test throughput only, overwritten by command line argument
170 | _C.THROUGHPUT_MODE = False
171 | # local rank for DistributedDataParallel, given by command line argument
172 | _C.LOCAL_RANK = 0
173 |
174 |
175 | def _update_config_from_file(config, cfg_file):
176 | config.defrost()
177 | with open(cfg_file, 'r') as f:
178 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
179 |
180 | for cfg in yaml_cfg.setdefault('BASE', ['']):
181 | if cfg:
182 | _update_config_from_file(
183 | config, os.path.join(os.path.dirname(cfg_file), cfg)
184 | )
185 | print('=> merge config from {}'.format(cfg_file))
186 | config.merge_from_file(cfg_file)
187 | config.freeze()
188 |
189 |
190 | def update_config(config, args):
191 | _update_config_from_file(config, args.cfg)
192 |
193 | config.defrost()
194 | if args.opts:
195 | config.merge_from_list(args.opts)
196 |
197 | # merge from specific arguments
198 | if args.batch_size:
199 | config.DATA.BATCH_SIZE = args.batch_size
200 | if args.zip:
201 | config.DATA.ZIP_MODE = True
202 | if args.cache_mode:
203 | config.DATA.CACHE_MODE = args.cache_mode
204 | if args.resume:
205 | config.MODEL.RESUME = args.resume
206 | if args.accumulation_steps:
207 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps
208 | if args.use_checkpoint:
209 | config.TRAIN.USE_CHECKPOINT = True
210 | if args.amp_opt_level:
211 | config.AMP_OPT_LEVEL = args.amp_opt_level
212 | if args.tag:
213 | config.TAG = args.tag
214 | if args.eval:
215 | config.EVAL_MODE = True
216 | if args.throughput:
217 | config.THROUGHPUT_MODE = True
218 |
219 | config.freeze()
220 |
221 |
222 | def get_config(args):
223 | """Get a yacs CfgNode object with default values."""
224 | # Return a clone so that the defaults will not be altered
225 | # This is for the "local variable" use pattern
226 | config = _C.clone()
227 | update_config(config, args)
228 |
229 | return config
230 |
--------------------------------------------------------------------------------
/code/networks/discriminator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class FC3DDiscriminator(nn.Module):
7 |
8 | def __init__(self, num_classes, ndf=64, n_channel=1):
9 | super(FC3DDiscriminator, self).__init__()
10 | # downsample 16
11 | self.conv0 = nn.Conv3d(
12 | num_classes, ndf, kernel_size=4, stride=2, padding=1)
13 | self.conv1 = nn.Conv3d(
14 | n_channel, ndf, kernel_size=4, stride=2, padding=1)
15 |
16 | self.conv2 = nn.Conv3d(ndf, ndf*2, kernel_size=4, stride=2, padding=1)
17 | self.conv3 = nn.Conv3d(
18 | ndf*2, ndf*4, kernel_size=4, stride=2, padding=1)
19 | self.conv4 = nn.Conv3d(
20 | ndf*4, ndf*8, kernel_size=4, stride=2, padding=1)
21 | self.avgpool = nn.AvgPool3d((6, 6, 6)) # (D/16, W/16, H/16)
22 | self.classifier = nn.Linear(ndf*8, 2)
23 |
24 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
25 | self.dropout = nn.Dropout3d(0.5)
26 | self.Softmax = nn.Softmax()
27 |
28 | def forward(self, map, image):
29 | batch_size = map.shape[0]
30 | map_feature = self.conv0(map)
31 | image_feature = self.conv1(image)
32 | x = torch.add(map_feature, image_feature)
33 | x = self.leaky_relu(x)
34 | x = self.dropout(x)
35 |
36 | x = self.conv2(x)
37 | x = self.leaky_relu(x)
38 | x = self.dropout(x)
39 |
40 | x = self.conv3(x)
41 | x = self.leaky_relu(x)
42 | x = self.dropout(x)
43 |
44 | x = self.conv4(x)
45 | x = self.leaky_relu(x)
46 |
47 | x = self.avgpool(x)
48 |
49 | x = x.view(batch_size, -1)
50 |
51 | x = self.classifier(x)
52 | x = x.reshape((batch_size, 2))
53 | # x = self.Softmax(x)
54 |
55 | return x
56 |
57 |
58 | class FCDiscriminator(nn.Module):
59 |
60 | def __init__(self, num_classes, ndf=64, n_channel=1):
61 | super(FCDiscriminator, self).__init__()
62 | self.conv0 = nn.Conv2d(
63 | num_classes, ndf, kernel_size=4, stride=2, padding=1)
64 | self.conv1 = nn.Conv2d(
65 | n_channel, ndf, kernel_size=4, stride=2, padding=1)
66 | self.conv2 = nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1)
67 | self.conv3 = nn.Conv2d(
68 | ndf*2, ndf*4, kernel_size=4, stride=2, padding=1)
69 | self.conv4 = nn.Conv2d(
70 | ndf*4, ndf*8, kernel_size=4, stride=2, padding=1)
71 | self.classifier = nn.Linear(ndf*32, 2)
72 | self.avgpool = nn.AvgPool2d((7, 7))
73 |
74 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
75 | self.dropout = nn.Dropout2d(0.5)
76 | # self.up_sample = nn.Upsample(scale_factor=32, mode='bilinear')
77 | # self.sigmoid = nn.Sigmoid()
78 |
79 | def forward(self, map, feature):
80 | map_feature = self.conv0(map)
81 | image_feature = self.conv1(feature)
82 |
83 | x = torch.add(map_feature, image_feature)
84 |
85 | x = self.conv2(x)
86 | x = self.leaky_relu(x)
87 | x = self.dropout(x)
88 |
89 | x = self.conv3(x)
90 | x = self.leaky_relu(x)
91 | x = self.dropout(x)
92 |
93 | x = self.conv4(x)
94 | x = self.leaky_relu(x)
95 | x = self.avgpool(x)
96 | x = x.view(x.size(0), -1)
97 | x = self.classifier(x)
98 | # x = self.up_sample(x)
99 | # x = self.sigmoid(x)
100 |
101 | return x
102 |
--------------------------------------------------------------------------------
/code/networks/efficientunet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from networks.attention import *
6 | from networks.efficient_encoder import get_encoder
7 |
8 |
9 | def initialize_decoder(module):
10 | for m in module.modules():
11 |
12 | if isinstance(m, nn.Conv2d):
13 | nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu")
14 | if m.bias is not None:
15 | nn.init.constant_(m.bias, 0)
16 |
17 | elif isinstance(m, nn.BatchNorm2d):
18 | nn.init.constant_(m.weight, 1)
19 | nn.init.constant_(m.bias, 0)
20 |
21 | elif isinstance(m, nn.Linear):
22 | nn.init.xavier_uniform_(m.weight)
23 | if m.bias is not None:
24 | nn.init.constant_(m.bias, 0)
25 |
26 |
27 | class DecoderBlock(nn.Module):
28 | def __init__(
29 | self,
30 | in_channels,
31 | skip_channels,
32 | out_channels,
33 | use_batchnorm=True,
34 | attention_type=None,
35 | ):
36 | super().__init__()
37 | self.conv1 = Conv2dReLU(
38 | in_channels + skip_channels,
39 | out_channels,
40 | kernel_size=3,
41 | padding=1,
42 | use_batchnorm=use_batchnorm,
43 | )
44 | self.attention1 = Attention(attention_type, in_channels=in_channels + skip_channels)
45 | self.conv2 = Conv2dReLU(
46 | out_channels,
47 | out_channels,
48 | kernel_size=3,
49 | padding=1,
50 | use_batchnorm=use_batchnorm,
51 | )
52 | self.attention2 = Attention(attention_type, in_channels=out_channels)
53 |
54 | def forward(self, x, skip=None):
55 | x = F.interpolate(x, scale_factor=2, mode="nearest")
56 | if skip is not None:
57 | x = torch.cat([x, skip], dim=1)
58 | x = self.attention1(x)
59 | x = self.conv1(x)
60 | x = self.conv2(x)
61 | x = self.attention2(x)
62 | return x
63 |
64 |
65 | class CenterBlock(nn.Sequential):
66 | def __init__(self, in_channels, out_channels, use_batchnorm=True):
67 | conv1 = Conv2dReLU(
68 | in_channels,
69 | out_channels,
70 | kernel_size=3,
71 | padding=1,
72 | use_batchnorm=use_batchnorm,
73 | )
74 | conv2 = Conv2dReLU(
75 | out_channels,
76 | out_channels,
77 | kernel_size=3,
78 | padding=1,
79 | use_batchnorm=use_batchnorm,
80 | )
81 | super().__init__(conv1, conv2)
82 |
83 |
84 | class UnetDecoder(nn.Module):
85 | def __init__(
86 | self,
87 | encoder_channels,
88 | decoder_channels,
89 | n_blocks=5,
90 | use_batchnorm=True,
91 | attention_type=None,
92 | center=False,
93 | ):
94 | super().__init__()
95 |
96 | if n_blocks != len(decoder_channels):
97 | raise ValueError(
98 | "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
99 | n_blocks, len(decoder_channels)
100 | )
101 | )
102 |
103 | encoder_channels = encoder_channels[1:] # remove first skip with same spatial resolution
104 | encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder
105 |
106 | # computing blocks input and output channels
107 | head_channels = encoder_channels[0]
108 | in_channels = [head_channels] + list(decoder_channels[:-1])
109 | skip_channels = list(encoder_channels[1:]) + [0]
110 | out_channels = decoder_channels
111 |
112 | if center:
113 | self.center = CenterBlock(
114 | head_channels, head_channels, use_batchnorm=use_batchnorm
115 | )
116 | else:
117 | self.center = nn.Identity()
118 |
119 | # combine decoder keyword arguments
120 | kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type)
121 | blocks = [
122 | DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
123 | for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels)
124 | ]
125 | self.blocks = nn.ModuleList(blocks)
126 |
127 | def forward(self, *features):
128 |
129 | features = features[1:] # remove first skip with same spatial resolution
130 | features = features[::-1] # reverse channels to start from head of encoder
131 |
132 | head = features[0]
133 | skips = features[1:]
134 |
135 | x = self.center(head)
136 | for i, decoder_block in enumerate(self.blocks):
137 | skip = skips[i] if i < len(skips) else None
138 | x = decoder_block(x, skip)
139 |
140 | return x
141 |
142 |
143 | class Effi_UNet(nn.Module):
144 | """Unet_ is a fully convolution neural network for image semantic segmentation
145 |
146 | Args:
147 | encoder_name: name of classification model (without last dense layers) used as feature
148 | extractor to build segmentation model.
149 | encoder_depth (int): number of stages used in decoder, larger depth - more features are generated.
150 | e.g. for depth=3 encoder will generate list of features with following spatial shapes
151 | [(H,W), (H/2, W/2), (H/4, W/4), (H/8, W/8)], so in general the deepest feature tensor will have
152 | spatial resolution (H/(2^depth), W/(2^depth)]
153 | encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet).
154 | decoder_channels: list of numbers of ``Conv2D`` layer filters in decoder blocks
155 | decoder_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers
156 | is used. If 'inplace' InplaceABN will be used, allows to decrease memory consumption.
157 | One of [True, False, 'inplace']
158 | decoder_attention_type: attention module used in decoder of the model
159 | One of [``None``, ``scse``]
160 | in_channels: number of input channels for model, default is 3.
161 | classes: a number of classes for output (output shape - ``(batch, classes, h, w)``).
162 | activation: activation function to apply after final convolution;
163 | One of [``sigmoid``, ``softmax``, ``logsoftmax``, ``identity``, callable, None]
164 | aux_params: if specified model will have additional classification auxiliary output
165 | build on top of encoder, supported params:
166 | - classes (int): number of classes
167 | - pooling (str): one of 'max', 'avg'. Default is 'avg'.
168 | - dropout (float): dropout factor in [0, 1)
169 | - activation (str): activation function to apply "sigmoid"/"softmax" (could be None to return logits)
170 |
171 | Returns:
172 | ``torch.nn.Module``: **Unet**
173 |
174 | .. _Unet:
175 | https://arxiv.org/pdf/1505.04597
176 |
177 | """
178 |
179 | def __init__(
180 | self,
181 | encoder_name: str = "resnet34",
182 | encoder_depth: int = 5,
183 | encoder_weights: str = "imagenet",
184 | decoder_use_batchnorm=True,
185 | decoder_channels=(256, 128, 64, 32, 16),
186 | decoder_attention_type=None,
187 | in_channels: int = 3,
188 | classes: int = 1):
189 | super().__init__()
190 |
191 | self.encoder = get_encoder(
192 | encoder_name,
193 | in_channels=in_channels,
194 | depth=encoder_depth,
195 | weights=encoder_weights,
196 | )
197 |
198 | self.decoder = UnetDecoder(
199 | encoder_channels=self.encoder.out_channels,
200 | decoder_channels=decoder_channels,
201 | n_blocks=encoder_depth,
202 | use_batchnorm=decoder_use_batchnorm,
203 | center=True if encoder_name.startswith("vgg") else False,
204 | attention_type=decoder_attention_type,
205 | )
206 | initialize_decoder(self.decoder)
207 | self.classifier = nn.Conv2d(decoder_channels[-1], classes, 1)
208 |
209 | def forward(self, x):
210 | """Sequentially pass `x` trough model`s encoder, decoder and heads"""
211 | features = self.encoder(x)
212 | decoder_output = self.decoder(*features)
213 | output = self.classifier(decoder_output)
214 |
215 | return output
216 |
217 |
218 | # unet = UNet('efficientnet-b3', encoder_weights='imagenet', in_channels=1, classes=1, decoder_attention_type="scse")
219 | # t = torch.rand(2, 1, 224, 224)
220 | # print(unet)
221 | # print(unet(t).shape)
222 |
--------------------------------------------------------------------------------
/code/networks/encoder_tool.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.utils.model_zoo as model_zoo
6 | from efficientnet_pytorch import EfficientNet
7 | from efficientnet_pytorch.utils import get_model_params, url_map
8 |
9 |
10 | class EncoderMixin:
11 | """Add encoder functionality such as:
12 | - output channels specification of feature tensors (produced by encoder)
13 | - patching first convolution for arbitrary input channels
14 | """
15 |
16 | @property
17 | def out_channels(self) -> List:
18 | """Return channels dimensions for each tensor of forward output of encoder"""
19 | return self._out_channels[: self._depth + 1]
20 |
21 | def set_in_channels(self, in_channels):
22 | """Change first convolution chennels"""
23 | if in_channels == 3:
24 | return
25 |
26 | self._in_channels = in_channels
27 | if self._out_channels[0] == 3:
28 | self._out_channels = tuple([in_channels] + list(self._out_channels)[1:])
29 |
30 | patch_first_conv(model=self, in_channels=in_channels)
31 |
32 |
33 | def patch_first_conv(model, in_channels):
34 | """Change first convolution layer input channels.
35 | In case:
36 | in_channels == 1 or in_channels == 2 -> reuse original weights
37 | in_channels > 3 -> make random kaiming normal initialization
38 | """
39 |
40 | # get first conv
41 | for module in model.modules():
42 | if isinstance(module, nn.Conv2d):
43 | break
44 |
45 | # change input channels for first conv
46 | module.in_channels = in_channels
47 | weight = module.weight.detach()
48 | reset = False
49 |
50 | if in_channels == 1:
51 | weight = weight.sum(1, keepdim=True)
52 | elif in_channels == 2:
53 | weight = weight[:, :2] * (3.0 / 2.0)
54 | else:
55 | reset = True
56 | weight = torch.Tensor(
57 | module.out_channels,
58 | module.in_channels // module.groups,
59 | *module.kernel_size
60 | )
61 |
62 | module.weight = nn.parameter.Parameter(weight)
63 | if reset:
64 | module.reset_parameters()
65 |
66 |
67 | class EfficientNetEncoder(EfficientNet, EncoderMixin):
68 | def __init__(self, stage_idxs, out_channels, model_name, depth=5):
69 |
70 | blocks_args, global_params = get_model_params(model_name, override_params=None)
71 | super().__init__(blocks_args, global_params)
72 |
73 | self._stage_idxs = list(stage_idxs) + [len(self._blocks)]
74 | self._out_channels = out_channels
75 | self._depth = depth
76 | self._in_channels = 3
77 |
78 | del self._fc
79 |
80 | def forward(self, x):
81 |
82 | features = [x]
83 |
84 | if self._depth > 0:
85 | x = self._swish(self._bn0(self._conv_stem(x)))
86 | features.append(x)
87 |
88 | if self._depth > 1:
89 | skip_connection_idx = 0
90 | for idx, block in enumerate(self._blocks):
91 | drop_connect_rate = self._global_params.drop_connect_rate
92 | if drop_connect_rate:
93 | drop_connect_rate *= float(idx) / len(self._blocks)
94 | x = block(x, drop_connect_rate=drop_connect_rate)
95 | if idx == self._stage_idxs[skip_connection_idx] - 1:
96 | skip_connection_idx += 1
97 | features.append(x)
98 | if skip_connection_idx + 1 == self._depth:
99 | break
100 | return features
101 |
102 | def load_state_dict(self, state_dict, **kwargs):
103 | state_dict.pop("_fc.bias")
104 | state_dict.pop("_fc.weight")
105 | super().load_state_dict(state_dict, **kwargs)
106 |
107 |
108 | def _get_pretrained_settings(encoder):
109 | pretrained_settings = {
110 | "imagenet": {
111 | "mean": [0.485, 0.456, 0.406],
112 | "std": [0.229, 0.224, 0.225],
113 | "url": url_map[encoder],
114 | "input_space": "RGB",
115 | "input_range": [0, 1],
116 | }
117 | }
118 | return pretrained_settings
119 |
120 |
121 | efficient_net_encoders = {
122 | "efficientnet-b0": {
123 | "encoder": EfficientNetEncoder,
124 | "pretrained_settings": _get_pretrained_settings("efficientnet-b0"),
125 | "params": {
126 | "out_channels": (3, 32, 24, 40, 112, 320),
127 | "stage_idxs": (3, 5, 9),
128 | "model_name": "efficientnet-b0",
129 | },
130 | },
131 | "efficientnet-b1": {
132 | "encoder": EfficientNetEncoder,
133 | "pretrained_settings": _get_pretrained_settings("efficientnet-b1"),
134 | "params": {
135 | "out_channels": (3, 32, 24, 40, 112, 320),
136 | "stage_idxs": (5, 8, 16),
137 | "model_name": "efficientnet-b1",
138 | },
139 | },
140 | "efficientnet-b2": {
141 | "encoder": EfficientNetEncoder,
142 | "pretrained_settings": _get_pretrained_settings("efficientnet-b2"),
143 | "params": {
144 | "out_channels": (3, 32, 24, 48, 120, 352),
145 | "stage_idxs": (5, 8, 16),
146 | "model_name": "efficientnet-b2",
147 | },
148 | },
149 | "efficientnet-b3": {
150 | "encoder": EfficientNetEncoder,
151 | "pretrained_settings": _get_pretrained_settings("efficientnet-b3"),
152 | "params": {
153 | "out_channels": (3, 40, 32, 48, 136, 384),
154 | "stage_idxs": (5, 8, 18),
155 | "model_name": "efficientnet-b3",
156 | },
157 | },
158 | "efficientnet-b4": {
159 | "encoder": EfficientNetEncoder,
160 | "pretrained_settings": _get_pretrained_settings("efficientnet-b4"),
161 | "params": {
162 | "out_channels": (3, 48, 32, 56, 160, 448),
163 | "stage_idxs": (6, 10, 22),
164 | "model_name": "efficientnet-b4",
165 | },
166 | },
167 | "efficientnet-b5": {
168 | "encoder": EfficientNetEncoder,
169 | "pretrained_settings": _get_pretrained_settings("efficientnet-b5"),
170 | "params": {
171 | "out_channels": (3, 48, 40, 64, 176, 512),
172 | "stage_idxs": (8, 13, 27),
173 | "model_name": "efficientnet-b5",
174 | },
175 | },
176 | "efficientnet-b6": {
177 | "encoder": EfficientNetEncoder,
178 | "pretrained_settings": _get_pretrained_settings("efficientnet-b6"),
179 | "params": {
180 | "out_channels": (3, 56, 40, 72, 200, 576),
181 | "stage_idxs": (9, 15, 31),
182 | "model_name": "efficientnet-b6",
183 | },
184 | },
185 | "efficientnet-b7": {
186 | "encoder": EfficientNetEncoder,
187 | "pretrained_settings": _get_pretrained_settings("efficientnet-b7"),
188 | "params": {
189 | "out_channels": (3, 64, 48, 80, 224, 640),
190 | "stage_idxs": (11, 18, 38),
191 | "model_name": "efficientnet-b7",
192 | },
193 | },
194 | }
195 |
196 | encoders = {}
197 | encoders.update(efficient_net_encoders)
198 |
199 |
200 | def get_encoder(name, in_channels=3, depth=5, weights=None):
201 | Encoder = encoders[name]["encoder"]
202 | params = encoders[name]["params"]
203 | params.update(depth=depth)
204 | encoder = Encoder(**params)
205 |
206 | if weights is not None:
207 | settings = encoders[name]["pretrained_settings"][weights]
208 | encoder.load_state_dict(model_zoo.load_url(settings["url"]))
209 |
210 | encoder.set_in_channels(in_channels)
211 |
212 | return encoder
213 |
--------------------------------------------------------------------------------
/code/networks/net_factory.py:
--------------------------------------------------------------------------------
1 | from networks.efficientunet import Effi_UNet
2 | from networks.enet import ENet
3 | from networks.pnet import PNet2D
4 | from networks.unet import UNet, UNet_DS, UNet_URPC, UNet_CCT
5 | import argparse
6 | from networks.vision_transformer import SwinUnet as ViT_seg
7 | from networks.config import get_config
8 | from networks.nnunet import initialize_network
9 |
10 |
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('--root_path', type=str,
13 | default='../data/ACDC', help='Name of Experiment')
14 | parser.add_argument('--exp', type=str,
15 | default='ACDC/Cross_Supervision_CNN_Trans2D', help='experiment_name')
16 | parser.add_argument('--fold', type=str,
17 | default='fold1', help='cross validation')
18 | parser.add_argument('--sup_type', type=str,
19 | default='scribble', help='supervision type')
20 |
21 | parser.add_argument('--model', type=str,
22 | default='unet', help='model_name')
23 | parser.add_argument('--max_iterations', type=int,
24 | default=30000, help='maximum epoch number to train')
25 | parser.add_argument('--batch_size', type=int, default=8,
26 | help='batch_size per gpu')
27 | parser.add_argument('--deterministic', type=int, default=1,
28 | help='whether use deterministic training')
29 | parser.add_argument('--base_lr', type=float, default=0.01,
30 | help='segmentation network learning rate')
31 | parser.add_argument('--patch_size', type=list, default=[224, 224],
32 | help='patch size of network input')
33 | parser.add_argument('--seed', type=int, default=1337, help='random seed')
34 | parser.add_argument('--num_classes', type=int, default=4,
35 | help='output channel of network')
36 | parser.add_argument(
37 | '--cfg', type=str, default="../code/configs/swin_tiny_patch4_window7_224_lite.yaml", help='path to config file', )
38 | parser.add_argument(
39 | "--opts",
40 | help="Modify config options by adding 'KEY VALUE' pairs. ",
41 | default=None,
42 | nargs='+',
43 | )
44 | parser.add_argument('--zip', action='store_true',
45 | help='use zipped dataset instead of folder dataset')
46 | parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
47 | help='no: no cache, '
48 | 'full: cache all data, '
49 | 'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
50 | parser.add_argument('--resume', help='resume from checkpoint')
51 | parser.add_argument('--accumulation-steps', type=int,
52 | help="gradient accumulation steps")
53 | parser.add_argument('--use-checkpoint', action='store_true',
54 | help="whether to use gradient checkpointing to save memory")
55 | parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],
56 | help='mixed precision opt level, if O0, no amp is used')
57 | parser.add_argument('--tag', help='tag of experiment')
58 | parser.add_argument('--eval', action='store_true',
59 | help='Perform evaluation only')
60 | parser.add_argument('--throughput', action='store_true',
61 | help='Test throughput only')
62 |
63 | # costs
64 | parser.add_argument('--ema_decay', type=float, default=0.99, help='ema_decay')
65 | parser.add_argument('--consistency_type', type=str,
66 | default="mse", help='consistency_type')
67 | parser.add_argument('--consistency', type=float,
68 | default=0.1, help='consistency')
69 | parser.add_argument('--consistency_rampup', type=float,
70 | default=200.0, help='consistency_rampup')
71 | args = parser.parse_args()
72 | config = get_config(args)
73 |
74 |
75 | def net_factory(net_type="unet", in_chns=1, class_num=3):
76 | if net_type == "unet":
77 | net = UNet(in_chns=in_chns, class_num=class_num).cuda()
78 | elif net_type == "enet":
79 | net = ENet(in_channels=in_chns, num_classes=class_num).cuda()
80 | elif net_type == "vnet":
81 | net = ENet(in_channels=in_chns, num_classes=class_num).cuda()
82 | elif net_type == "unet_ds":
83 | net = UNet_DS(in_chns=in_chns, class_num=class_num).cuda()
84 | elif net_type == "unet_cct":
85 | net = UNet_CCT(in_chns=in_chns, class_num=class_num).cuda()
86 | elif net_type == "unet_urpc":
87 | net = UNet_URPC(in_chns=in_chns, class_num=class_num).cuda()
88 | elif net_type == "efficient_unet":
89 | net = Effi_UNet('efficientnet-b3', encoder_weights='imagenet',
90 | in_channels=in_chns, classes=class_num).cuda()
91 | elif net_type == "ViT_Seg":
92 | net = ViT_seg(config, img_size=args.patch_size,
93 | num_classes=args.num_classes).cuda()
94 | elif net_type == "pnet":
95 | net = PNet2D(in_chns, class_num, 64, [1, 2, 4, 8, 16]).cuda()
96 | elif net_type == "nnUNet":
97 | net = initialize_network(num_classes=class_num).cuda()
98 | else:
99 | net = None
100 | return net
--------------------------------------------------------------------------------
/code/networks/net_factory_3d.py:
--------------------------------------------------------------------------------
1 | from networks.unet_3D import unet_3D
2 | from networks.vnet import VNet
3 | from networks.VoxResNet import VoxResNet
4 | from networks.attention_unet import Attention_UNet
5 |
6 |
7 | def net_factory_3d(net_type="unet_3D", in_chns=1, class_num=2):
8 | if net_type == "unet_3D":
9 | net = unet_3D(n_classes=class_num, in_channels=in_chns).cuda()
10 | elif net_type == "attention_unet":
11 | net = Attention_UNet(n_classes=class_num, in_channels=in_chns).cuda()
12 | elif net_type == "voxresnet":
13 | net = VoxResNet(in_chns=in_chns, feature_chns=64,
14 | class_num=class_num).cuda()
15 | elif net_type == "vnet":
16 | net = VNet(n_channels=in_chns, n_classes=class_num,
17 | normalization='batchnorm', has_dropout=True).cuda()
18 | else:
19 | net = None
20 | return net
21 |
--------------------------------------------------------------------------------
/code/networks/pnet.py:
--------------------------------------------------------------------------------
1 |
2 | # -*- coding: utf-8 -*-
3 | """
4 | An PyTorch implementation of the DeepIGeoS paper:
5 | Wang, Guotai and Zuluaga, Maria A and Li, Wenqi and Pratt, Rosalind and Patel, Premal A and Aertsen, Michael and Doel, Tom and David, Anna L and Deprest, Jan and Ourselin, S{\'e}bastien and others:
6 | DeepIGeoS: a deep interactive geodesic framework for medical image segmentation.
7 | TPAMI (7) 2018: 1559--1572
8 | Note that there are some modifications from the original paper, such as
9 | the use of leaky relu here.
10 | """
11 | from __future__ import division, print_function
12 |
13 | import torch
14 | import torch.nn as nn
15 |
16 |
17 | class PNetBlock(nn.Module):
18 | def __init__(self, in_channels, out_channels, dilation, padding):
19 | super(PNetBlock, self).__init__()
20 |
21 | self.in_chns = in_channels
22 | self.out_chns = out_channels
23 | self.dilation = dilation
24 | self.padding = padding
25 |
26 | self.conv1 = nn.Conv2d(self.in_chns, self.out_chns, kernel_size=3,
27 | padding=self.padding, dilation=self.dilation, groups=1, bias=True)
28 | self.conv2 = nn.Conv2d(self.out_chns, self.out_chns, kernel_size=3,
29 | padding=self.padding, dilation=self.dilation, groups=1, bias=True)
30 | self.in1 = nn.BatchNorm2d(self.out_chns)
31 | self.in2 = nn.BatchNorm2d(self.out_chns)
32 | self.ac1 = nn.LeakyReLU()
33 | self.ac2 = nn.LeakyReLU()
34 |
35 | def forward(self, x):
36 | x = self.conv1(x)
37 | x = self.in1(x)
38 | x = self.ac1(x)
39 | x = self.conv2(x)
40 | x = self.in2(x)
41 | x = self.ac2(x)
42 | return x
43 |
44 |
45 | class ConcatBlock(nn.Module):
46 | def __init__(self, in_channels, out_channels):
47 | super(ConcatBlock, self).__init__()
48 | self.in_chns = in_channels
49 | self.out_chns = out_channels
50 | self.conv1 = nn.Conv2d(
51 | self.in_chns, self.in_chns, kernel_size=1, padding=0)
52 | self.conv2 = nn.Conv2d(
53 | self.in_chns, self.out_chns, kernel_size=1, padding=0)
54 | self.ac1 = nn.LeakyReLU()
55 | self.ac2 = nn.LeakyReLU()
56 |
57 | def forward(self, x):
58 | x = self.conv1(x)
59 | x = self.ac1(x)
60 | x = self.conv2(x)
61 | x = self.ac2(x)
62 | return x
63 |
64 |
65 | class OutPutBlock(nn.Module):
66 | def __init__(self, in_channels, out_channels):
67 | super(OutPutBlock, self).__init__()
68 | self.in_chns = in_channels
69 | self.out_chns = out_channels
70 | self.conv1 = nn.Conv2d(
71 | self.in_chns, self.in_chns // 2, kernel_size=1, padding=0)
72 | self.conv2 = nn.Conv2d(
73 | self.in_chns // 2, self.out_chns, kernel_size=1, padding=0)
74 | self.drop1 = nn.Dropout2d(0.3)
75 | self.drop2 = nn.Dropout2d(0.3)
76 | self.ac1 = nn.LeakyReLU()
77 |
78 | def forward(self, x):
79 | x = self.drop1(x)
80 | x = self.conv1(x)
81 | x = self.ac1(x)
82 | x = self.drop2(x)
83 | x = self.conv2(x)
84 | return x
85 |
86 |
87 | class PNet2D(nn.Module):
88 | def __init__(self, in_chns, out_chns, num_filters, ratios):
89 | super(PNet2D, self).__init__()
90 |
91 | self.in_chns = in_chns
92 | self.out_chns = out_chns
93 | self.ratios = ratios
94 | self.num_filters = num_filters
95 |
96 | self.block1 = PNetBlock(
97 | self.in_chns, self.num_filters, self.ratios[0], padding=self.ratios[0])
98 |
99 | self.block2 = PNetBlock(
100 | self.num_filters, self.num_filters, self.ratios[1], padding=self.ratios[1])
101 |
102 | self.block3 = PNetBlock(
103 | self.num_filters, self.num_filters, self.ratios[2], padding=self.ratios[2])
104 |
105 | self.block4 = PNetBlock(
106 | self.num_filters, self.num_filters, self.ratios[3], padding=self.ratios[3])
107 |
108 | self.block5 = PNetBlock(
109 | self.num_filters, self.num_filters, self.ratios[4], padding=self.ratios[4])
110 | self.catblock = ConcatBlock(self.num_filters * 5, self.num_filters * 2)
111 | self.out = OutPutBlock(self.num_filters * 2, self.out_chns)
112 |
113 | def forward(self, x):
114 | x1 = self.block1(x)
115 | x2 = self.block2(x1)
116 | x3 = self.block3(x2)
117 | x4 = self.block4(x3)
118 | x5 = self.block5(x4)
119 | conx = torch.cat([x1, x2, x3, x4, x5], dim=1)
120 | conx = self.catblock(conx)
121 | out = self.out(conx)
122 | return out
123 |
--------------------------------------------------------------------------------
/code/networks/unet_3D.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | An implementation of the 3D U-Net paper:
4 | Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf Ronneberger:
5 | 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation.
6 | MICCAI (2) 2016: 424-432
7 | Note that there are some modifications from the original paper, such as
8 | the use of batch normalization, dropout, and leaky relu here.
9 | The implementation is borrowed from: https://github.com/ozan-oktay/Attention-Gated-Networks
10 | """
11 | import math
12 |
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 |
16 | from networks.networks_other import init_weights
17 | from networks.utils import UnetConv3, UnetUp3, UnetUp3_CT
18 |
19 |
20 | class unet_3D(nn.Module):
21 |
22 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True):
23 | super(unet_3D, self).__init__()
24 | self.is_deconv = is_deconv
25 | self.in_channels = in_channels
26 | self.is_batchnorm = is_batchnorm
27 | self.feature_scale = feature_scale
28 |
29 | filters = [64, 128, 256, 512, 1024]
30 | filters = [int(x / self.feature_scale) for x in filters]
31 |
32 | # downsampling
33 | self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=(
34 | 3, 3, 3), padding_size=(1, 1, 1))
35 | self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2))
36 |
37 | self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=(
38 | 3, 3, 3), padding_size=(1, 1, 1))
39 | self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2))
40 |
41 | self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=(
42 | 3, 3, 3), padding_size=(1, 1, 1))
43 | self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2))
44 |
45 | self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=(
46 | 3, 3, 3), padding_size=(1, 1, 1))
47 | self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2))
48 |
49 | self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=(
50 | 3, 3, 3), padding_size=(1, 1, 1))
51 |
52 | # upsampling
53 | self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm)
54 | self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm)
55 | self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm)
56 | self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm)
57 |
58 | # final conv (without any concat)
59 | self.final = nn.Conv3d(filters[0], n_classes, 1)
60 |
61 | self.dropout1 = nn.Dropout(p=0.3)
62 | self.dropout2 = nn.Dropout(p=0.3)
63 |
64 | # initialise weights
65 | for m in self.modules():
66 | if isinstance(m, nn.Conv3d):
67 | init_weights(m, init_type='kaiming')
68 | elif isinstance(m, nn.BatchNorm3d):
69 | init_weights(m, init_type='kaiming')
70 |
71 | def forward(self, inputs):
72 | conv1 = self.conv1(inputs)
73 | maxpool1 = self.maxpool1(conv1)
74 |
75 | conv2 = self.conv2(maxpool1)
76 | maxpool2 = self.maxpool2(conv2)
77 |
78 | conv3 = self.conv3(maxpool2)
79 | maxpool3 = self.maxpool3(conv3)
80 |
81 | conv4 = self.conv4(maxpool3)
82 | maxpool4 = self.maxpool4(conv4)
83 |
84 | center = self.center(maxpool4)
85 | center = self.dropout1(center)
86 | up4 = self.up_concat4(conv4, center)
87 | up3 = self.up_concat3(conv3, up4)
88 | up2 = self.up_concat2(conv2, up3)
89 | up1 = self.up_concat1(conv1, up2)
90 | up1 = self.dropout2(up1)
91 |
92 | final = self.final(up1)
93 |
94 | return final
95 |
96 | @staticmethod
97 | def apply_argmax_softmax(pred):
98 | log_p = F.softmax(pred, dim=1)
99 |
100 | return log_p
101 |
--------------------------------------------------------------------------------
/code/networks/vision_mamba.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import copy
7 | import logging
8 | import math
9 |
10 | from os.path import join as pjoin
11 |
12 | import torch
13 | import torch.nn as nn
14 | import numpy as np
15 |
16 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
17 | from torch.nn.modules.utils import _pair
18 | from scipy import ndimage
19 | from .mamba_sys import VSSM
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 | class MambaUnet(nn.Module):
24 | def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
25 | super(MambaUnet, self).__init__()
26 | self.num_classes = num_classes
27 | self.zero_head = zero_head
28 | self.config = config
29 |
30 | self.mamba_unet = VSSM(
31 | patch_size=config.MODEL.VSSM.PATCH_SIZE,
32 | in_chans=config.MODEL.VSSM.IN_CHANS,
33 | num_classes=self.num_classes,
34 | embed_dim=config.MODEL.VSSM.EMBED_DIM,
35 | depths=config.MODEL.VSSM.DEPTHS,
36 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
37 | drop_rate=config.MODEL.DROP_RATE,
38 | drop_path_rate=config.MODEL.DROP_PATH_RATE,
39 | patch_norm=config.MODEL.SWIN.PATCH_NORM,
40 | use_checkpoint=config.TRAIN.USE_CHECKPOINT)
41 |
42 | def forward(self, x):
43 | if x.size()[1] == 1:
44 | x = x.repeat(1,3,1,1)
45 | logits = self.mamba_unet(x)
46 | return logits
47 |
48 | def load_from(self, config):
49 | pretrained_path = config.MODEL.PRETRAIN_CKPT
50 | if pretrained_path is not None:
51 | print("pretrained_path:{}".format(pretrained_path))
52 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
53 | pretrained_dict = torch.load(pretrained_path, map_location=device)
54 | if "model" not in pretrained_dict:
55 | print("---start load pretrained modle by splitting---")
56 | pretrained_dict = {k[17:]:v for k,v in pretrained_dict.items()}
57 | for k in list(pretrained_dict.keys()):
58 | if "output" in k:
59 | print("delete key:{}".format(k))
60 | del pretrained_dict[k]
61 | msg = self.mamba_unet.load_state_dict(pretrained_dict,strict=False)
62 | # print(msg)
63 | return
64 | pretrained_dict = pretrained_dict['model']
65 | print("---start load pretrained modle of swin encoder---")
66 |
67 | model_dict = self.mamba_unet.state_dict()
68 | full_dict = copy.deepcopy(pretrained_dict)
69 | for k, v in pretrained_dict.items():
70 | if "layers." in k:
71 | current_layer_num = 3-int(k[7:8])
72 | current_k = "layers_up." + str(current_layer_num) + k[8:]
73 | full_dict.update({current_k:v})
74 | for k in list(full_dict.keys()):
75 | if k in model_dict:
76 | if full_dict[k].shape != model_dict[k].shape:
77 | print("delete:{};shape pretrain:{};shape model:{}".format(k,v.shape,model_dict[k].shape))
78 | del full_dict[k]
79 |
80 | msg = self.mamba_unet.load_state_dict(full_dict, strict=False)
81 | # print(msg)
82 | else:
83 | print("none pretrain")
--------------------------------------------------------------------------------
/code/networks/vision_transformer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # This file borrowed from Swin-UNet: https://github.com/HuCaoFighting/Swin-Unet
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import copy
8 | import logging
9 | import math
10 |
11 | from os.path import join as pjoin
12 |
13 | import torch
14 | import torch.nn as nn
15 | import numpy as np
16 |
17 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
18 | from torch.nn.modules.utils import _pair
19 | from scipy import ndimage
20 | from networks.swin_transformer_unet_skip_expand_decoder_sys import SwinTransformerSys
21 |
22 | logger = logging.getLogger(__name__)
23 |
24 | class SwinUnet(nn.Module):
25 | def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
26 | super(SwinUnet, self).__init__()
27 | self.num_classes = num_classes
28 | self.zero_head = zero_head
29 | self.config = config
30 |
31 | self.swin_unet = SwinTransformerSys(img_size=config.DATA.IMG_SIZE,
32 | patch_size=config.MODEL.SWIN.PATCH_SIZE,
33 | in_chans=config.MODEL.SWIN.IN_CHANS,
34 | num_classes=self.num_classes,
35 | embed_dim=config.MODEL.SWIN.EMBED_DIM,
36 | depths=config.MODEL.SWIN.DEPTHS,
37 | num_heads=config.MODEL.SWIN.NUM_HEADS,
38 | window_size=config.MODEL.SWIN.WINDOW_SIZE,
39 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
40 | qkv_bias=config.MODEL.SWIN.QKV_BIAS,
41 | qk_scale=config.MODEL.SWIN.QK_SCALE,
42 | drop_rate=config.MODEL.DROP_RATE,
43 | drop_path_rate=config.MODEL.DROP_PATH_RATE,
44 | ape=config.MODEL.SWIN.APE,
45 | patch_norm=config.MODEL.SWIN.PATCH_NORM,
46 | use_checkpoint=config.TRAIN.USE_CHECKPOINT)
47 |
48 | def forward(self, x):
49 | if x.size()[1] == 1:
50 | x = x.repeat(1,3,1,1)
51 | logits = self.swin_unet(x)
52 | return logits
53 |
54 | def load_from(self, config):
55 | pretrained_path = config.MODEL.PRETRAIN_CKPT
56 | if pretrained_path is not None:
57 | print("pretrained_path:{}".format(pretrained_path))
58 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
59 | pretrained_dict = torch.load(pretrained_path, map_location=device)
60 | if "model" not in pretrained_dict:
61 | print("---start load pretrained modle by splitting---")
62 | pretrained_dict = {k[17:]:v for k,v in pretrained_dict.items()}
63 | for k in list(pretrained_dict.keys()):
64 | if "output" in k:
65 | print("delete key:{}".format(k))
66 | del pretrained_dict[k]
67 | msg = self.swin_unet.load_state_dict(pretrained_dict,strict=False)
68 | # print(msg)
69 | return
70 | pretrained_dict = pretrained_dict['model']
71 | print("---start load pretrained modle of swin encoder---")
72 |
73 | model_dict = self.swin_unet.state_dict()
74 | full_dict = copy.deepcopy(pretrained_dict)
75 | for k, v in pretrained_dict.items():
76 | if "layers." in k:
77 | current_layer_num = 3-int(k[7:8])
78 | current_k = "layers_up." + str(current_layer_num) + k[8:]
79 | full_dict.update({current_k:v})
80 | for k in list(full_dict.keys()):
81 | if k in model_dict:
82 | if full_dict[k].shape != model_dict[k].shape:
83 | print("delete:{};shape pretrain:{};shape model:{}".format(k,v.shape,model_dict[k].shape))
84 | del full_dict[k]
85 |
86 | msg = self.swin_unet.load_state_dict(full_dict, strict=False)
87 | # print(msg)
88 | else:
89 | print("none pretrain")
90 |
--------------------------------------------------------------------------------
/code/pretrained_ckpt/readme.txt:
--------------------------------------------------------------------------------
1 | download pre-trained model to this folder,
2 |
3 |
4 | https://drive.google.com/file/d/14RzbbBDjbKbgr0ordKlWbb69EFkHuplr/view?usp=sharing
5 |
6 | https://drive.google.com/file/d/1uUPsr7XeqayCxlspqBHbg5zIWx0JYtSX/view?usp=sharing
7 |
--------------------------------------------------------------------------------
/code/test_2D.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import re
4 | import shutil
5 |
6 | import h5py
7 | import nibabel as nib
8 | import numpy as np
9 | import SimpleITK as sitk
10 | import torch
11 | from medpy import metric
12 | from scipy.ndimage import zoom
13 | from scipy.ndimage.interpolation import zoom
14 | from tqdm import tqdm
15 |
16 | # from networks.efficientunet import UNet
17 | from networks.net_factory import net_factory, config, args
18 |
19 | from config import get_config
20 | from networks.vision_transformer import SwinUnet as ViT_seg
21 |
22 |
23 |
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument('--root_path', type=str,
26 | default='../data/ACDC', help='Name of Experiment')
27 | parser.add_argument('--exp', type=str,
28 | default='semi_mamba_unet_3unet', help='experiment_name')
29 | parser.add_argument('--model', type=str,
30 | default='unet', help='model_name')
31 | parser.add_argument('--fold', type=str,
32 | default='fold1', help='fold')
33 | parser.add_argument('--num_classes', type=int, default=4,
34 | help='output channel of network')
35 | # parser.add_argument('--sup_type', type=str, default="label",help='label')
36 | parser.add_argument('--sup_type', type=str, default="scribble",help='label/scribble')
37 |
38 | def get_fold_ids(fold):
39 | all_cases_set = ["patient{:0>3}".format(i) for i in range(1, 101)]
40 | fold1_testing_set = [
41 | "patient{:0>3}".format(i) for i in range(1, 21)]
42 | fold1_training_set = [
43 | i for i in all_cases_set if i not in fold1_testing_set]
44 |
45 | fold2_testing_set = [
46 | "patient{:0>3}".format(i) for i in range(21, 41)]
47 | fold2_training_set = [
48 | i for i in all_cases_set if i not in fold2_testing_set]
49 |
50 | fold3_testing_set = [
51 | "patient{:0>3}".format(i) for i in range(41, 61)]
52 | fold3_training_set = [
53 | i for i in all_cases_set if i not in fold3_testing_set]
54 |
55 | fold4_testing_set = [
56 | "patient{:0>3}".format(i) for i in range(61, 81)]
57 | fold4_training_set = [
58 | i for i in all_cases_set if i not in fold4_testing_set]
59 |
60 | fold5_testing_set = [
61 | "patient{:0>3}".format(i) for i in range(81, 101)]
62 | fold5_training_set = [
63 | i for i in all_cases_set if i not in fold5_testing_set]
64 | if fold == "fold1":
65 | return [fold1_training_set, fold1_testing_set]
66 | elif fold == "fold2":
67 | return [fold2_training_set, fold2_testing_set]
68 | elif fold == "fold3":
69 | return [fold3_training_set, fold3_testing_set]
70 | elif fold == "fold4":
71 | return [fold4_training_set, fold4_testing_set]
72 | elif fold == "fold5":
73 | return [fold5_training_set, fold5_testing_set]
74 | else:
75 | return "ERROR KEY"
76 |
77 |
78 | def calculate_metric_percase(pred, gt, spacing):
79 | pred[pred > 0] = 1
80 | gt[gt > 0] = 1
81 | dice = metric.binary.dc(pred, gt)
82 | asd = metric.binary.asd(pred, gt, voxelspacing=spacing)
83 | hd95 = metric.binary.hd95(pred, gt, voxelspacing=spacing)
84 | return dice, hd95, asd
85 |
86 |
87 | def test_single_volume(case, net, test_save_path, FLAGS):
88 | h5f = h5py.File(FLAGS.root_path +
89 | "/ACDC_training_volumes/{}".format(case), 'r')
90 | image = h5f['image'][:]
91 | label = h5f['label'][:]
92 | prediction = np.zeros_like(label)
93 | for ind in range(image.shape[0]):
94 | slice = image[ind, :, :]
95 | x, y = slice.shape[0], slice.shape[1]
96 | slice = zoom(slice, (224 / x, 224 / y), order=0)
97 | input = torch.from_numpy(slice).unsqueeze(
98 | 0).unsqueeze(0).float().cuda()
99 | net.eval()
100 | with torch.no_grad():
101 | out_main = net(input)
102 | out = torch.argmax(torch.softmax(
103 | out_main, dim=1), dim=1).squeeze(0)
104 | out = out.cpu().detach().numpy()
105 | pred = zoom(out, (x / 224, y / 224), order=0)
106 | prediction[ind] = pred
107 | case = case.replace(".h5", "")
108 | org_img_path = "../data/ACDC_training/{}.nii.gz".format(case)
109 | org_img_itk = sitk.ReadImage(org_img_path)
110 | spacing = org_img_itk.GetSpacing()
111 |
112 | first_metric = calculate_metric_percase(
113 | prediction == 1, label == 1, (spacing[2], spacing[0], spacing[1]))
114 | second_metric = calculate_metric_percase(
115 | prediction == 2, label == 2, (spacing[2], spacing[0], spacing[1]))
116 | third_metric = calculate_metric_percase(
117 | prediction == 3, label == 3, (spacing[2], spacing[0], spacing[1]))
118 |
119 | img_itk = sitk.GetImageFromArray(image.astype(np.float32))
120 | img_itk.CopyInformation(org_img_itk)
121 | prd_itk = sitk.GetImageFromArray(prediction.astype(np.float32))
122 | prd_itk.CopyInformation(org_img_itk)
123 | lab_itk = sitk.GetImageFromArray(label.astype(np.float32))
124 | lab_itk.CopyInformation(org_img_itk)
125 | sitk.WriteImage(prd_itk, test_save_path + case + "_pred.nii.gz")
126 | sitk.WriteImage(img_itk, test_save_path + case + "_img.nii.gz")
127 | sitk.WriteImage(lab_itk, test_save_path + case + "_gt.nii.gz")
128 | return first_metric, second_metric, third_metric
129 |
130 |
131 | def Inference(FLAGS):
132 | train_ids, test_ids = get_fold_ids(FLAGS.fold)
133 | all_volumes = os.listdir(
134 | FLAGS.root_path + "/ACDC_training_volumes")
135 | image_list = []
136 | for ids in test_ids:
137 | new_data_list = list(filter(lambda x: re.match(
138 | '{}.*'.format(ids), x) != None, all_volumes))
139 | image_list.extend(new_data_list)
140 | snapshot_path = "../model/{}_{}/{}".format(
141 | FLAGS.exp, FLAGS.fold, FLAGS.sup_type)
142 | test_save_path = "../model/{}_{}/{}/{}_predictions/".format(
143 | FLAGS.exp, FLAGS.fold, FLAGS.sup_type, FLAGS.model)
144 | if os.path.exists(test_save_path):
145 | shutil.rmtree(test_save_path)
146 | os.makedirs(test_save_path)
147 |
148 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #v
149 | net = net_factory(net_type=FLAGS.model, in_chns=1,
150 | class_num=FLAGS.num_classes)
151 |
152 | # net = ViT_seg(config, img_size=[224, 224], num_classes=args.num_classes).cuda()
153 | # net.load_from(config)
154 |
155 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
156 | save_mode_path = os.path.join(
157 | snapshot_path, 'ViT_best_model.pth')
158 | net.load_state_dict(torch.load(save_mode_path))
159 | print("init weight from {}".format(save_mode_path))
160 | net.eval()
161 |
162 | first_total = 0.0
163 | second_total = 0.0
164 | third_total = 0.0
165 | for case in tqdm(image_list):
166 | print(case)
167 | first_metric, second_metric, third_metric = test_single_volume(
168 | case, net, test_save_path, FLAGS)
169 | first_total += np.asarray(first_metric)
170 | second_total += np.asarray(second_metric)
171 | third_total += np.asarray(third_metric)
172 | avg_metric = [first_total / len(image_list), second_total /
173 | len(image_list), third_total / len(image_list)]
174 | print(avg_metric)
175 | print((avg_metric[0] + avg_metric[1] + avg_metric[2]) / 3)
176 | return ((avg_metric[0] + avg_metric[1] + avg_metric[2]) / 3)[0]
177 |
178 |
179 | if __name__ == '__main__':
180 | FLAGS = parser.parse_args()
181 | total = 0.0
182 | # for i in [5]:
183 | # for i in [5]:
184 | # FLAGS.fold = "fold{}".format(i)
185 | # print("Inference fold{}".format(i))
186 | mean_dice = Inference(FLAGS)
187 | total += mean_dice
188 | # print(total/5.0)
189 |
--------------------------------------------------------------------------------
/code/utils/__pycache__/gate_crf_loss.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/utils/__pycache__/gate_crf_loss.cpython-310.pyc
--------------------------------------------------------------------------------
/code/utils/__pycache__/gate_crf_loss.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/utils/__pycache__/gate_crf_loss.cpython-38.pyc
--------------------------------------------------------------------------------
/code/utils/__pycache__/losses.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/utils/__pycache__/losses.cpython-310.pyc
--------------------------------------------------------------------------------
/code/utils/__pycache__/losses.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/utils/__pycache__/losses.cpython-38.pyc
--------------------------------------------------------------------------------
/code/utils/__pycache__/metrics.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/utils/__pycache__/metrics.cpython-310.pyc
--------------------------------------------------------------------------------
/code/utils/__pycache__/metrics.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/utils/__pycache__/metrics.cpython-38.pyc
--------------------------------------------------------------------------------
/code/utils/__pycache__/ramps.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/utils/__pycache__/ramps.cpython-310.pyc
--------------------------------------------------------------------------------
/code/utils/__pycache__/ramps.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/code/utils/__pycache__/ramps.cpython-38.pyc
--------------------------------------------------------------------------------
/code/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from medpy import metric
3 |
4 |
5 | def cal_dice(prediction, label, num=2):
6 | total_dice = np.zeros(num-1)
7 | for i in range(1, num):
8 | prediction_tmp = (prediction == i)
9 | label_tmp = (label == i)
10 | prediction_tmp = prediction_tmp.astype(np.float)
11 | label_tmp = label_tmp.astype(np.float)
12 |
13 | dice = 2 * np.sum(prediction_tmp * label_tmp) / (np.sum(prediction_tmp) + np.sum(label_tmp))
14 | total_dice[i - 1] += dice
15 |
16 | return total_dice
17 |
18 |
19 | def calculate_metric_percase(pred, gt):
20 | dc = metric.binary.dc(pred, gt)
21 | jc = metric.binary.jc(pred, gt)
22 | hd = metric.binary.hd95(pred, gt)
23 | asd = metric.binary.asd(pred, gt)
24 |
25 | return dc, jc, hd, asd
26 |
27 |
28 | def dice(input, target, ignore_index=None):
29 | smooth = 1.
30 | # using clone, so that it can do change to original target.
31 | iflat = input.clone().view(-1)
32 | tflat = target.clone().view(-1)
33 | if ignore_index is not None:
34 | mask = tflat == ignore_index
35 | tflat[mask] = 0
36 | iflat[mask] = 0
37 | intersection = (iflat * tflat).sum()
38 |
39 | return (2. * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth)
--------------------------------------------------------------------------------
/code/utils/ramps.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018, Curious AI Ltd. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | """Functions for ramping hyperparameters up or down
9 |
10 | Each function takes the current training step or epoch, and the
11 | ramp length in the same format, and returns a multiplier between
12 | 0 and 1.
13 | """
14 |
15 |
16 | import numpy as np
17 |
18 |
19 | def sigmoid_rampup(current, rampup_length):
20 | """Exponential rampup from https://arxiv.org/abs/1610.02242"""
21 | if rampup_length == 0:
22 | return 1.0
23 | else:
24 | current = np.clip(current, 0.0, rampup_length)
25 | phase = 1.0 - current / rampup_length
26 | return float(np.exp(-5.0 * phase * phase))
27 |
28 |
29 | def linear_rampup(current, rampup_length):
30 | """Linear rampup"""
31 | assert current >= 0 and rampup_length >= 0
32 | if current >= rampup_length:
33 | return 1.0
34 | else:
35 | return current / rampup_length
36 |
37 |
38 | def cosine_rampdown(current, rampdown_length):
39 | """Cosine rampdown from https://arxiv.org/abs/1608.03983"""
40 | assert 0 <= current <= rampdown_length
41 | return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1))
42 |
--------------------------------------------------------------------------------
/code/utils/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import numpy as np
4 | from scipy.ndimage import distance_transform_edt as distance
5 | from skimage import segmentation as skimage_seg
6 | import torch
7 | from torch.utils.data.sampler import Sampler
8 |
9 | import networks
10 |
11 | def load_model(path):
12 | """Loads model and return it without DataParallel table."""
13 | if os.path.isfile(path):
14 | print("=> loading checkpoint '{}'".format(path))
15 | checkpoint = torch.load(path)
16 |
17 | # size of the top layer
18 | N = checkpoint['state_dict']['top_layer.bias'].size()
19 |
20 | # build skeleton of the model
21 | sob = 'sobel.0.weight' in checkpoint['state_dict'].keys()
22 | model = models.__dict__[checkpoint['arch']](sobel=sob, out=int(N[0]))
23 |
24 | # deal with a dataparallel table
25 | def rename_key(key):
26 | if not 'module' in key:
27 | return key
28 | return ''.join(key.split('.module'))
29 |
30 | checkpoint['state_dict'] = {rename_key(key): val
31 | for key, val
32 | in checkpoint['state_dict'].items()}
33 |
34 | # load weights
35 | model.load_state_dict(checkpoint['state_dict'])
36 | print("Loaded")
37 | else:
38 | model = None
39 | print("=> no checkpoint found at '{}'".format(path))
40 | return model
41 |
42 |
43 | class UnifLabelSampler(Sampler):
44 | """Samples elements uniformely accross pseudolabels.
45 | Args:
46 | N (int): size of returned iterator.
47 | images_lists: dict of key (target), value (list of data with this target)
48 | """
49 |
50 | def __init__(self, N, images_lists):
51 | self.N = N
52 | self.images_lists = images_lists
53 | self.indexes = self.generate_indexes_epoch()
54 |
55 | def generate_indexes_epoch(self):
56 | size_per_pseudolabel = int(self.N / len(self.images_lists)) + 1
57 | res = np.zeros(size_per_pseudolabel * len(self.images_lists))
58 |
59 | for i in range(len(self.images_lists)):
60 | indexes = np.random.choice(
61 | self.images_lists[i],
62 | size_per_pseudolabel,
63 | replace=(len(self.images_lists[i]) <= size_per_pseudolabel)
64 | )
65 | res[i * size_per_pseudolabel: (i + 1) * size_per_pseudolabel] = indexes
66 |
67 | np.random.shuffle(res)
68 | return res[:self.N].astype('int')
69 |
70 | def __iter__(self):
71 | return iter(self.indexes)
72 |
73 | def __len__(self):
74 | return self.N
75 |
76 |
77 | class AverageMeter(object):
78 | """Computes and stores the average and current value"""
79 | def __init__(self):
80 | self.reset()
81 |
82 | def reset(self):
83 | self.val = 0
84 | self.avg = 0
85 | self.sum = 0
86 | self.count = 0
87 |
88 | def update(self, val, n=1):
89 | self.val = val
90 | self.sum += val * n
91 | self.count += n
92 | self.avg = self.sum / self.count
93 |
94 |
95 | def learning_rate_decay(optimizer, t, lr_0):
96 | for param_group in optimizer.param_groups:
97 | lr = lr_0 / np.sqrt(1 + lr_0 * param_group['weight_decay'] * t)
98 | param_group['lr'] = lr
99 |
100 |
101 | class Logger():
102 | """ Class to update every epoch to keep trace of the results
103 | Methods:
104 | - log() log and save
105 | """
106 |
107 | def __init__(self, path):
108 | self.path = path
109 | self.data = []
110 |
111 | def log(self, train_point):
112 | self.data.append(train_point)
113 | with open(os.path.join(self.path), 'wb') as fp:
114 | pickle.dump(self.data, fp, -1)
115 |
116 |
117 | def compute_sdf(img_gt, out_shape):
118 | """
119 | compute the signed distance map of binary mask
120 | input: segmentation, shape = (batch_size, x, y, z)
121 | output: the Signed Distance Map (SDM)
122 | sdf(x) = 0; x in segmentation boundary
123 | -inf|x-y|; x in segmentation
124 | +inf|x-y|; x out of segmentation
125 | normalize sdf to [-1,1]
126 | """
127 |
128 | img_gt = img_gt.astype(np.uint8)
129 | normalized_sdf = np.zeros(out_shape)
130 |
131 | for b in range(out_shape[0]): # batch size
132 | posmask = img_gt[b].astype(np.bool)
133 | if posmask.any():
134 | negmask = ~posmask
135 | posdis = distance(posmask)
136 | negdis = distance(negmask)
137 | boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8)
138 | sdf = (negdis-np.min(negdis))/(np.max(negdis)-np.min(negdis)) - (posdis-np.min(posdis))/(np.max(posdis)-np.min(posdis))
139 | sdf[boundary==1] = 0
140 | normalized_sdf[b] = sdf
141 | # assert np.min(sdf) == -1.0, print(np.min(posdis), np.max(posdis), np.min(negdis), np.max(negdis))
142 | # assert np.max(sdf) == 1.0, print(np.min(posdis), np.min(negdis), np.max(posdis), np.max(negdis))
143 |
144 | return normalized_sdf
--------------------------------------------------------------------------------
/code/val_2D.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from medpy import metric
4 | from scipy.ndimage import zoom
5 |
6 |
7 | def calculate_metric_percase(pred, gt):
8 | pred[pred > 0] = 1
9 | gt[gt > 0] = 1
10 | if pred.sum() > 0:
11 | dice = metric.binary.dc(pred, gt)
12 | hd95 = metric.binary.hd95(pred, gt)
13 | return dice, hd95
14 | else:
15 | return 0, 0
16 |
17 |
18 | def test_single_volume(image, label, net, classes, patch_size=[256, 256]):
19 | image, label = image.squeeze(0).cpu().detach(
20 | ).numpy(), label.squeeze(0).cpu().detach().numpy()
21 | if len(image.shape) == 3:
22 | prediction = np.zeros_like(label)
23 | for ind in range(image.shape[0]):
24 | slice = image[ind, :, :]
25 | x, y = slice.shape[0], slice.shape[1]
26 | slice = zoom(
27 | slice, (patch_size[0] / x, patch_size[1] / y), order=0)
28 | input = torch.from_numpy(slice).unsqueeze(
29 | 0).unsqueeze(0).float().cuda()
30 | net.eval()
31 | with torch.no_grad():
32 | out = torch.argmax(torch.softmax(
33 | net(input), dim=1), dim=1).squeeze(0)
34 | out = out.cpu().detach().numpy()
35 | pred = zoom(
36 | out, (x / patch_size[0], y / patch_size[1]), order=0)
37 | prediction[ind] = pred
38 | else:
39 | input = torch.from_numpy(image).unsqueeze(
40 | 0).unsqueeze(0).float().cuda()
41 | net.eval()
42 | with torch.no_grad():
43 | out = torch.argmax(torch.softmax(
44 | net(input), dim=1), dim=1).squeeze(0)
45 | prediction = out.cpu().detach().numpy()
46 | metric_list = []
47 | for i in range(1, classes):
48 | metric_list.append(calculate_metric_percase(
49 | prediction == i, label == i))
50 | return metric_list
51 |
52 |
53 | def test_single_volume_ds(image, label, net, classes, patch_size=[256, 256]):
54 | image, label = image.squeeze(0).cpu().detach(
55 | ).numpy(), label.squeeze(0).cpu().detach().numpy()
56 | if len(image.shape) == 3:
57 | prediction = np.zeros_like(label)
58 | for ind in range(image.shape[0]):
59 | slice = image[ind, :, :]
60 | x, y = slice.shape[0], slice.shape[1]
61 | slice = zoom(
62 | slice, (patch_size[0] / x, patch_size[1] / y), order=0)
63 | input = torch.from_numpy(slice).unsqueeze(
64 | 0).unsqueeze(0).float().cuda()
65 | net.eval()
66 | with torch.no_grad():
67 | output_main, _, _, _ = net(input)
68 | out = torch.argmax(torch.softmax(
69 | output_main, dim=1), dim=1).squeeze(0)
70 | out = out.cpu().detach().numpy()
71 | pred = zoom(
72 | out, (x / patch_size[0], y / patch_size[1]), order=0)
73 | prediction[ind] = pred
74 | else:
75 | input = torch.from_numpy(image).unsqueeze(
76 | 0).unsqueeze(0).float().cuda()
77 | net.eval()
78 | with torch.no_grad():
79 | output_main, _, _, _ = net(input)
80 | out = torch.argmax(torch.softmax(
81 | output_main, dim=1), dim=1).squeeze(0)
82 | prediction = out.cpu().detach().numpy()
83 | metric_list = []
84 | for i in range(1, classes):
85 | metric_list.append(calculate_metric_percase(
86 | prediction == i, label == i))
87 | return metric_list
88 |
89 |
90 | def test_single_volume_cct(image, label, net, classes, patch_size=[256, 256]):
91 | image, label = image.squeeze(0).cpu().detach(
92 | ).numpy(), label.squeeze(0).cpu().detach().numpy()
93 | if len(image.shape) == 3:
94 | prediction = np.zeros_like(label)
95 | for ind in range(image.shape[0]):
96 | slice = image[ind, :, :]
97 | x, y = slice.shape[0], slice.shape[1]
98 | slice = zoom(
99 | slice, (patch_size[0] / x, patch_size[1] / y), order=0)
100 | input = torch.from_numpy(slice).unsqueeze(
101 | 0).unsqueeze(0).float().cuda()
102 | net.eval()
103 | with torch.no_grad():
104 | output_main = net(input)[0]
105 | out = torch.argmax(torch.softmax(
106 | output_main, dim=1), dim=1).squeeze(0)
107 | out = out.cpu().detach().numpy()
108 | pred = zoom(
109 | out, (x / patch_size[0], y / patch_size[1]), order=0)
110 | prediction[ind] = pred
111 | else:
112 | input = torch.from_numpy(image).unsqueeze(
113 | 0).unsqueeze(0).float().cuda()
114 | net.eval()
115 | with torch.no_grad():
116 | output_main, _, _, _ = net(input)
117 | out = torch.argmax(torch.softmax(
118 | output_main, dim=1), dim=1).squeeze(0)
119 | prediction = out.cpu().detach().numpy()
120 | metric_list = []
121 | for i in range(1, classes):
122 | metric_list.append(calculate_metric_percase(
123 | prediction == i, label == i))
124 | return metric_list
125 |
--------------------------------------------------------------------------------
/data/readme.txt:
--------------------------------------------------------------------------------
1 |
2 | Google Drive:
3 |
4 | https://drive.google.com/file/d/1XR_Id0wdvXY9QeKtdOdgJHKVJ-nVr2j1/view?usp=sharing
5 |
6 | or
7 |
8 | Baidu Netdisk:
9 |
10 | https://pan.baidu.com/s/1dHkp9daqE3kLEbAP6zl7Jw with passcode: 'rwv2'
11 |
--------------------------------------------------------------------------------
/img/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/img/results.png
--------------------------------------------------------------------------------
/img/wslframework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/img/wslframework.png
--------------------------------------------------------------------------------
/img/wslintro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/img/wslintro.png
--------------------------------------------------------------------------------
/mamba/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "3rdparty/lm-evaluation-harness"]
2 | path = 3rdparty/lm-evaluation-harness
3 | url = https://github.com/EleutherAI/lm-evaluation-harness/
4 |
--------------------------------------------------------------------------------
/mamba/AUTHORS:
--------------------------------------------------------------------------------
1 | Tri Dao, tri@tridao.me
2 | Albert Gu, agu@andrew.cmu.edu
3 |
--------------------------------------------------------------------------------
/mamba/README.md:
--------------------------------------------------------------------------------
1 | # Mamba
2 |
3 | 
4 | > **Mamba: Linear-Time Sequence Modeling with Selective State Spaces**\
5 | > Albert Gu*, Tri Dao*\
6 | > Paper: https://arxiv.org/abs/2312.00752
7 |
8 | ## About
9 |
10 | Mamba is a new state space model architecture showing promising performance on information-dense data such as language modeling, where previous subquadratic models fall short of Transformers.
11 | It is based on the line of progress on [structured state space models](https://github.com/state-spaces/s4),
12 | with an efficient hardware-aware design and implementation in the spirit of [FlashAttention](https://github.com/Dao-AILab/flash-attention).
13 |
14 | ## Installation
15 |
16 | - `pip install causal-conv1d`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block.
17 | - `pip install mamba-ssm`: the core Mamba package.
18 |
19 | It can also be built from source with `pip install .` from this repository.
20 |
21 | If `pip` complains about PyTorch versions, try passing `--no-build-isolation` to `pip`.
22 |
23 | Other requirements:
24 | - Linux
25 | - NVIDIA GPU
26 | - PyTorch 1.12+
27 | - CUDA 11.6+
28 |
29 | ## Usage
30 |
31 | We expose several levels of interface with the Mamba model.
32 |
33 | ### Selective SSM
34 |
35 | Mamba is based on a selective SSM layer, which is the focus of the paper (Section 3; Algorithm 2).
36 |
37 | Source: [ops/selective_scan_interface.py](mamba_ssm/ops/selective_scan_interface.py).
38 |
39 | ### Mamba Block
40 |
41 | The main module of this repository is the Mamba architecture block wrapping the selective SSM.
42 |
43 | Source: [modules/mamba_simple.py](mamba_ssm/modules/mamba_simple.py).
44 |
45 | Usage:
46 | ```
47 | from mamba_ssm import Mamba
48 |
49 | batch, length, dim = 2, 64, 16
50 | x = torch.randn(batch, length, dim).to("cuda")
51 | model = Mamba(
52 | # This module uses roughly 3 * expand * d_model^2 parameters
53 | d_model=dim, # Model dimension d_model
54 | d_state=16, # SSM state expansion factor
55 | d_conv=4, # Local convolution width
56 | expand=2, # Block expansion factor
57 | ).to("cuda")
58 | y = model(x)
59 | assert y.shape == x.shape
60 | ```
61 |
62 | ### Mamba Language Model
63 |
64 | Finally, we provide an example of a complete language model: a deep sequence model backbone (with repeating Mamba blocks) + language model head.
65 |
66 | Source: [models/mixer_seq_simple.py](mamba_ssm/models/mixer_seq_simple.py).
67 |
68 | This is an example of how to integrate Mamba into an end-to-end neural network.
69 | This example is used in the generation scripts below.
70 |
71 |
72 |
73 | ## Pretrained Models
74 |
75 | Pretrained models are uploaded to
76 | [HuggingFace](https://huggingface.co/state-spaces): `mamba-130m`, `mamba-370m`,
77 | `mamba-790m`, `mamba-1.4b`, `mamba-2.8b`.
78 |
79 | The models will be autodownloaded by the generation script below.
80 |
81 | These models were trained on the [Pile](https://huggingface.co/datasets/EleutherAI/pile), and follow the standard model dimensions described by GPT-3 and followed by many open source models:
82 |
83 | | Parameters | Layers | Model dim. |
84 | |------------|--------|------------|
85 | | 130M | 12 | 768 |
86 | | 370M | 24 | 1024 |
87 | | 790M | 24 | 1536 |
88 | | 1.4B | 24 | 2048 |
89 | | 2.8B | 32 | 2560 |
90 |
91 | (The layer count of Mamba should be doubled, as two Mamba blocks are needed for each "layer" (MHA block + MLP block) of a Transformer.)
92 |
93 | Note: these are base models trained only for 300B tokens, without any form of downstream modification (instruction tuning, etc.).
94 | Performance is expected to be comparable or better than other architectures trained on similar data, but not to match larger or fine-tuned models.
95 |
96 |
97 | ## Evaluations
98 |
99 | To run zero-shot evaluations of models (corresponding to Table 3 of the paper),
100 | we use the
101 | [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor)
102 | library.
103 |
104 | 1. Pull the `lm-evaluation-harness` repo by `git submodule update --init
105 | --recursive`. We use the `big-refactor` branch.
106 | 2. Install `lm-evaluation-harness`: `pip install -e 3rdparty/lm-evaluation-harness`
107 | 3. Run evaluation with (more documentation at the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) repo):
108 | ```
109 | python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
110 | python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
111 | ```
112 |
113 | Note that the result of each task might differ from reported values by 0.1-0.3 due to noise in the evaluation process.
114 |
115 | ## Inference
116 |
117 | The script [benchmarks/benchmark_generation_mamba_simple.py](benchmarks/benchmark_generation_mamba_simple.py)
118 | 1. autoloads a model from the HuggingFace Hub,
119 | 2. generates completions of a user-specified prompt,
120 | 3. benchmarks the inference speed of this generation.
121 |
122 | Other configurable options include the top-p (nucleus sampling) probability, and the softmax temperature.
123 |
124 | ### Examples
125 |
126 | To test generation latency (e.g. batch size = 1) with different sampling strategies:
127 |
128 | ```
129 | python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.5
130 | python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.5
131 | ```
132 |
133 | To test generation throughput with random prompts (e.g. large batch size):
134 | ```
135 | python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 128
136 | python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 128
137 | ```
138 |
139 | ## Citation
140 |
141 | If you use this codebase, or otherwise found our work valuable, please cite Mamba:
142 | ```
143 | @article{mamba,
144 | title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
145 | author={Gu, Albert and Dao, Tri},
146 | journal={arXiv preprint arXiv:2312.00752},
147 | year={2023}
148 | }
149 | ```
150 |
--------------------------------------------------------------------------------
/mamba/assets/selection.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/mamba/assets/selection.png
--------------------------------------------------------------------------------
/mamba/benchmarks/benchmark_generation_mamba_simple.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Tri Dao, Albert Gu.
2 |
3 | import argparse
4 | import time
5 | import json
6 |
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | from einops import rearrange
11 |
12 | from transformers import AutoTokenizer, AutoModelForCausalLM
13 |
14 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
15 |
16 |
17 | parser = argparse.ArgumentParser(description="Generation benchmarking")
18 | parser.add_argument("--model-name", type=str, default="state-spaces/mamba-130m")
19 | parser.add_argument("--prompt", type=str, default=None)
20 | parser.add_argument("--promptlen", type=int, default=100)
21 | parser.add_argument("--genlen", type=int, default=100)
22 | parser.add_argument("--temperature", type=float, default=1.0)
23 | parser.add_argument("--topk", type=int, default=1)
24 | parser.add_argument("--topp", type=float, default=1.0)
25 | parser.add_argument("--batch", type=int, default=1)
26 | args = parser.parse_args()
27 |
28 | repeats = 3
29 | device = "cuda"
30 | dtype = torch.float16
31 |
32 | print(f"Loading model {args.model_name}")
33 | is_mamba = args.model_name.startswith("state-spaces/mamba-") or "mamba" in args.model_name
34 |
35 | if is_mamba:
36 | tokenizer = AutoTokenizer.from_pretrained("/home/zhulianghui/VisionProjects/mamba/ckpts/gpt-neox-20b-tokenizer")
37 | model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype)
38 | else:
39 | tokenizer = AutoTokenizer.from_pretrained(args.model_name)
40 | model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map={"": device}, torch_dtype=dtype)
41 | model.eval()
42 | print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
43 |
44 | torch.random.manual_seed(0)
45 | if args.prompt is None:
46 | input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device="cuda")
47 | attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda")
48 | else:
49 | tokens = tokenizer(args.prompt, return_tensors="pt")
50 | input_ids = tokens.input_ids.to(device=device)
51 | attn_mask = tokens.attention_mask.to(device=device)
52 | max_length = input_ids.shape[1] + args.genlen
53 |
54 | if is_mamba:
55 | fn = lambda: model.generate(
56 | input_ids=input_ids,
57 | max_length=max_length,
58 | cg=True,
59 | return_dict_in_generate=True,
60 | output_scores=True,
61 | enable_timing=False,
62 | temperature=args.temperature,
63 | top_k=args.topk,
64 | top_p=args.topp,
65 | )
66 | else:
67 | fn = lambda: model.generate(
68 | input_ids=input_ids,
69 | attention_mask=attn_mask,
70 | max_length=max_length,
71 | return_dict_in_generate=True,
72 | pad_token_id=tokenizer.eos_token_id,
73 | do_sample=True,
74 | temperature=args.temperature,
75 | top_k=args.topk,
76 | top_p=args.topp,
77 | )
78 | out = fn()
79 | if args.prompt is not None:
80 | print(tokenizer.batch_decode(out.sequences.tolist()))
81 |
82 | torch.cuda.synchronize()
83 | start = time.time()
84 | for _ in range(repeats):
85 | fn()
86 | torch.cuda.synchronize()
87 | print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}")
88 | print(f"{args.model_name} prompt processing + decoding time: {(time.time() - start) / repeats * 1000:.0f}ms")
89 |
--------------------------------------------------------------------------------
/mamba/csrc/selective_scan/selective_scan.h:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | #pragma once
6 |
7 | ////////////////////////////////////////////////////////////////////////////////////////////////////
8 |
9 | struct SSMScanParamsBase {
10 | using index_t = uint32_t;
11 |
12 | int batch, seqlen, n_chunks;
13 | index_t a_batch_stride;
14 | index_t b_batch_stride;
15 | index_t out_batch_stride;
16 |
17 | // Common data pointers.
18 | void *__restrict__ a_ptr;
19 | void *__restrict__ b_ptr;
20 | void *__restrict__ out_ptr;
21 | void *__restrict__ x_ptr;
22 | };
23 |
24 | ////////////////////////////////////////////////////////////////////////////////////////////////////
25 |
26 | struct SSMParamsBase {
27 | using index_t = uint32_t;
28 |
29 | int batch, dim, seqlen, dstate, n_groups, n_chunks;
30 | int dim_ngroups_ratio;
31 | bool is_variable_B;
32 | bool is_variable_C;
33 |
34 | bool delta_softplus;
35 |
36 | index_t A_d_stride;
37 | index_t A_dstate_stride;
38 | index_t B_batch_stride;
39 | index_t B_d_stride;
40 | index_t B_dstate_stride;
41 | index_t B_group_stride;
42 | index_t C_batch_stride;
43 | index_t C_d_stride;
44 | index_t C_dstate_stride;
45 | index_t C_group_stride;
46 | index_t u_batch_stride;
47 | index_t u_d_stride;
48 | index_t delta_batch_stride;
49 | index_t delta_d_stride;
50 | index_t z_batch_stride;
51 | index_t z_d_stride;
52 | index_t out_batch_stride;
53 | index_t out_d_stride;
54 | index_t out_z_batch_stride;
55 | index_t out_z_d_stride;
56 |
57 | // Common data pointers.
58 | void *__restrict__ A_ptr;
59 | void *__restrict__ B_ptr;
60 | void *__restrict__ C_ptr;
61 | void *__restrict__ D_ptr;
62 | void *__restrict__ u_ptr;
63 | void *__restrict__ delta_ptr;
64 | void *__restrict__ delta_bias_ptr;
65 | void *__restrict__ out_ptr;
66 | void *__restrict__ x_ptr;
67 | void *__restrict__ z_ptr;
68 | void *__restrict__ out_z_ptr;
69 | };
70 |
71 | struct SSMParamsBwd: public SSMParamsBase {
72 | index_t dout_batch_stride;
73 | index_t dout_d_stride;
74 | index_t dA_d_stride;
75 | index_t dA_dstate_stride;
76 | index_t dB_batch_stride;
77 | index_t dB_group_stride;
78 | index_t dB_d_stride;
79 | index_t dB_dstate_stride;
80 | index_t dC_batch_stride;
81 | index_t dC_group_stride;
82 | index_t dC_d_stride;
83 | index_t dC_dstate_stride;
84 | index_t du_batch_stride;
85 | index_t du_d_stride;
86 | index_t dz_batch_stride;
87 | index_t dz_d_stride;
88 | index_t ddelta_batch_stride;
89 | index_t ddelta_d_stride;
90 |
91 | // Common data pointers.
92 | void *__restrict__ dout_ptr;
93 | void *__restrict__ dA_ptr;
94 | void *__restrict__ dB_ptr;
95 | void *__restrict__ dC_ptr;
96 | void *__restrict__ dD_ptr;
97 | void *__restrict__ du_ptr;
98 | void *__restrict__ dz_ptr;
99 | void *__restrict__ ddelta_ptr;
100 | void *__restrict__ ddelta_bias_ptr;
101 | };
102 |
--------------------------------------------------------------------------------
/mamba/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_bwd_kernel.cuh"
8 |
9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/mamba/csrc/selective_scan/selective_scan_bwd_bf16_real.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_bwd_kernel.cuh"
8 |
9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/mamba/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_bwd_kernel.cuh"
8 |
9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/mamba/csrc/selective_scan/selective_scan_bwd_fp16_real.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_bwd_kernel.cuh"
8 |
9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/mamba/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_bwd_kernel.cuh"
8 |
9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/mamba/csrc/selective_scan/selective_scan_bwd_fp32_real.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_bwd_kernel.cuh"
8 |
9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/mamba/csrc/selective_scan/selective_scan_fwd_bf16.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_fwd_kernel.cuh"
8 |
9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/mamba/csrc/selective_scan/selective_scan_fwd_fp16.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_fwd_kernel.cuh"
8 |
9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/mamba/csrc/selective_scan/selective_scan_fwd_fp32.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_fwd_kernel.cuh"
8 |
9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/mamba/csrc/selective_scan/static_switch.h:
--------------------------------------------------------------------------------
1 | // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
2 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
3 |
4 | #pragma once
5 |
6 | /// @param COND - a boolean expression to switch by
7 | /// @param CONST_NAME - a name given for the constexpr bool variable.
8 | /// @param ... - code to execute for true and false
9 | ///
10 | /// Usage:
11 | /// ```
12 | /// BOOL_SWITCH(flag, BoolConst, [&] {
13 | /// some_function(...);
14 | /// });
15 | /// ```
16 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \
17 | [&] { \
18 | if (COND) { \
19 | constexpr bool CONST_NAME = true; \
20 | return __VA_ARGS__(); \
21 | } else { \
22 | constexpr bool CONST_NAME = false; \
23 | return __VA_ARGS__(); \
24 | } \
25 | }()
26 |
--------------------------------------------------------------------------------
/mamba/csrc/selective_scan/uninitialized_copy.cuh:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved.
3 | *
4 | * Redistribution and use in source and binary forms, with or without
5 | * modification, are permitted provided that the following conditions are met:
6 | * * Redistributions of source code must retain the above copyright
7 | * notice, this list of conditions and the following disclaimer.
8 | * * Redistributions in binary form must reproduce the above copyright
9 | * notice, this list of conditions and the following disclaimer in the
10 | * documentation and/or other materials provided with the distribution.
11 | * * Neither the name of the NVIDIA CORPORATION nor the
12 | * names of its contributors may be used to endorse or promote products
13 | * derived from this software without specific prior written permission.
14 | *
15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18 | * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19 | * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 | *
26 | ******************************************************************************/
27 |
28 | #pragma once
29 |
30 | #include
31 |
32 | #include
33 |
34 |
35 | namespace detail
36 | {
37 |
38 | #if defined(_NVHPC_CUDA)
39 | template
40 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val)
41 | {
42 | // NVBug 3384810
43 | new (ptr) T(::cuda::std::forward(val));
44 | }
45 | #else
46 | template ::value,
50 | int
51 | >::type = 0>
52 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val)
53 | {
54 | *ptr = ::cuda::std::forward(val);
55 | }
56 |
57 | template ::value,
61 | int
62 | >::type = 0>
63 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val)
64 | {
65 | new (ptr) T(::cuda::std::forward(val));
66 | }
67 | #endif
68 |
69 | } // namespace detail
70 |
--------------------------------------------------------------------------------
/mamba/evals/lm_harness_eval.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import transformers
4 | from transformers import AutoTokenizer
5 |
6 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
7 |
8 | from lm_eval.api.model import LM
9 | from lm_eval.models.huggingface import HFLM
10 | from lm_eval.api.registry import register_model
11 | from lm_eval.__main__ import cli_evaluate
12 |
13 |
14 | @register_model("mamba")
15 | class MambaEvalWrapper(HFLM):
16 |
17 | AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
18 |
19 | def __init__(self, pretrained="state-spaces/mamba-2.8b", max_length=2048, batch_size=None, device="cuda",
20 | dtype=torch.float16):
21 | LM.__init__(self)
22 | self._model = MambaLMHeadModel.from_pretrained(pretrained, device=device, dtype=dtype)
23 | self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
24 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
25 | self.vocab_size = self.tokenizer.vocab_size
26 | self._batch_size = batch_size if batch_size is None else 64
27 | self._max_length = max_length
28 | self._device = torch.device(device)
29 |
30 | @property
31 | def batch_size(self):
32 | return self._batch_size
33 |
34 | def _model_generate(self, context, max_length, stop, **generation_kwargs):
35 | raise NotImplementedError()
36 |
37 |
38 | if __name__ == "__main__":
39 | cli_evaluate()
40 |
--------------------------------------------------------------------------------
/mamba/mamba_ssm/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "1.0.1"
2 |
3 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn
4 | from mamba_ssm.modules.mamba_simple import Mamba
5 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
6 |
--------------------------------------------------------------------------------
/mamba/mamba_ssm/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/mamba/mamba_ssm/models/__init__.py
--------------------------------------------------------------------------------
/mamba/mamba_ssm/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/mamba/mamba_ssm/modules/__init__.py
--------------------------------------------------------------------------------
/mamba/mamba_ssm/ops/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/mamba/mamba_ssm/ops/__init__.py
--------------------------------------------------------------------------------
/mamba/mamba_ssm/ops/triton/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/mamba/mamba_ssm/ops/triton/__init__.py
--------------------------------------------------------------------------------
/mamba/mamba_ssm/ops/triton/selective_state_update.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Tri Dao.
2 |
3 | """We want triton==2.1.0 for this
4 | """
5 |
6 | import math
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | import triton
11 | import triton.language as tl
12 |
13 | from einops import rearrange, repeat
14 |
15 |
16 | @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
17 | @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
18 | @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
19 | @triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
20 | @triton.jit
21 | def _selective_scan_update_kernel(
22 | # Pointers to matrices
23 | state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,
24 | # Matrix dimensions
25 | batch, dim, dstate,
26 | # Strides
27 | stride_state_batch, stride_state_dim, stride_state_dstate,
28 | stride_x_batch, stride_x_dim,
29 | stride_dt_batch, stride_dt_dim,
30 | stride_dt_bias_dim,
31 | stride_A_dim, stride_A_dstate,
32 | stride_B_batch, stride_B_dstate,
33 | stride_C_batch, stride_C_dstate,
34 | stride_D_dim,
35 | stride_z_batch, stride_z_dim,
36 | stride_out_batch, stride_out_dim,
37 | # Meta-parameters
38 | DT_SOFTPLUS: tl.constexpr,
39 | BLOCK_SIZE_M: tl.constexpr,
40 | HAS_DT_BIAS: tl.constexpr,
41 | HAS_D: tl.constexpr,
42 | HAS_Z: tl.constexpr,
43 | BLOCK_SIZE_DSTATE: tl.constexpr,
44 | ):
45 | pid_m = tl.program_id(axis=0)
46 | pid_b = tl.program_id(axis=1)
47 | state_ptr += pid_b * stride_state_batch
48 | x_ptr += pid_b * stride_x_batch
49 | dt_ptr += pid_b * stride_dt_batch
50 | B_ptr += pid_b * stride_B_batch
51 | C_ptr += pid_b * stride_C_batch
52 | if HAS_Z:
53 | z_ptr += pid_b * stride_z_batch
54 | out_ptr += pid_b * stride_out_batch
55 |
56 | offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
57 | offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
58 | state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)
59 | x_ptrs = x_ptr + offs_m * stride_x_dim
60 | dt_ptrs = dt_ptr + offs_m * stride_dt_dim
61 | if HAS_DT_BIAS:
62 | dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
63 | A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)
64 | B_ptrs = B_ptr + offs_n * stride_B_dstate
65 | C_ptrs = C_ptr + offs_n * stride_C_dstate
66 | if HAS_D:
67 | D_ptrs = D_ptr + offs_m * stride_D_dim
68 | if HAS_Z:
69 | z_ptrs = z_ptr + offs_m * stride_z_dim
70 | out_ptrs = out_ptr + offs_m * stride_out_dim
71 |
72 | state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)
73 | x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
74 | dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
75 | if HAS_DT_BIAS:
76 | dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
77 | if DT_SOFTPLUS:
78 | dt = tl.log(1.0 + tl.exp(dt))
79 | A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
80 | dA = tl.exp(A * dt[:, None])
81 | B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
82 | C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
83 | if HAS_D:
84 | D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
85 | if HAS_Z:
86 | z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
87 |
88 | dB = B[None, :] * dt[:, None]
89 | state = state * dA + dB * x[:, None]
90 | tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
91 | out = tl.sum(state * C[None, :], axis=1)
92 | if HAS_D:
93 | out += x * D
94 | if HAS_Z:
95 | out *= z * tl.sigmoid(z)
96 | tl.store(out_ptrs, out, mask=offs_m < dim)
97 |
98 |
99 | def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
100 | """
101 | Argument:
102 | state: (batch, dim, dstate)
103 | x: (batch, dim)
104 | dt: (batch, dim)
105 | A: (dim, dstate)
106 | B: (batch, dstate)
107 | C: (batch, dstate)
108 | D: (dim,)
109 | z: (batch, dim)
110 | dt_bias: (dim,)
111 | Return:
112 | out: (batch, dim)
113 | """
114 | batch, dim, dstate = state.shape
115 | assert x.shape == (batch, dim)
116 | assert dt.shape == x.shape
117 | assert A.shape == (dim, dstate)
118 | assert B.shape == (batch, dstate)
119 | assert C.shape == B.shape
120 | if D is not None:
121 | assert D.shape == (dim,)
122 | if z is not None:
123 | assert z.shape == x.shape
124 | if dt_bias is not None:
125 | assert dt_bias.shape == (dim,)
126 | out = torch.empty_like(x)
127 | grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch)
128 | z_strides = ((z.stride(0), z.stride(1)) if z is not None else (0, 0))
129 | # We don't want autotune since it will overwrite the state
130 | # We instead tune by hand.
131 | BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16
132 | else ((16, 4) if dstate <= 32 else
133 | ((8, 4) if dstate <= 64 else
134 | ((4, 4) if dstate <= 128 else
135 | ((4, 8))))))
136 | with torch.cuda.device(x.device.index):
137 | _selective_scan_update_kernel[grid](
138 | state, x, dt, dt_bias, A, B, C, D, z, out,
139 | batch, dim, dstate,
140 | state.stride(0), state.stride(1), state.stride(2),
141 | x.stride(0), x.stride(1),
142 | dt.stride(0), dt.stride(1),
143 | dt_bias.stride(0) if dt_bias is not None else 0,
144 | A.stride(0), A.stride(1),
145 | B.stride(0), B.stride(1),
146 | C.stride(0), C.stride(1),
147 | D.stride(0) if D is not None else 0,
148 | z_strides[0], z_strides[1],
149 | out.stride(0), out.stride(1),
150 | dt_softplus,
151 | BLOCK_SIZE_M,
152 | num_warps=num_warps,
153 | )
154 | return out
155 |
156 |
157 | def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
158 | """
159 | Argument:
160 | state: (batch, dim, dstate)
161 | x: (batch, dim)
162 | dt: (batch, dim)
163 | A: (dim, dstate)
164 | B: (batch, dstate)
165 | C: (batch, dstate)
166 | D: (dim,)
167 | z: (batch, dim)
168 | dt_bias: (dim,)
169 | Return:
170 | out: (batch, dim)
171 | """
172 | batch, dim, dstate = state.shape
173 | assert x.shape == (batch, dim)
174 | assert dt.shape == x.shape
175 | assert A.shape == (dim, dstate)
176 | assert B.shape == (batch, dstate)
177 | assert C.shape == B.shape
178 | if D is not None:
179 | assert D.shape == (dim,)
180 | if z is not None:
181 | assert z.shape == x.shape
182 | if dt_bias is not None:
183 | assert dt_bias.shape == (dim,)
184 | dt = dt + dt_bias
185 | dt = F.softplus(dt) if dt_softplus else dt
186 | dA = torch.exp(rearrange(dt, "b d -> b d 1") * A) # (batch, dim, dstate)
187 | dB = rearrange(dt, "b d -> b d 1") * rearrange(B, "b n -> b 1 n") # (batch, dim, dstate)
188 | state.copy_(state * dA + dB * rearrange(x, "b d -> b d 1")) # (batch, dim, dstate
189 | out = torch.einsum("bdn,bn->bd", state.to(C.dtype), C)
190 | if D is not None:
191 | out += (x * D).to(out.dtype)
192 | return (out if z is None else out * F.silu(z)).to(x.dtype)
193 |
--------------------------------------------------------------------------------
/mamba/mamba_ssm/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyangwang007/Weak-Mamba-UNet/8290f2de2bee85357be1c250be6a9ceeae14e612/mamba/mamba_ssm/utils/__init__.py
--------------------------------------------------------------------------------
/mamba/mamba_ssm/utils/hf.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import torch
4 |
5 | from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
6 | from transformers.utils.hub import cached_file
7 |
8 |
9 | def load_config_hf(model_name):
10 | resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
11 | return json.load(open(resolved_archive_file))
12 |
13 |
14 | def load_state_dict_hf(model_name, device=None, dtype=None):
15 | # If not fp32, then we don't want to load directly to the GPU
16 | mapped_device = "cpu" if dtype not in [torch.float32, None] else device
17 | resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
18 | return torch.load(resolved_archive_file, map_location=mapped_device)
19 | # Convert dtype before moving to GPU to save memory
20 | if dtype is not None:
21 | state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
22 | state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
23 | return state_dict
24 |
--------------------------------------------------------------------------------
/mamba/test_mamba_module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from mamba_ssm import Mamba
3 |
4 | batch, length, dim = 2, 64, 768
5 | x = torch.randn(batch, length, dim).to("cuda")
6 | model = Mamba(
7 | # This module uses roughly 3 * expand * d_model^2 parameters
8 | d_model=dim, # Model dimension d_model
9 | d_state=16, # SSM state expansion factor # 64
10 | d_conv=4, # Local convolution width
11 | expand=2, # Block expansion factor
12 | use_fast_path=False,
13 | ).to("cuda")
14 | y = model(x)
15 | assert y.shape == x.shape
16 |
--------------------------------------------------------------------------------
/mamba/tests/ops/triton/test_selective_state_update.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2023, Tri Dao.
2 |
3 | import math
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | import pytest
8 |
9 | from einops import rearrange
10 |
11 | from mamba_ssm.ops.triton.selective_state_update import selective_state_update, selective_state_update_ref
12 |
13 |
14 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
15 | # @pytest.mark.parametrize('itype', [torch.float16])
16 | @pytest.mark.parametrize("has_z", [False, True])
17 | # @pytest.mark.parametrize('has_z', [True])
18 | @pytest.mark.parametrize("dstate", [16, 32, 64])
19 | # @pytest.mark.parametrize("dstate", [16])
20 | @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
21 | # @pytest.mark.parametrize("dim", [2048])
22 | def test_causal_conv1d_update(dim, dstate, has_z, itype):
23 | device = "cuda"
24 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
25 | if itype == torch.bfloat16:
26 | rtol, atol = 1e-2, 5e-2
27 | # set seed
28 | torch.random.manual_seed(0)
29 | batch_size = 2
30 | state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
31 | x = torch.randn(batch_size, dim, device=device, dtype=itype)
32 | dt = torch.randn(batch_size, dim, device=device, dtype=itype)
33 | dt_bias = torch.rand(dim, device=device) - 4.0
34 | A = -torch.rand(dim, dstate, device=device) - 1.0
35 | B = torch.randn(batch_size, dstate, device=device)
36 | C = torch.randn(batch_size, dstate, device=device)
37 | D = torch.randn(dim, device=device)
38 | if has_z:
39 | z = torch.randn_like(x)
40 | else:
41 | z = None
42 | state_ref = state.detach().clone()
43 | out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
44 | out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
45 |
46 | print(f"Output max diff: {(out - out_ref).abs().max().item()}")
47 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
48 | assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
49 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
50 |
--------------------------------------------------------------------------------