├── README.md ├── SegDiffusion.py └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | # Cold SegDiffusion 2 | This repository is an official implementation of the paper "Cold SegDiffusion: A Novel 3 | Diffusion Model for Medical Image Segmentation." 4 | 5 | ## Dataset 6 | These medical images utilized in the experiments are collected from three public datasets: ISIC [1], TN3K [2], and REFUGE [3]. 7 | The references for the experimental datasets are given below: 8 | 9 | [1] D. Gutman, N. C. Codella, E. Celebi, B. Helba, M. Marchetti, N. Mishra, A. Halpern, Skin lesion analysis toward melanoma detection: A challenge at the international symposium 10 | on biomedical imaging (isbi) 2016, hosted by the international skin imaging collab- oration (isic), arXiv preprint arXiv:1605.01397 (2016). 11 | 12 | [2] H. Gong, J. Chen, G. Chen, H. Li, G. Li, F. Chen, Thyroid region prior guided attention for ultrasound segmentation of thyroid nodules, 13 | Computers in Biology and Medicine 155 (2023) 106389. 14 | 15 | [3] J. I. Orlando, H. Fu, J. B. Breda, K. Van Keer, D. R. Bathula, A. DiazPinto, R. Fang, P.-A. Heng, J. Kim, J. Lee, et al., Refuge challenge: A unified 16 | framework for evaluating automated methods for glaucoma assessment from fundus photographs, Medical image analysis 59 (2020) 101570. 17 | 18 | ## Code Usage 19 | 20 | ## Installation 21 | 22 | ### Requirements 23 | 24 | * Linux, CUDA>=11.3, GCC>=7.5.0 25 | 26 | * Python>=3.8 27 | 28 | * PyTorch>=1.11.0, torchvision>=0.12.0 (following instructions [here](https://pytorch.org/)) 29 | 30 | * Other requirements 31 | ```bash 32 | pip install -r requirements.txt 33 | ``` 34 | 35 | ### Dataset preparation 36 | 37 | Please organize the dataset as follows: 38 | 39 | ``` 40 | ISIC_Med/ 41 | └── ISBI2016_ISIC_Dataset/ 42 | ├── ISIC_0000000.jpg 43 | ├── ISIC_0000000_Segmentation.png 44 | ├── ISIC_0000001.jpg 45 | ├── ISIC_0000001_Segmentation.png 46 | ... 47 | └── train.txt 48 | └── valid.txt 49 | └── test.txt 50 | ``` 51 | 52 | ### Training 53 | 54 | For example, the command for the training Cold SegDiffusion is as follows: 55 | 56 | ```bash 57 | python driver.py 58 | ``` 59 | The configs in model_train.py or other files can be changed. 60 | 61 | ### Evaluation 62 | 63 | After obtaining the trained Cold SegDiffusion, then run the following command to evaluate it on the validation set: 64 | 65 | ```bash 66 | python sample.py 67 | ``` 68 | 69 | ## Notes 70 | The code of this repository is built on 71 | https://github.com/TimesXY/Cold-SegDiffusion. 72 | -------------------------------------------------------------------------------- /SegDiffusion.py: -------------------------------------------------------------------------------- 1 | import math 2 | import copy 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from random import random 7 | from tqdm.auto import tqdm 8 | from einops import rearrange 9 | from torch import nn, einsum 10 | from beartype import beartype 11 | from functools import partial 12 | from torch.fft import fft2, ifft2 13 | from collections import namedtuple 14 | from einops.layers.torch import Rearrange 15 | 16 | # constants 17 | ModelPrediction = namedtuple('ModelPrediction', ['predict_noise', 'predict_x_start']) 18 | 19 | 20 | # 判断变量是否存在 21 | def exists(x): 22 | return x is not None 23 | 24 | 25 | # 变量的默认选择 26 | def default(val, d): 27 | # 判断变量是否存在, 如果存在, 直接返回结果. 否则进行变量2的判断 28 | if exists(val): 29 | return val 30 | 31 | # 判断输入变量2是否为函数, 如果为函数, 返回函数结果, 否则直接返回变量2 32 | if callable(d): 33 | return d() 34 | else: 35 | return d 36 | 37 | 38 | # 残差链接,直接返回结果 39 | def identity(t): 40 | return t 41 | 42 | 43 | # 归一化函数, 从 [0, 1] 放缩至 [-1, 1] 44 | def normalize_to_neg_one_to_one(img): 45 | return img * 2 - 1 46 | 47 | 48 | # 归一化函数, 从 [-1, 1] 放缩至 [0, 1] 49 | def un_normalize_to_zero_to_one(t): 50 | return (t + 1) * 0.5 51 | 52 | 53 | # 创建学习率更新策略 54 | def create_lr_scheduler(optimizer, num_step: int, epochs: int, warmup=True, warmup_epochs=1, warmup_factor=1e-3): 55 | assert num_step > 0 and epochs > 0 56 | if warmup is False: 57 | warmup_epochs = 0 58 | 59 | def func(x): 60 | """ 61 | 学习率调整函数 62 | 根据 step 数返回一个学习率倍率因子,注意在训练开始之前,pytorch 会提前调用一次 lr_scheduler.step() 方法 63 | """ 64 | if warmup is True and x <= (warmup_epochs * num_step): 65 | alpha = float(x) / (warmup_epochs * num_step) 66 | # warmup 过程中 lr 倍率因子从 warmup_factor -> 1 67 | return warmup_factor * (1 - alpha) + alpha 68 | else: 69 | # warmup 后 lr 倍率因子从 1 -> 0 70 | # 参考 deeplab_v2: Learning rate policy 71 | return (1 - (x - warmup_epochs * num_step) / ((epochs - warmup_epochs) * num_step)) ** 0.9 72 | 73 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=func) 74 | 75 | 76 | # 残差链接模块 77 | class Residual(nn.Module): 78 | def __init__(self, fn): 79 | super().__init__() 80 | self.fn = fn 81 | 82 | def forward(self, x, *args, **kwargs): 83 | return self.fn(x, *args, **kwargs) + x 84 | 85 | 86 | # 上采样模块 87 | def up_sample(dim, dim_out=None): 88 | up_s = nn.Sequential( 89 | nn.Upsample(scale_factor=2, mode='nearest'), 90 | nn.Conv2d(dim, default(dim_out, dim), (3, 3), padding=1)) 91 | return up_s 92 | 93 | 94 | # 下采样模块 95 | def down_sample(dim, dim_out=None): 96 | down_s = nn.Sequential( 97 | Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1=2, p2=2), 98 | nn.Conv2d(dim * 4, default(dim_out, dim), (1, 1))) 99 | return down_s 100 | 101 | 102 | # 层归一化模块 103 | class LayerNorm(nn.Module): 104 | def __init__(self, dim, bias=False): 105 | super().__init__() 106 | self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) 107 | self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) if bias else None 108 | 109 | def forward(self, x): 110 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3 111 | var = torch.var(x, dim=1, unbiased=False, keepdim=True) 112 | mean = torch.mean(x, dim=1, keepdim=True) 113 | return (x - mean) * (var + eps).rsqrt() * self.g + default(self.b, 0) 114 | 115 | 116 | # 正弦位置编码模块 117 | class SinusoidalPosEmb(nn.Module): 118 | def __init__(self, dim): 119 | super().__init__() 120 | self.dim = dim 121 | 122 | def forward(self, x): 123 | device = x.device 124 | half_dim = self.dim // 2 125 | emb = math.log(10000) / (half_dim - 1) 126 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 127 | emb = x[:, None] * emb[None, :] 128 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 129 | return emb 130 | 131 | 132 | # 建立 卷积-归一化-激活 模块 133 | class Block(nn.Module): 134 | def __init__(self, dim, dim_out, groups=8): 135 | super().__init__() 136 | self.proj = nn.Conv2d(dim, dim_out, (3, 3), padding=1) 137 | self.norm = nn.GroupNorm(groups, dim_out) 138 | self.act = nn.SiLU() 139 | 140 | def forward(self, x, scale_shift=None): 141 | x = self.proj(x) 142 | x = self.norm(x) 143 | 144 | if exists(scale_shift): 145 | scale, shift = scale_shift 146 | x = x * (scale + 1) + shift 147 | 148 | x = self.act(x) 149 | return x 150 | 151 | 152 | # 构建残差模块 153 | class ResnetBlock(nn.Module): 154 | def __init__(self, dim, dim_out, time_emb_dim=None, groups=8): 155 | super().__init__() 156 | 157 | # 建立卷积-归一化-激活模块 158 | self.block1 = Block(dim, dim_out, groups=groups) 159 | self.block2 = Block(dim_out, dim_out, groups=groups) 160 | 161 | # 建立通道转换卷积层 162 | self.res_conv = nn.Conv2d(dim, dim_out, (1, 1)) if dim != dim_out else nn.Identity() 163 | 164 | # 时间编码 165 | self.mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2)) if exists(time_emb_dim) else None 166 | 167 | def forward(self, x, time_emb=None): 168 | # 获取编码后的结果, 拆分为 scale 和 shift 169 | scale_shift = None 170 | if exists(self.mlp) and exists(time_emb): 171 | time_emb = self.mlp(time_emb) 172 | time_emb = rearrange(time_emb, 'b c -> b c 1 1') 173 | scale_shift = time_emb.chunk(2, dim=1) 174 | 175 | # 添加到网络模块中 176 | h = self.block1(x, scale_shift=scale_shift) 177 | h = self.block2(h) 178 | 179 | # 通道转换后进行残差连接 180 | return h + self.res_conv(x) 181 | 182 | 183 | # 前向传播 通道维度的转换 184 | def feed_forward_att(dim, mult=4): 185 | inner_dim = int(dim * mult) 186 | feed_forward_linear = nn.Sequential(LayerNorm(dim), 187 | nn.Conv2d(dim, inner_dim, (1, 1)), 188 | nn.GELU(), 189 | nn.Conv2d(inner_dim, dim, (1, 1))) 190 | return feed_forward_linear 191 | 192 | 193 | # 线性注意力机制 194 | class LinearAttention(nn.Module): 195 | def __init__(self, dim, heads=4, dim_head=32): 196 | super().__init__() 197 | 198 | # 超参数设置(注意力头数目 归一化因子 注意力隐藏层维度) 199 | self.heads = heads 200 | self.scale = dim_head ** -0.5 201 | hidden_dim = dim_head * heads 202 | 203 | # 设置归一化层、QKV转换层、注意力输出层 204 | self.pre_norm = LayerNorm(dim) 205 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, (1, 1), bias=False) 206 | self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, (1, 1)), LayerNorm(dim)) 207 | 208 | def forward(self, x): 209 | # 获取数据维度 210 | b, c, h, w = x.shape 211 | 212 | # 数据归一化后划分为 Q K V 注意力机制 213 | x = self.pre_norm(x) 214 | qkv = self.to_qkv(x).chunk(3, dim=1) 215 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h=self.heads), qkv) 216 | 217 | # 线性注意力机制的计算过程 218 | q = q.softmax(dim=-2) 219 | k = k.softmax(dim=-1) 220 | q = q * self.scale 221 | context = torch.einsum('b h d n, b h e n -> b h d e', k, v) 222 | 223 | # 获取线性注意力机制的输出结果 224 | out = torch.einsum('b h d e, b h d n -> b h e n', context, q) 225 | out = rearrange(out, 'b h c (x y) -> b (h c) x y', h=self.heads, x=h, y=w) 226 | 227 | return self.to_out(out) 228 | 229 | 230 | class Attention(nn.Module): 231 | def __init__(self, dim, heads=4, dim_head=32): 232 | super().__init__() 233 | 234 | # 超参数设置(注意力头数目 归一化因子 注意力隐藏层维度) 235 | self.heads = heads 236 | self.scale = dim_head ** -0.5 237 | hidden_dim = dim_head * heads 238 | 239 | # 设置归一化层、QKV转换层、注意力输出层 240 | self.pre_norm = LayerNorm(dim) 241 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, (1, 1), bias=False) 242 | self.to_out = nn.Conv2d(hidden_dim, dim, (1, 1)) 243 | 244 | def forward(self, x): 245 | # 获取数据维度 246 | b, c, h, w = x.shape 247 | 248 | # 数据归一化后划分为 Q K V 注意力机制 249 | x = self.pre_norm(x) 250 | qkv = self.to_qkv(x).chunk(3, dim=1) 251 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h=self.heads), qkv) 252 | 253 | # 获取 QKV 的计算结果 254 | q = q * self.scale 255 | sim = einsum('b h d i, b h d j -> b h i j', q, k) 256 | attn = sim.softmax(dim=-1) 257 | out = einsum('b h i j, b h d j -> b h i d', attn, v) 258 | 259 | # 维度转换后获取网络输出 260 | out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w) 261 | return self.to_out(out) 262 | 263 | 264 | class MIDAttention(nn.Module): 265 | def __init__(self, dim, heads=4, dim_head=32): 266 | super().__init__() 267 | 268 | # 超参数设置(注意力头数目 归一化因子 注意力隐藏层维度) 269 | self.heads = heads 270 | self.scale = dim_head ** -0.5 271 | hidden_dim = dim_head * heads 272 | 273 | # 设置归一化层、QKV转换层、注意力输出层 274 | self.pre_norm_x = LayerNorm(dim) 275 | self.pre_norm_c = LayerNorm(dim) 276 | 277 | self.to_qkv_x = nn.Conv2d(dim, hidden_dim * 3, (1, 1), bias=False) 278 | self.to_qkv_c = nn.Conv2d(dim, hidden_dim * 3, (1, 1), bias=False) 279 | self.to_out = nn.Conv2d(hidden_dim, dim, (1, 1)) 280 | 281 | def forward(self, x, c_x): 282 | # 获取数据维度 283 | b, c, h, w = x.shape 284 | 285 | # 数据归一化后划分为 Q K V 注意力机制 286 | x = self.pre_norm_x(x) 287 | c_x = self.pre_norm_c(c_x) 288 | 289 | qkv_x = self.to_qkv_x(x).chunk(3, dim=1) 290 | qkv_c = self.to_qkv_c(c_x).chunk(3, dim=1) 291 | 292 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h=self.heads), qkv_x) 293 | q_c, k_c, v_c = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h=self.heads), qkv_c) 294 | 295 | # 获取 QKV 的计算结果 296 | q_c = q_c * self.scale 297 | sim = einsum('b h d i, b h d j -> b h i j', q_c, k) 298 | attn = sim.softmax(dim=-1) 299 | out = einsum('b h i j, b h d j -> b h i d', attn, v) 300 | 301 | # 维度转换后获取网络输出 302 | out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w) 303 | return self.to_out(out) 304 | 305 | 306 | # 通道注意力机制 CAM 307 | class ChannelAttentionModule(nn.Module): 308 | def __init__(self, channel, ratio=16): 309 | super(ChannelAttentionModule, self).__init__() 310 | 311 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 312 | self.max_pool = nn.AdaptiveMaxPool2d(1) 313 | 314 | self.shared_MLP = nn.Sequential( 315 | nn.Conv2d(channel, channel // ratio, (1, 1), bias=False), 316 | nn.ReLU(), 317 | nn.Conv2d(channel // ratio, channel, (1, 1), bias=False)) 318 | 319 | self.sigmoid = nn.Sigmoid() 320 | 321 | def forward(self, x): 322 | avg_out = self.shared_MLP(self.avg_pool(x)) 323 | max_out = self.shared_MLP(self.max_pool(x)) 324 | return self.sigmoid(avg_out + max_out) 325 | 326 | 327 | # 空间注意力机制 SAM 328 | class SpatialAttentionModule(nn.Module): 329 | def __init__(self): 330 | super(SpatialAttentionModule, self).__init__() 331 | self.conv2d = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=(7, 7), stride=(1, 1), padding=3) 332 | self.sigmoid = nn.Sigmoid() 333 | 334 | def forward(self, x): 335 | avg_out = torch.mean(x, dim=1, keepdim=True) 336 | max_out, _ = torch.max(x, dim=1, keepdim=True) 337 | out = torch.cat([avg_out, max_out], dim=1) 338 | out = self.sigmoid(self.conv2d(out)) 339 | return out 340 | 341 | 342 | # Contrast Enhancement Module 343 | class CEMLinearAttention(nn.Module): 344 | def __init__(self, dim, heads=4, dim_head=32): 345 | super().__init__() 346 | 347 | # 超参数设置(注意力头数目 归一化因子 注意力隐藏层维度) 348 | self.heads = heads 349 | self.scale = dim_head ** -0.5 350 | hidden_dim = dim_head * heads 351 | 352 | # 设置归一化层、QKV转换层、注意力输出层 353 | self.pre_norm = LayerNorm(dim) 354 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, (1, 1), bias=False) 355 | self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, (1, 1)), LayerNorm(dim)) 356 | 357 | # 空间和通道注意力机制 358 | self.channel_attention = ChannelAttentionModule(hidden_dim) 359 | self.spatial_attention = SpatialAttentionModule() 360 | 361 | def forward(self, x): 362 | # 获取数据维度 363 | b, c, h, w = x.shape 364 | 365 | # 数据归一化后划分为 Q K V 注意力机制 366 | x = self.pre_norm(x) 367 | qkv = self.to_qkv(x).chunk(3, dim=1) 368 | 369 | # SAM 和 CAM 获取过程 370 | qkv = list(qkv) 371 | qkv[0] = self.channel_attention(qkv[0]) * qkv[0] + qkv[0] 372 | qkv[1] = self.spatial_attention(qkv[1]) * qkv[1] + qkv[1] 373 | 374 | # 格式转换 375 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h=self.heads), qkv) 376 | 377 | # 线性注意力机制的计算过程 378 | q = q.softmax(dim=-2) 379 | k = k.softmax(dim=-1) 380 | q = q * self.scale 381 | context = torch.einsum('b h d n, b h e n -> b h d e', k, v) 382 | 383 | # 获取线性注意力机制的输出结果 384 | out = torch.einsum('b h d e, b h d n -> b h e n', context, q) 385 | out = rearrange(out, 'b h c (x y) -> b (h c) x y', h=self.heads, x=h, y=w) 386 | 387 | return self.to_out(out) 388 | 389 | 390 | # Transformer 网络结构 391 | class Transformer(nn.Module): 392 | def __init__(self, dim, dim_head=32, heads=4, depth=1): 393 | super().__init__() 394 | self.layers = nn.ModuleList([]) 395 | for _ in range(depth): 396 | self.layers.append(nn.ModuleList([ 397 | Residual(Attention(dim, dim_head=dim_head, heads=heads)), 398 | Residual(feed_forward_att(dim))])) 399 | 400 | def forward(self, x): 401 | for attn, ff_linear in self.layers: 402 | x = attn(x) 403 | x = ff_linear(x) 404 | return x 405 | 406 | 407 | # Conditional Attention Transformer 网络结构 408 | class MIDTransformer(nn.Module): 409 | def __init__(self, dim, dim_head=32, heads=4, depth=1): 410 | super().__init__() 411 | self.layers = nn.ModuleList([]) 412 | for _ in range(depth): 413 | self.layers.append(nn.ModuleList([ 414 | Residual(MIDAttention(dim, dim_head=dim_head, heads=heads)), 415 | Residual(feed_forward_att(dim)), 416 | Residual(feed_forward_att(dim)), 417 | Residual(MIDAttention(dim, dim_head=dim_head, heads=heads))])) 418 | 419 | def forward(self, x, c): 420 | for attn_1, ff_linear_1, ff_linear_2, attn_2 in self.layers: 421 | x = attn_1(x, c) 422 | x1 = ff_linear_1(x) 423 | x2 = ff_linear_2(x) 424 | x = attn_2(x1, x2) 425 | return x 426 | 427 | 428 | # 扩散模型的FFT编码过程 429 | class Conditioning(nn.Module): 430 | def __init__(self, fmap_size, dim): 431 | super().__init__() 432 | 433 | # 初始化调制高频的注意力图 434 | self.ff_theta = nn.Parameter(torch.ones(dim, 1, 1)) 435 | self.ff_parser_attn_map_r = nn.Parameter(torch.ones(dim, fmap_size, fmap_size)) 436 | self.ff_parser_attn_map_i = nn.Parameter(torch.ones(dim, fmap_size, fmap_size)) 437 | 438 | # 输入变量归一化 439 | self.norm_input = LayerNorm(dim, bias=True) 440 | 441 | # 构建残差模块 442 | self.block = ResnetBlock(dim, dim) 443 | 444 | # 自注意力机制 445 | self.attention_f = CEMLinearAttention(dim, heads=4, dim_head=32) 446 | 447 | def forward(self, x): 448 | # 调制高频的注意力图 449 | x_type = x.dtype 450 | 451 | # 二维傅里叶变换 452 | z = fft2(x) 453 | 454 | # 获取傅里叶变换后的 实部 和 虚部 455 | z_real = z.real 456 | z_imag = z.imag 457 | 458 | # 频域滤波器 保持低频,增强高频 459 | # 可学习高频滤波 或者 高频滤波器 (实部 和 虚部的加权处理) 460 | z_real = z_real * self.ff_parser_attn_map_r 461 | z_imag = z_imag * self.ff_parser_attn_map_i 462 | 463 | # 合并为复数形式 464 | z = torch.complex(z_real * self.ff_theta, z_imag * self.ff_theta) 465 | 466 | # 反变换后只需要实部,虚部 为误差 467 | z = ifft2(z).real 468 | 469 | # 格式转换 470 | z = z.type(x_type) 471 | 472 | # 条件变量和输入变量的融合 473 | norm_z = self.norm_input(z) 474 | 475 | # 利用自注意力机制增强学习到的特征 476 | norm_z = self.attention_f(norm_z + x) 477 | 478 | # 添加一个额外的块以允许更多信息集成,在条件块之后有一个下采样(但也许有一个比下采样之前更好的条件) 479 | return self.block(norm_z) 480 | 481 | 482 | @beartype 483 | class Unet(nn.Module): 484 | def __init__(self, dim, image_size, mask_channels=1, input_img_channels=3, init_dim=None, 485 | dim_mult: tuple = (1, 2, 4, 8), full_self_attn: tuple = (False, False, False, True), attn_dim_head=32, 486 | attn_heads=4, mid_transformer_depth=1, self_condition=False, resnet_block_groups=8, 487 | conditioning_klass=Conditioning, skip_connect_condition_fmap=False): 488 | """ 489 | :param dim: 基础维度 490 | :param image_size: 图像大小 491 | :param init_dim: 初始维度 492 | :param dim_mult: 维度乘子 493 | :param attn_dim_head: 注意力机制的基础维度 494 | :param attn_heads: 注意力机制的多头数目 495 | :param input_img_channels: 输入原始图像通道数目 496 | :param mask_channels: 输入掩码通道数目(无自条件时, 输出通道数目) 497 | :param mid_transformer_depth: Transformer 深度 498 | :param full_self_attn: 自注意力机制 499 | :param self_condition: 自条件引导 500 | :param resnet_block_groups: 残差模块的组卷积 501 | :param conditioning_klass: 条件模块 502 | :param skip_connect_condition_fmap: 解码部分是否采用编码中的条件模块输出 503 | """ 504 | super().__init__() 505 | 506 | # 超参数的确定 507 | self.image_size = image_size 508 | self.mask_channels = mask_channels 509 | self.self_condition = self_condition 510 | self.input_img_channels = input_img_channels 511 | 512 | # 判断是否添加自条件引导 - 更改了输入通道数目 513 | output_channels = mask_channels 514 | mask_channels = input_img_channels 515 | 516 | # 确定初始转换维度 517 | init_dim = default(init_dim, dim) 518 | 519 | # 输入变量和条件变量的初始卷积过程 520 | self.init_conv = nn.Conv2d(mask_channels, init_dim, (7, 7), padding=3) 521 | self.cond_init_conv = nn.Conv2d(input_img_channels, init_dim, (7, 7), padding=3) 522 | 523 | # 获取各网络层的卷积特征图维度 524 | dims = [init_dim, *map(lambda m: dim * m, dim_mult)] 525 | in_out = list(zip(dims[:-1], dims[1:])) 526 | 527 | # 建立卷积-归一化-激活模块, 设置组数目 528 | block_klass = partial(ResnetBlock, groups=resnet_block_groups) 529 | 530 | # 时间编码维度和时间编码模块 531 | time_dim = dim * 4 532 | self.time_mlp = nn.Sequential(SinusoidalPosEmb(dim), 533 | nn.Linear(dim, time_dim), 534 | nn.GELU(), 535 | nn.Linear(time_dim, time_dim)) 536 | 537 | # 注意力机制相关参数 538 | attn_kwargs = dict(dim_head=attn_dim_head, heads=attn_heads) 539 | 540 | # 获取卷积模块的层数 541 | num_resolutions = len(in_out) 542 | assert len(full_self_attn) == num_resolutions 543 | 544 | # 参数初始化和赋值 545 | curr_fmap_size = image_size 546 | self.downs = nn.ModuleList([]) 547 | self.conditioners = nn.ModuleList([]) 548 | self.skip_connect_condition_fmap = skip_connect_condition_fmap 549 | 550 | # 下采样编码模块 551 | for ind, ((dim_in, dim_out), full_attn) in enumerate(zip(in_out, full_self_attn)): 552 | # 判断是否为最后的卷积模块 553 | is_last = ind >= (num_resolutions - 1) 554 | 555 | # 判断注意力机制的类型 556 | attn_klass = Attention if full_attn else LinearAttention 557 | 558 | # 添加条件编码模块 559 | self.conditioners.append(conditioning_klass(curr_fmap_size, dim_in)) 560 | 561 | # 添加下采样模块 562 | self.downs.append(nn.ModuleList([ 563 | block_klass(dim_in, dim_in, time_emb_dim=time_dim), 564 | block_klass(dim_in, dim_in, time_emb_dim=time_dim), 565 | Residual(attn_klass(dim_in, **attn_kwargs)), 566 | down_sample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, (3, 3), padding=1)])) 567 | 568 | # 特征图规模减半 下采样 569 | if not is_last: 570 | curr_fmap_size //= 2 571 | 572 | # 中间层模块 利用 Transformer 代替残差连接 573 | mid_dim = dims[-1] 574 | self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) 575 | self.mid_transformer = MIDTransformer(mid_dim, depth=mid_transformer_depth, **attn_kwargs) 576 | self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) 577 | 578 | # 条件编码路径将与主编码路径相同 579 | self.ups = nn.ModuleList([]) 580 | self.cond_downs = copy.deepcopy(self.downs) 581 | self.cond_mid_block1 = copy.deepcopy(self.mid_block1) 582 | 583 | # 上采样解码模块 584 | for ind, ((dim_in, dim_out), full_attn) in enumerate(zip(reversed(in_out), reversed(full_self_attn))): 585 | # 判断是否为最后的卷积模块 586 | is_last = ind == (len(in_out) - 1) 587 | 588 | # 判断注意力机制的类型 589 | attn_klass = Attention if full_attn else LinearAttention 590 | 591 | # 解码部分是否采用编码中的条件模块输出 592 | skip_connect_dim = dim_in * (2 if self.skip_connect_condition_fmap else 1) 593 | 594 | # 添加上采样模块 595 | self.ups.append(nn.ModuleList([ 596 | block_klass(dim_out + skip_connect_dim, dim_out, time_emb_dim=time_dim), 597 | block_klass(dim_out + skip_connect_dim, dim_out, time_emb_dim=time_dim), 598 | Residual(attn_klass(dim_out, **attn_kwargs)), 599 | up_sample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, (3, 3), padding=1)])) 600 | 601 | # 最后的输出层 602 | self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim) 603 | self.final_conv = nn.Conv2d(dim, output_channels, (1, 1)) 604 | 605 | def forward(self, x, time, cond, x_self_cond=None): 606 | # 解码部分是否采用条件模块 607 | skip_connect_c = self.skip_connect_condition_fmap 608 | 609 | # 是否将条件合并到输入中 610 | if self.self_condition: 611 | x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x)) 612 | x = torch.cat((x_self_cond, x), dim=1) 613 | 614 | # 输入变量和条件变量的初始卷积过程 615 | x = self.init_conv(x) 616 | c = self.cond_init_conv(cond) 617 | 618 | # 获取初始卷积后的输出, 用于最后的拼接 619 | r = x.clone() 620 | 621 | # 时间编码模块 622 | t = self.time_mlp(time) 623 | 624 | # 下采样编码阶段 625 | h = [] 626 | for (block1, block2, attn, d_sample), (cond_block1, cond_block2, cond_attn, 627 | cond_d_sample), conditioner in zip(self.downs, self.cond_downs, 628 | self.conditioners): 629 | # 卷积编码模块 + 条件编码模块 630 | x = block1(x, t) 631 | c = cond_block1(c, t) 632 | 633 | # 保存卷积和条件编码结果 634 | h.append([x, c] if skip_connect_c else [x]) 635 | 636 | # 卷积编码模块 + 条件编码模块 637 | x = block2(x, t) 638 | c = cond_block2(c, t) 639 | 640 | # 注意力模块和条件注意力模块输出 641 | x = attn(x) 642 | c = cond_attn(c) 643 | 644 | # 傅里叶调制状态 645 | x = conditioner(x) 646 | 647 | # 保存卷积和条件编码结果 648 | h.append([x, c] if skip_connect_c else [x]) 649 | 650 | # 下采样模块 条件下采样模块 651 | x = d_sample(x) 652 | c = cond_d_sample(c) 653 | 654 | # 卷积和条件的中间层编码模块 655 | x = self.mid_block1(x, t) 656 | c = self.cond_mid_block1(c, t) 657 | 658 | # 条件编码和卷积编码的融合 659 | x = x + c 660 | 661 | # 中间层编码的注意力机制 662 | x = self.mid_transformer(x, c) 663 | x = self.mid_block2(x, t) 664 | 665 | # 上采样解码模块 666 | for block1, block2, attn, up_s in self.ups: 667 | # 合并原始输入和编码中的条件模块输出和原始输出 668 | x = torch.cat((x, *h.pop()), dim=1) 669 | x = block1(x, t) 670 | 671 | # 合并原始输入和编码中的条件模块输出和原始输出 672 | x = torch.cat((x, *h.pop()), dim=1) 673 | x = block2(x, t) 674 | 675 | # 注意力机制 676 | x = attn(x) 677 | 678 | # 上采样模块 679 | x = up_s(x) 680 | 681 | # 合并输出和初始卷积后的输出 682 | x = torch.cat((x, r), dim=1) 683 | 684 | # 最后的卷积层 685 | x = self.final_res_block(x, t) 686 | return self.final_conv(x) 687 | 688 | 689 | # 高斯扩散训练器类 690 | def extract(a, t, x_shape): 691 | b, *_ = t.shape 692 | out = a.gather(-1, t) 693 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 694 | 695 | 696 | # 线性采样方案 697 | def linear_beta_schedule(time_steps): 698 | scale = 1000 / time_steps 699 | beta_start = scale * 0.0001 700 | beta_end = scale * 0.02 701 | return torch.linspace(beta_start, beta_end, time_steps, dtype=torch.float64) 702 | 703 | 704 | # 正弦采样方案 705 | def cosine_beta_schedule(time_steps, s=0.008): 706 | steps = time_steps + 1 707 | x = torch.linspace(0, time_steps, steps, dtype=torch.float64) 708 | alphas_cum_prod = torch.cos(((x / time_steps) + s) / (1 + s) * math.pi * 0.5) ** 2 709 | alphas_cum_prod = alphas_cum_prod / alphas_cum_prod[0] 710 | betas = 1 - (alphas_cum_prod[1:] / alphas_cum_prod[:-1]) 711 | return torch.clip(betas, 0, 0.999) 712 | 713 | 714 | # 医学图像分割模型 715 | class MedSegDiff(nn.Module): 716 | def __init__(self, model, time_steps=1000, sampling_time_steps=None, objective='predict_x0', 717 | beta_schedule='cosine', ddim_sampling_eta=1.): 718 | """ 719 | :param model: 分割模型 UNet 720 | :param time_steps: 加噪的步长 721 | :param sampling_time_steps: 采样步长 722 | :param objective: 预测目标 723 | :param beta_schedule: 加噪方案 724 | :param ddim_sampling_eta: 采样率 725 | """ 726 | super().__init__() 727 | 728 | # 参数的赋值 729 | self.model = model 730 | self.objective = objective 731 | self.image_size = model.image_size 732 | self.mask_channels = self.model.mask_channels 733 | self.self_condition = self.model.self_condition 734 | self.input_img_channels = self.model.input_img_channels 735 | 736 | # 加噪方案的选择, 获取 beta 737 | if beta_schedule == 'linear': 738 | betas = linear_beta_schedule(time_steps) 739 | elif beta_schedule == 'cosine': 740 | betas = cosine_beta_schedule(time_steps) 741 | else: 742 | raise ValueError(f'unknown beta schedule {beta_schedule}') 743 | 744 | # 获取 alpha 值和累乘结果 745 | alphas = 1. - betas 746 | alphas_cum_prod = torch.cumprod(alphas, dim=0) 747 | alphas_cum_prod_prev = F.pad(alphas_cum_prod[:-1], (1, 0), value=1.) 748 | 749 | # 根据加噪长度 获取加噪步长和采样步长 750 | time_steps, = betas.shape 751 | self.num_time_steps = int(time_steps) 752 | self.sampling_time_steps = default(sampling_time_steps, time_steps) 753 | assert self.sampling_time_steps <= time_steps 754 | 755 | # 默认采样时间步数到训练时的时间步数 756 | self.is_ddim_sampling = self.sampling_time_steps < time_steps 757 | self.ddim_sampling_eta = ddim_sampling_eta 758 | 759 | # 辅助函数,用于将缓冲区从 float64 注册到 float32 760 | def register_buffer(name, val): 761 | self.register_buffer(name, val.to(torch.float32)) 762 | 763 | # 获取 beta 值和累乘结果 764 | register_buffer('betas', betas) 765 | register_buffer('alphas_cum_prod', alphas_cum_prod) 766 | register_buffer('alphas_cum_prod_prev', alphas_cum_prod_prev) 767 | 768 | # 扩散模型相关公式的计算 q(x_t | x_{t-1}) 769 | register_buffer('sqrt_alphas_cum_prod', torch.sqrt(alphas_cum_prod)) 770 | register_buffer('sqrt_one_minus_alphas_cum_prod', torch.sqrt(1. - alphas_cum_prod)) 771 | register_buffer('log_one_minus_alphas_cum_prod', torch.log(1. - alphas_cum_prod)) 772 | register_buffer('sqrt_recip_alphas_cum_prod', torch.sqrt(1. / alphas_cum_prod)) 773 | register_buffer('sqrt_recip_m1_alphas_cum_prod', torch.sqrt(1. / alphas_cum_prod - 1)) 774 | 775 | # 后验计算过程 q(x_{t-1} | x_t, x_0) 776 | posterior_variance = betas * (1. - alphas_cum_prod_prev) / (1. - alphas_cum_prod) 777 | 778 | # 以上: 等于 1. / (1. / (1. - alpha_cum_prod_t_m1) + alpha_t / beta_t) 779 | register_buffer('posterior_variance', posterior_variance) 780 | 781 | # 下面: 由于扩散链开始时的后验方差为 0 而剪裁的对数计算 782 | register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min=1e-20))) 783 | 784 | # 后验均值的相关系数 785 | register_buffer('posterior_mean_cof_1', betas * torch.sqrt(alphas_cum_prod_prev) / (1. - alphas_cum_prod)) 786 | register_buffer('posterior_mean_cof_2', 787 | (1. - alphas_cum_prod_prev) * torch.sqrt(alphas) / (1. - alphas_cum_prod)) 788 | 789 | @property 790 | def device(self): 791 | return next(self.parameters()).device 792 | 793 | def predict_noise_from_start(self, x_t, t, x0): 794 | """ 根据真实的 x_t 和预测的 x_0, 获取添加的噪声 """ 795 | return ((extract(self.sqrt_recip_alphas_cum_prod, t, x_t.shape) * x_t - x0) / 796 | extract(self.sqrt_recip_m1_alphas_cum_prod, t, x_t.shape)) 797 | 798 | def model_predictions(self, x, t, c, x_self_cond=None, clip_x_start=False): 799 | """ 模型的预测过程 """ 800 | 801 | # 获取 UNet 网络的输出 802 | model_output = self.model(x, t, c, x_self_cond) 803 | 804 | # 是否对输出结果进行限制 805 | maybe_clip = partial(torch.clamp, min=-1., max=1.) if clip_x_start else identity 806 | 807 | # 直接预测逆扩散结果 808 | if self.objective == 'predict_x0': 809 | x_start = model_output 810 | x_start = maybe_clip(x_start) 811 | predict_noise = self.predict_noise_from_start(x, t, x_start) 812 | 813 | else: 814 | raise ValueError(f'unknown objective {self.objective}') 815 | 816 | return ModelPrediction(predict_noise, x_start) 817 | 818 | @torch.no_grad() 819 | def p_sample(self, x, t, c, x_self_cond=None, clip_de_noised=True): 820 | 821 | """ 通过神经网络预测均值和方差, 即通过x_t 预测 x_{t - 1} 的均值和方差,也包括 x_0 的预测""" 822 | batched_times = torch.full((x.shape[0],), t, device=x.device, dtype=torch.long) 823 | 824 | predicts = self.model_predictions(x, batched_times, c, x_self_cond) 825 | 826 | # 获取预测的 x_0 827 | x_start = predicts.predict_x_start 828 | if clip_de_noised: 829 | x_start.clamp_(-1., 1.) 830 | 831 | return x_start 832 | 833 | @torch.no_grad() 834 | def p_sample_loop(self, cond): 835 | """ 推理过程中, 给定 x_t 采样 x_{t-1}, 递归采样获取 x_0, 样本恢复过程 """ 836 | 837 | # 设置条件图像, 获取输入噪声 838 | x_start = None 839 | img = cond 840 | 841 | # 循环采样过程, 显示加载器 842 | for t in tqdm(reversed(range(0, self.num_time_steps)), desc='sampling time step', total=self.num_time_steps): 843 | # 判断是否采用自条件 进行限制 844 | self_cond = x_start if self.self_condition else None 845 | # 获得 UNet 网络的预测值 更新 846 | img = self.p_sample(img, t, cond, self_cond) 847 | # 图像混合 t 时间 848 | batched_times = torch.full((img.shape[0],), t, device=img.device, dtype=torch.long) 849 | img_xt = self.q_sample(x_start=img, t=batched_times, noise=cond) 850 | # 图像混合 t - 1 时间 851 | img_xt_sub = img_xt 852 | if t - 1 != -1: 853 | batched_times = torch.full((img.shape[0],), t - 1, device=img.device, dtype=torch.long) 854 | img_xt_sub = self.q_sample(x_start=img_xt_sub, t=batched_times, noise=cond) 855 | 856 | # 图像更新 857 | img = img - img_xt + img_xt_sub 858 | 859 | # 反标准化 860 | img = un_normalize_to_zero_to_one(img) 861 | return img 862 | 863 | @torch.no_grad() 864 | def p_sample_loop_ones(self, cond): 865 | """ 推理过程中, 给定 x_t 采样 x_{t-1}, 递归采样获取 x_0, 样本恢复过程 """ 866 | 867 | # 设置条件图像, 获取输入噪声 868 | x_start = None 869 | img = cond 870 | 871 | # 判断是否采用自条件 进行限制 872 | self_cond = x_start if self.self_condition else None 873 | 874 | # 获得 UNet 网络的预测值 更新 875 | t = self.num_time_steps 876 | img = self.p_sample(img, t - 1, cond, self_cond) 877 | 878 | # 反标准化 879 | img = un_normalize_to_zero_to_one(img) 880 | 881 | return img 882 | 883 | @torch.no_grad() 884 | def sample(self, cond_img): 885 | 886 | # 将条件图像添加至运行设备 887 | cond_img = cond_img.to(self.device) 888 | 889 | # 判断是否采用加速采样 890 | sample_fn = self.p_sample_loop 891 | 892 | # 返回预测采样结果 img 893 | return sample_fn(cond_img) 894 | 895 | @torch.no_grad() 896 | def sample_ones(self, cond_img): 897 | 898 | # 将条件图像添加至运行设备 899 | cond_img = cond_img.to(self.device) 900 | 901 | # 判断是否采用加速采样 902 | sample_fn = self.p_sample_loop_ones 903 | 904 | # 返回预测采样结果 img 905 | return sample_fn(cond_img) 906 | 907 | def q_sample(self, x_start, t, noise): 908 | """ 前向扩散过程(重参数化采样), 从 q (x_t | x_0) 中采样, 获得 x_t """ 909 | return (extract(self.sqrt_alphas_cum_prod, t, x_start.shape) * x_start + 910 | extract(self.sqrt_one_minus_alphas_cum_prod, t, x_start.shape) * noise) 911 | 912 | def p_losses(self, x_start, t, cond): 913 | """ 损失计算过程 """ 914 | 915 | # 根据噪声生成加噪后图像 916 | x = self.q_sample(x_start=x_start, t=t, noise=cond) 917 | 918 | # 如果加入自条件,50% 的时间,根据 UNet 的当前时间和条件预测 x_start,这种技术将使训练速度减慢 25%,但似乎会显着降低 FID 919 | x_self_cond = None 920 | if self.self_condition and random() < 0.5: 921 | with torch.no_grad(): 922 | x_self_cond = self.model_predictions(x, t, cond).predict_x_start 923 | x_self_cond.detach_() 924 | 925 | # 预测并采取梯度步长 926 | model_out = self.model(x, t, cond, x_self_cond) 927 | 928 | # 选择预测目标 929 | if self.objective == 'predict_x0': 930 | target = x_start 931 | else: 932 | raise ValueError(f'unknown objective {self.objective}') 933 | 934 | # 计算损失 MSE 935 | return F.mse_loss(model_out, target) 936 | 937 | def forward(self, img, cond_img, epoch, epochs): 938 | """ 前向计算过程, 直接获取损失 """ 939 | 940 | # 数据格式转换 941 | if img.ndim == 3: 942 | img = rearrange(img, 'b h w -> b 1 h w') 943 | 944 | if cond_img.ndim == 3: 945 | cond_img = rearrange(cond_img, 'b h w -> b 1 h w') 946 | 947 | # 获取运行设备 948 | device = self.device 949 | 950 | # 将输入和条件图像添加到运行设备 951 | img, cond_img = img.to(device), cond_img.to(device) 952 | 953 | # 对图像的大小进行判断, 并给出警告 954 | b, c, h, w = img.shape 955 | img_size = self.image_size 956 | img_channels, mask_channels = self.input_img_channels, self.mask_channels 957 | 958 | assert h == img_size and w == img_size, f'height and width of image must be {img_size}' 959 | assert cond_img.shape[1] == img_channels, f'your input medical must have {img_channels} channels' 960 | assert img.shape[1] == mask_channels, f'the segmented image must have {mask_channels} channels' 961 | 962 | # 生成时间编码 963 | sc = int(np.sqrt(epoch / epochs) * self.num_time_steps) 964 | times = torch.randint(sc, self.num_time_steps, (b,), device=device).long() 965 | 966 | # 对图像进行归一化 967 | img = normalize_to_neg_one_to_one(img) 968 | 969 | # 计算损失 970 | return self.p_losses(img, times, cond_img) 971 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TimesXY/Cold-SegDiffusion/8898bade59a9fe9261022cfa793354b0c1fe8157/requirements.txt --------------------------------------------------------------------------------