├── LICENSE ├── Model.py └── README.md /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 NewCodeAlg 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | from torch.nn import functional as F 6 | 7 | 8 | class Swish(nn.Module): 9 | def forward(self, x): 10 | return x * torch.sigmoid(x) 11 | 12 | 13 | class TimeEmbedding(nn.Module): 14 | def __init__(self, T, d_model, dim): 15 | assert d_model % 2 == 0 16 | super().__init__() 17 | emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000) 18 | emb = torch.exp(-emb) 19 | pos = torch.arange(T).float() 20 | emb = pos[:, None] * emb[None, :] 21 | assert list(emb.shape) == [T, d_model // 2] 22 | emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1) 23 | assert list(emb.shape) == [T, d_model // 2, 2] 24 | emb = emb.view(T, d_model) 25 | 26 | self.timembedding = nn.Sequential( 27 | nn.Embedding.from_pretrained(emb), 28 | nn.Linear(d_model, dim), 29 | Swish(), 30 | nn.Linear(dim, dim), 31 | ) 32 | self.initialize() 33 | 34 | def initialize(self): 35 | for module in self.modules(): 36 | if isinstance(module, nn.Linear): 37 | init.xavier_uniform_(module.weight) 38 | init.zeros_(module.bias) 39 | 40 | def forward(self, t): 41 | emb = self.timembedding(t) 42 | return emb 43 | 44 | 45 | class DownSample(nn.Module): 46 | def __init__(self, in_ch): 47 | super().__init__() 48 | self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1) 49 | self.initialize() 50 | 51 | def initialize(self): 52 | init.xavier_uniform_(self.main.weight) 53 | init.zeros_(self.main.bias) 54 | 55 | def forward(self, x, temb): 56 | x = self.main(x) 57 | return x 58 | 59 | 60 | class UpSample(nn.Module): 61 | def __init__(self, in_ch): 62 | super().__init__() 63 | self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1) 64 | self.initialize() 65 | 66 | def initialize(self): 67 | init.xavier_uniform_(self.main.weight) 68 | init.zeros_(self.main.bias) 69 | 70 | def forward(self, x, temb): 71 | _, _, H, W = x.shape 72 | x = F.interpolate( 73 | x, scale_factor=2, mode='nearest') 74 | x = self.main(x) 75 | return x 76 | 77 | 78 | class AttnBlock(nn.Module): 79 | def __init__(self, in_ch): 80 | super().__init__() 81 | self.group_norm = nn.GroupNorm(32, in_ch) 82 | self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) 83 | self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) 84 | self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) 85 | self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) 86 | self.initialize() 87 | 88 | def initialize(self): 89 | for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]: 90 | init.xavier_uniform_(module.weight) 91 | init.zeros_(module.bias) 92 | init.xavier_uniform_(self.proj.weight, gain=1e-5) 93 | 94 | def forward(self, x): 95 | B, C, H, W = x.shape 96 | h = self.group_norm(x) 97 | q = self.proj_q(h) 98 | k = self.proj_k(h) 99 | v = self.proj_v(h) 100 | 101 | q = q.permute(0, 2, 3, 1).view(B, H * W, C) 102 | k = k.view(B, C, H * W) 103 | w = torch.bmm(q, k) * (int(C) ** (-0.5)) 104 | assert list(w.shape) == [B, H * W, H * W] 105 | w = F.softmax(w, dim=-1) 106 | 107 | v = v.permute(0, 2, 3, 1).view(B, H * W, C) 108 | h = torch.bmm(w, v) 109 | assert list(h.shape) == [B, H * W, C] 110 | h = h.view(B, H, W, C).permute(0, 3, 1, 2) 111 | h = self.proj(h) 112 | 113 | return x + h 114 | 115 | 116 | class ResBlock(nn.Module): 117 | def __init__(self, in_ch, out_ch, tdim, dropout, attn=False): 118 | super().__init__() 119 | self.block1 = nn.Sequential( 120 | nn.GroupNorm(32, in_ch), 121 | Swish(), 122 | nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1), 123 | ) 124 | self.temb_proj = nn.Sequential( 125 | Swish(), 126 | nn.Linear(tdim, out_ch), 127 | ) 128 | self.block2 = nn.Sequential( 129 | nn.GroupNorm(32, out_ch), 130 | Swish(), 131 | nn.Dropout(dropout), 132 | nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1), 133 | ) 134 | if in_ch != out_ch: 135 | self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0) 136 | else: 137 | self.shortcut = nn.Identity() 138 | if attn: 139 | self.attn = AttnBlock(out_ch) 140 | else: 141 | self.attn = nn.Identity() 142 | self.initialize() 143 | 144 | def initialize(self): 145 | for module in self.modules(): 146 | if isinstance(module, (nn.Conv2d, nn.Linear)): 147 | init.xavier_uniform_(module.weight) 148 | init.zeros_(module.bias) 149 | init.xavier_uniform_(self.block2[-1].weight, gain=1e-5) 150 | 151 | def forward(self, x, temb): 152 | h = self.block1(x) 153 | h += self.temb_proj(temb)[:, :, None, None] 154 | h = self.block2(h) 155 | 156 | h = h + self.shortcut(x) 157 | h = self.attn(h) 158 | return h 159 | 160 | 161 | class UNet(nn.Module): 162 | def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout): 163 | super().__init__() 164 | assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound' 165 | tdim = ch * 4 166 | self.time_embedding = TimeEmbedding(T, ch, tdim) 167 | 168 | self.head = nn.Conv2d(1, ch, kernel_size=3, stride=1, padding=1) 169 | self.downblocks = nn.ModuleList() 170 | chs = [ch] 171 | now_ch = ch 172 | for i, mult in enumerate(ch_mult): 173 | out_ch = ch * mult 174 | for _ in range(num_res_blocks): 175 | self.downblocks.append(ResBlock( 176 | in_ch=now_ch, out_ch=out_ch, tdim=tdim, 177 | dropout=dropout, attn=(i in attn))) 178 | now_ch = out_ch 179 | chs.append(now_ch) 180 | if i != len(ch_mult) - 1: 181 | self.downblocks.append(DownSample(now_ch)) 182 | chs.append(now_ch) 183 | 184 | self.middleblocks = nn.ModuleList([ 185 | ResBlock(now_ch, now_ch, tdim, dropout, attn=True), 186 | ResBlock(now_ch, now_ch, tdim, dropout, attn=False), 187 | ]) 188 | 189 | self.upblocks = nn.ModuleList() 190 | for i, mult in reversed(list(enumerate(ch_mult))): 191 | out_ch = ch * mult 192 | for _ in range(num_res_blocks + 1): 193 | self.upblocks.append(ResBlock( 194 | in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, 195 | dropout=dropout, attn=(i in attn))) 196 | now_ch = out_ch 197 | if i != 0: 198 | self.upblocks.append(UpSample(now_ch)) 199 | assert len(chs) == 0 200 | 201 | self.tail = nn.Sequential( 202 | nn.GroupNorm(32, now_ch), 203 | Swish(), 204 | nn.Conv2d(now_ch, 1, 3, stride=1, padding=1) 205 | ) 206 | self.initialize() 207 | 208 | def initialize(self): 209 | init.xavier_uniform_(self.head.weight) 210 | init.zeros_(self.head.bias) 211 | init.xavier_uniform_(self.tail[-1].weight, gain=1e-5) 212 | init.zeros_(self.tail[-1].bias) 213 | 214 | def forward(self, x, t): 215 | temb = self.time_embedding(t) 216 | h = self.head(x) 217 | hs = [h] 218 | for layer in self.downblocks: 219 | h = layer(h, temb) 220 | hs.append(h) 221 | for layer in self.middleblocks: 222 | h = layer(h, temb) 223 | for layer in self.upblocks: 224 | if isinstance(layer, ResBlock): 225 | h = torch.cat([h, hs.pop()], dim=1) 226 | h = layer(h, temb) 227 | h = self.tail(h) 228 | 229 | assert len(hs) == 0 230 | return h 231 | 232 | 233 | if __name__ == '__main__': 234 | batch_size = 8 235 | model = UNet( 236 | T=1000, ch=128, ch_mult=[1, 2, 2, 2], attn=[1], 237 | num_res_blocks=2, dropout=0.1) 238 | x = torch.randn(batch_size, 1, 64, 64) 239 | t = torch.randint(1000, (batch_size, )) 240 | y = model(x, t) 241 | print(y.shape) 242 | 243 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # The Source Code of A Novel Data Augmentation Method based on Denoising Diffusion Probabilistic Model for Fault Diagnosis Under Imbalanced Data (TII 2024) 2 | Our model code has been uploaded. If you have any questions about the paper, please feel free to contact me. 3 | 4 | # The citation format of our paper is as follows: 5 | [1] X. Yang, T. Ye, X. Yuan, W. Zhu, X. Mei and F. Zhou, "A Novel Data Augmentation Method Based on Denoising Diffusion Probabilistic Model for Fault Diagnosis Under Imbalanced Data," IEEE Transactions on Industrial Informatics, vol. 20, no. 5, pp. 7820-7831, May 2024, doi: 10.1109/TII.2024.3366991. 6 | 7 | 8 | 9 | 10 | --------------------------------------------------------------------------------