├── nam.py ├── EWSNet.py ├── README.md ├── weights.py └── thresholds.py /nam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from Elaplacenet5 import Net 4 | from gdatasave import train_loader 5 | import pandas as pd 6 | import seaborn as sns 7 | from sklearn.svm import SVC 8 | import numpy as np 9 | from tqdm import tqdm 10 | import matplotlib.pyplot as plt 11 | torch.cuda.manual_seed(1) 12 | torch.manual_seed(2) 13 | 14 | sampling_rate = 50000 15 | window_size = 1024 16 | n_class = 4 17 | channels = [64] 18 | 19 | 20 | def normalize_freq(fam): 21 | max_fam, _ = torch.max(fam, dim=-1, keepdim=True) 22 | return torch.div(fam, max_fam) 23 | # This is from P. Welch's method to compute power spectrum, check [3] 24 | def power_spectrum(t_freq): 25 | result = torch.abs(torch.fft.rfft(t_freq))**2 26 | return result / torch.mean(result, dim=2, keepdim=True) 27 | 28 | 29 | def freq_activation_map(model, input, width, channels, target_label): 30 | ''' 31 | Param: 32 | model : Neural Network Object 33 | input : timeseries data 34 | width : length of power_spectrum(input) 35 | channels : # last channel of the model 36 | ''' 37 | fam = torch.zeros(input.shape[0], 1, width) 38 | if torch.cuda.is_available(): 39 | fam = fam.cuda() 40 | 41 | with torch.no_grad(): 42 | freq, labels = model(torch.reshape(input, (-1, 1, input.shape[-1]))) 43 | freq = freq.unsqueeze(2) 44 | labels = labels.unsqueeze(2) 45 | labels = torch.argmax(F.softmax(labels, dim=1), dim=1) 46 | labels = torch.where(labels == target_label, 1., 0.) 47 | labels = torch.unsqueeze(torch.reshape(labels, [-1, 1]).repeat(1, width), 1) 48 | for c in range(channels): 49 | sp = freq[:, c, :, :] 50 | if torch.cuda.is_available(): 51 | sp = sp.cuda() 52 | 53 | sp = power_spectrum(sp) 54 | sp = sp * labels 55 | 56 | if model.p4[0].weight[target_label, c] > 0: 57 | fam += model.p4[0].weight[target_label, c] * sp 58 | 59 | return torch.squeeze(torch.sum(fam, dim=0)), torch.sum(labels[:, 0, 0], dim=0) 60 | 61 | 62 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 63 | #summary(Net1, input_size=(1, 1024)) # 输出模型具有的参数 64 | model = Net().to(device) 65 | model.load_state_dict(torch.load('./model1001.pt')) 66 | torch.no_grad() 67 | # for img, label in test_loader: 68 | ''' 69 | 测试集 70 | ''' 71 | # Test.Data = Test.Data.type(torch.FloatTensor).to(device) 72 | # label = torch.from_numpy(Test.Label) 73 | # outputs = model(Test.Data) 74 | # outputs = torch.squeeze(outputs).float() 75 | # encoded_data = outputs.cpu().detach().numpy() 76 | # clf = SVC(C=8.94, gamma=3.26) 77 | ''' 78 | 训练集 79 | ''' 80 | # Train.Data = Train.Data.type(torch.FloatTensor).to(device) 81 | # labels = torch.from_numpy(Train.Label) 82 | # outputs = model(Train.Data) 83 | # outputs = torch.squeeze(outputs).float() 84 | # encoded_data = outputs.cpu().detach().numpy() 85 | 86 | # clf = SVC(C=10.82, kernel='rbf', gamma=1.19, decision_function_shape='ovr') # rbf高斯基核函数 87 | # clf.fit(encoded_data, label.cpu().numpy()) 88 | # fit_score = clf.score(encoded_data, label.cpu().numpy()) 89 | # print(fit_score) 90 | 91 | freq_intervals = np.fft.rfftfreq(window_size, d=1/sampling_rate) 92 | total_fam = torch.zeros(n_class, len(freq_intervals)) 93 | total_len = torch.zeros(n_class, 1) 94 | if torch.cuda.is_available(): 95 | total_fam = total_fam.cuda() 96 | total_len = total_len.cuda() 97 | 98 | for batch in tqdm(train_loader): 99 | x, y = batch 100 | if torch.cuda.is_available(): 101 | x = x.cuda() 102 | for n in range(n_class): 103 | tmp_fam, cnt = freq_activation_map(model, x.float(), len(freq_intervals), channels[-1], n) 104 | total_fam[n, :] += tmp_fam 105 | total_len[n] += cnt 106 | 107 | total_fam /= total_len 108 | print(total_fam.size()) 109 | # _, predicted = outputs.max(1) 110 | # print(predicted, Test.Label) 111 | # num_correct = (predicted.cpu() == Test.Label).sum().item() 112 | # acc = num_correct / outputs.cpu().shape[0] 113 | # eval_acc = 0 114 | # eval_acc += fit_score 115 | # print(eval_acc) 116 | rtotal_fam = normalize_freq(total_fam).cpu().detach() 117 | columns_ = np.floor(freq_intervals).astype(int) 118 | plt.plot(columns_, rtotal_fam[0, :]) 119 | 120 | result = pd.DataFrame(rtotal_fam.numpy(), 121 | index = ['0', '1', '2', '3'], 122 | columns = columns_, 123 | ) 124 | 125 | new_index = ['3', '2', '1', '0'] 126 | print(result.values.shape) 127 | new_index.reverse() 128 | 129 | result = result.reindex(new_index) 130 | 131 | sns.heatmap(result, cmap='viridis') 132 | plt.show() 133 | 134 | -------------------------------------------------------------------------------- /EWSNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | # from wmodelsii8 import Sin_fast as fast 4 | # from wmodelsii3 import Laplace_fast as fast 5 | from wmodelsii8 import Laplace_fastv2 as fast 6 | # from wsinc import SincConv_fast as fast 7 | from Shrinkage import Shrinkagev3ppp2 as sage 8 | 9 | class Mish1(nn.Module): 10 | def __init__(self): 11 | super(Mish1, self).__init__() 12 | self.mish = nn.ReLU(inplace=True) 13 | 14 | def forward(self, x): 15 | 16 | return self.mish(x) 17 | 18 | 19 | class Net(nn.Module): 20 | def __init__(self): 21 | super(Net, self).__init__() #85,42,70 #63,31,75 22 | self.p1_0 = nn.Sequential( # nn.Conv1d(1, 50, kernel_size=18, stride=2), 23 | # fast(out_channels=64, kernel_size=250, stride=1), 24 | # fast1(out_channels=70, kernel_size=84, stride=1), 25 | nn.Conv1d(1, 64, kernel_size=250, stride=1, bias=True), 26 | nn.BatchNorm1d(64), 27 | Mish1() 28 | ) 29 | self.p1_1 = nn.Sequential(nn.Conv1d(64, 16, kernel_size=18, stride=2, bias=True), 30 | # fast(out_channels=50, kernel_size=18, stride=2), 31 | nn.BatchNorm1d(16), 32 | Mish1() 33 | ) 34 | self.p1_2 = nn.Sequential(nn.Conv1d(16, 10, kernel_size=10, stride=2, bias=True), 35 | nn.BatchNorm1d(10), 36 | Mish1() 37 | ) 38 | self.p1_3 = nn.MaxPool1d(kernel_size=2) 39 | self.p2_1 = nn.Sequential(nn.Conv1d(64, 32, kernel_size=6, stride=1, bias=True), 40 | # fast(out_channels=50, kernel_size=6, stride=1), 41 | nn.BatchNorm1d(32), 42 | Mish1() 43 | ) 44 | self.p2_2 = nn.Sequential(nn.Conv1d(32, 16, kernel_size=6, stride=1, bias=True), 45 | nn.BatchNorm1d(16), 46 | Mish1() 47 | ) 48 | self.p2_3 = nn.MaxPool1d(kernel_size=2) 49 | self.p2_4 = nn.Sequential(nn.Conv1d(16, 10, kernel_size=6, stride=1, bias=True), 50 | nn.BatchNorm1d(10), 51 | Mish1() 52 | ) 53 | self.p2_5 = nn.Sequential(nn.Conv1d(10, 10, kernel_size=8, stride=2, bias=True), 54 | # nn.Conv1d(10, 10, kernel_size=6, stride=2), 55 | nn.BatchNorm1d(10), 56 | Mish1() 57 | ) # PRelu 58 | self.p2_6 = nn.MaxPool1d(kernel_size=2) 59 | self.p3_0 = sage(channel=64, gap_size=1) 60 | self.p3_1 = nn.Sequential(nn.Conv1d(64, 10, kernel_size=43, stride=4, bias=True), 61 | nn.BatchNorm1d(10), 62 | Mish1() 63 | ) 64 | self.p3_2 = nn.MaxPool1d(kernel_size=2) 65 | self.p3_3 = nn.Sequential(nn.AdaptiveAvgPool1d(1)) 66 | self.p4 = nn.Sequential(nn.Linear(10, 4)) 67 | self._initialize_weights() 68 | 69 | def _initialize_weights(self): 70 | for m in self.modules(): 71 | if isinstance(m, nn.Conv1d): 72 | if m.kernel_size == (250,): 73 | m.weight.data = fast(out_channels=64, kernel_size=250).forward() 74 | nn.init.constant_(m.bias.data, 0.0) 75 | else: 76 | nn.init.kaiming_normal_(m.weight.data) 77 | nn.init.constant_(m.bias.data, 0.0) 78 | elif isinstance(m, nn.BatchNorm1d): 79 | m.weight.data.fill_(1) 80 | m.bias.data.fill_(0) 81 | elif isinstance(m, nn.Linear): 82 | m.weight.data.normal_(0, 0.01) 83 | if m.bias is not None: 84 | m.bias.data.fill_(1) 85 | 86 | 87 | 88 | def forward(self, x): 89 | x = self.p1_0(x) 90 | p1 = self.p1_3(self.p1_2(self.p1_1(x))) 91 | p2 = self.p2_6(self.p2_5(self.p2_4(self.p2_3(self.p2_2(self.p2_1(x)))))) 92 | x = self.p3_2(self.p3_1(x + self.p3_0(x))) 93 | x = torch.add(x, torch.add(p1, p2)) 94 | x = self.p3_3(x).squeeze() 95 | x = self.p4(x) 96 | return x 97 | 98 | if __name__ == '__main__': 99 | import numpy as np 100 | import matplotlib.pyplot as plt 101 | import math 102 | 103 | # input = torch.randn(2, 1, 1024).cuda() 104 | # model = Net().cuda() 105 | # # for param in model.parameters(): 106 | # # print(type(param.data), param.size()) 107 | # print("# parameters:", sum(param.numel() for param in model.parameters())) 108 | # output = model(input) 109 | # print(model) 110 | ################################## 111 | # model = Net().cuda() 112 | # for name, parameters in model.named_parameters(): 113 | # print(name, ':', parameters.size()) 114 | # model.load_state_dict(torch.load('H:\EWSNet\pre\99.53.pt'), strict=False) 115 | # weight_t = model.state_dict()['p3_0.a'].cpu().detach().numpy() 116 | # print(weight_t) 117 | ################################## 118 | # model = Net().cuda() 119 | # weight_t = model.state_dict()['p1_0.0.weight'].cpu().detach().numpy() 120 | # y1 = weight_t[60, :, :].squeeze() 121 | # # model.load_state_dict(torch.load('H:\EWSNet\pre\model.pt'), strict=False) 122 | # # weight_t = model.state_dict()['p1_0.0.weight'].cpu().detach().numpy() 123 | # # y2 = weight_t[20, :, :].squeeze() 124 | # x = np.linspace(0, 250, 250) 125 | # # plt.plot(x, y1, label='before') 126 | # plt.plot(x, y1, label='Xavier_normal') 127 | # # plt.plot(x, y2, label='after') 128 | # plt.legend() 129 | # plt.savefig('wahaha.tiff', format='tiff', dpi=600) 130 | # plt.show() 131 | 132 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🔥 Physics-informed Interpretable Wavelet Weight Initialization and Balanced Dynamic Adaptive Threshold for Intelligent Fault Diagnosis of Rolling Bearings 2 | 3 | The pytorch implementation of the paper [Physics-informed interpretable wavelet weight initialization and balanced dynamic adaptive threshold for intelligent fault diagnosis of rolling bearings](https://doi.org/10.1016/j.jmsy.2023.08.014) 4 | 5 | 6 | # Updating! 7 | 8 | 9 | [NEWS!]This paper has been accepted by **[Journal of Manufacturing Systems](https://www.sciencedirect.com/journal/journal-of-manufacturing-systems/vol/70/suppl/C)**! 10 | 11 | 12 | ## Brief introduction 13 | Intelligent fault diagnosis of rolling bearings using deep learning-based methods has made unprecedented progress. However, there is still little research on weight initialization and the threshold setting for noise reduction. An innovative deep triple-stream network called EWSNet is proposed, which presents a wavelet weight initialization method and a balanced dynamic adaptive threshold algorithm. Initially, an enhanced wavelet basis function is designed, in which a scale smoothing factor is defined to acquire more rational wavelet scales. Next, a plug-and-play wavelet weight initialization for deep neural networks is proposed, which utilizes physics-informed wavelet prior knowledge and showcases stronger applicability. Furthermore, a balanced dynamic adaptive threshold is established to enhance the noise-resistant robustness of the model. Finally, normalization activation mapping is devised to reveal the effectiveness of Z-score from a visual perspective rather than experimental results. The validity and reliability of EWSNet are demonstrated through four data sets under the conditions of constant and fluctuating speeds. 14 | 15 | ## Highlights 16 | 17 | - **A novel deep triple-stream network called EWSNet is proposed for fault diagnosis of rolling bearings under the condition of constant or sharp speed variation.** 18 | - **An enhanced wavelet convolution kernel is designed to improve the trainability, in which a scale smoothing factor is employed to acquire rational wavelet scales.** 19 | - **A plug-and-play and physics-informed wavelet weight initialization is proposed to construct an interpretable convolution kernel, which makes the diagnosis interpretable.** 20 | - **Balanced dynamic adaptive threshold is specially devised to improve the antinoise robustness of the model.** 21 | - **Normalization activation mapping is designed to visually reveal that Z-score can enhance the frequency-domain information of raw signals.** 22 | 23 | 24 | ## Paper 25 | **Physics-informed Interpretable Wavelet Weight Initialization and Balanced Dynamic Adaptive Threshold for Intelligent Fault Diagnosis of Rolling Bearings** 26 | 27 | Chao Hea,b, **Hongmei Shia,b,***, Jin Sic and Jianbo Lia,b 28 | 29 | aSchool of Mechanical, Electronic and Control Engineering, Beijing Jiaotong University, Beijing 100044, China 30 | 31 | bCollaborative Innovation Center of Railway Traffic Safety, Beijing 100044, China 32 | 33 | cKey laboratory of information system and technology, Beijing institute of control and electronic technology, Beijing 100038, China 34 | 35 | **[Journal of Manufacturing Systems](https://www.sciencedirect.com/journal/journal-of-manufacturing-systems/vol/70/suppl/C)** 36 | 37 | ## EWSNet 38 | ![image](https://github.com/liguge/EWSNet_new/assets/19371493/8296a3a9-ff68-4857-8e59-f7f828245101) 39 | 40 | ## Wavelet initialization 41 | 42 | ![image](https://user-images.githubusercontent.com/19371493/180359513-b6fd1fb4-4c63-47ad-8d98-b8030d2ca529.png) 43 | 44 | ## Balanced Dynamic Adaptive Thresholding 45 | 46 | ![image](https://github.com/liguge/EWSNet_new/assets/19371493/6df0c396-841c-4c3c-b098-69765de18bf5) 47 | 48 | ![image](https://github.com/liguge/EWSNet_new/assets/19371493/c94d06ef-7604-43d8-886c-b592acb003ab) 49 | 50 | **where $\alpha$ and $\eta$ are differentiable($\alpha \in \left( {0,1} \right),\alpha \ne 0,1$). When $\alpha$=0 or 1, Eq. respectively degenerates into hard threshold and soft threshold, and thus we can adjust $\alpha$ appropriately to make $y$ closer to the genuine wavelet coefficient.** 51 | 52 | ## Normalization Activation Mapping 53 | 54 | **Data normalization can accelerate the process of convergence. Also, Z-score makes CNN get better accuracy. Unlike experimental methods, we notice that Z-score enhances frequency-domain information of signals so that CNN can learn these features better.** 55 | 56 | **FAM illustrates the frequency-domain information by utilizing the weights of the classification layer and extracted features, but it can not reveal the influence of normalization methods. Therefore in NAM, the weight of the correct label is $1.0$, and the features are signals processed by the normalization methods and it can visualize which normalization method possesses more frequency-domain knowledge.** 57 | 58 | ![image](https://github.com/liguge/EWSNet_new/assets/19371493/ddbc692d-74c6-4764-a775-260c44837473) 59 | 60 | where ${l_{real}}$ is the real label and ${l_{target}}$ is the tested label. 61 | 62 | ## Example: 63 | 64 | 65 | 66 | ```python 67 | class CNNNet(nn.Module): 68 | 69 | def __init__(self, init_weights=False): 70 | super(CNNNet, self).__init__() 71 | self.conv1 = nn.Conv1d(1, 64, 60, padding=2) 72 | self.conv2 = nn.Conv1d(32, 32, 3, padding=1) 73 | self.conv3 = nn.Conv1d(16, 48, 3, padding=1) 74 | #self.sage = sage(channel=16, gap_size=1) 75 | self.conv4 = nn.Conv1d(24, 64, 3, padding=1) 76 | self.pool = nn.MaxPool2d(2) 77 | self.fc1 = nn.Linear(32*60, 512) 78 | self.fc2 = nn.Linear(512, 4) 79 | if init_weights: 80 | self._initialize_weights() 81 | 82 | 83 | def _initialize_weights(self): 84 | for m in self.modules(): 85 | if isinstance(m, nn.Conv1d): 86 | if m.kernel_size == (60,): 87 | m.weight.data = fast(out_channels=64, kernel_size=60, eps=0.2, mode='sigmoid').forward() 88 | nn.init.constant_(m.bias.data, 0.0) 89 | 90 | def forward(self, x): 91 | x = self.pool(F.relu(self.conv1(x))) 92 | x = self.pool(F.relu(self.conv2(x))) 93 | #x = x + self.sage(x) 94 | x = self.pool(F.relu(self.conv3(x))) 95 | x = self.pool(F.relu(self.conv4(x))) 96 | x = x.view(x.shape[0], -1) 97 | x = F.relu(self.fc1(x)) 98 | x = self.fc2(x) 99 | return x 100 | ``` 101 | 102 | 103 | 104 | ## Citation 105 | 106 | ```html 107 | @article{HE, 108 | title = {Physics-informed interpretable wavelet weight initialization and balanced dynamic adaptive threshold for intelligent fault diagnosis of rolling bearings}, 109 | journal = {Journal of Manufacturing Systems}, 110 | volume = {70}, 111 | pages = {579-592}, 112 | year = {2023}, 113 | issn = {1878-6642}, 114 | doi = {https://doi.org/10.1016/j.jmsy.2023.08.014}, 115 | author = {Chao He and Hongmei Shi and Jin Si and Jianbo Li} 116 | ``` 117 | 118 | C. He, H. Shi, J. Si, J. Li, Physics-informed interpretable wavelet weight initialization and balanced dynamic adaptive threshold for intelligent fault diagnosis of rolling bearings, Journal of Manufacturing Systems 70 (2023) 579-592, https://doi.org/10.1016/j.jmsy.2023.08.014. 119 | 120 | 121 | 122 | ## Ackowledgements 123 | The authors are grateful for the supports of the Fundamental Research Funds for the Central Universities (Science and Technology Leading Talent Team Project) (2022JBXT005), and the National Natural Science Foundation of China (No.52272429). 124 | 125 | ## References 126 | 127 | - von Rueden L, Mayer S, Beckh K, Georgiev B, Giesselbach S, Heese R, et al Informed Machine Learning – A Taxonomy and Survey of Integrating Prior Knowledge into Learning Systems. IEEE Trans Knowl Data Eng 2023;35(1):614–633. https://doi.org/10.1109/TKDE.2021.3079836 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | - Vollert S, Atzmueller M, Theissler A. Interpretable Machine Learning: A brief survey from the predictive maintenance perspective. In: 26th IEEE International Conference on Emerging Technologies and Factory Automation, ETFA 2021, Vasteras, Sweden, September 7-10, 2021 IEEE; 2021. p. 1–8. https://doi.org/10.1109/ETFA45728.2021.9613467 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | - Li T, Zhao Z, Sun C, Cheng L, Chen X, Yan R, et al WaveletKernelNet: An Interpretable Deep Neural Network for Industrial Intelligent Diagnosis. IEEE Trans Syst Man Cybern Syst 2022;52(4):2302–2312. https://doi.org/10.1109/TSMC.2020.3048950 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | - Zhao M, Zhong S, Fu X, Tang B, Pecht M. Deep Residual Shrinkage Networks for Fault Diagnosis. IEEE Trans Industr Inform 2020;16(7):4681–4690. https://doi.org/10.1109/TII.2019.2943898 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | - Kim MS, Yun JP, Park P. An Explainable Neural Network for Fault Diagnosis With a Frequency Activation Map. IEEE Access 2021;9:98962–98972. https://doi.org/10.1109/ACCESS.2021.3095565 179 | 180 | 181 | 182 | 183 | 184 | 185 | ## Contact 186 | 187 | - **Chao He** 188 | - **chaohe#bjtu.edu.cn (please replace # by @)** 189 | 190 | ​ 191 | -------------------------------------------------------------------------------- /weights.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from math import pi 4 | import torch.nn.functional as F 5 | 6 | 7 | ''' 8 | Laplace小波卷积核初始化 9 | ''' 10 | def Laplace(p): 11 | 12 | w = 2 * pi * 80 13 | q = torch.tensor(1 - pow(0.03, 2)) 14 | return 0.08 * torch.exp(((-0.03 / (torch.sqrt(q))) * (w * (p - 0.1)))) * (torch.sin(w * (p - 0.1))) 15 | 16 | class Laplace_fast(nn.Module): 17 | def __init__(self, out_channels, kernel_size, stride=2): 18 | super(Laplace_fast, self).__init__() 19 | self.out_channels = out_channels 20 | self.kernel_size = kernel_size 21 | self.stride = stride 22 | self.a_ = nn.Parameter(torch.linspace(1, 100, out_channels).view(-1, 1)) 23 | self.b_ = nn.Parameter(torch.linspace(0, 100, out_channels).view(-1, 1)) 24 | 25 | def forward(self, waveforms): 26 | time_disc = torch.linspace(0, 1, steps=int(self.kernel_size)) 27 | p1 = (time_disc.cuda() - self.b_.cuda()) / (self.a_.cuda()) 28 | laplace_filter = Laplace(p1) 29 | filters = laplace_filter.view(self.out_channels, 1, self.kernel_size).cuda() 30 | return F.conv1d(waveforms, filters, stride=self.stride, padding=0) 31 | 32 | 33 | class Laplace_fastv2: 34 | 35 | def __init__(self, out_channels, kernel_size, eps=-0.3): 36 | super(Laplace_fastv2, self).__init__() 37 | self.out_channels = out_channels 38 | self.kernel_size = kernel_size 39 | self.eps = eps 40 | self.a_ = torch.linspace(0, out_channels, out_channels).view(-1, 1) #out_channels-1 通道是整数 41 | self.b_ = torch.linspace(0, out_channels, out_channels).view(-1, 1) 42 | self.time_disc = torch.linspace(0, self.kernel_size - 1, steps=int(self.kernel_size)) 43 | 44 | def forward(self): 45 | p1 = (self.time_disc - self.b_) / (self.a_ - self.eps) 46 | filter = Laplace(p1).view(self.out_channels, 1, self.kernel_size) # (70,1,85) 47 | return filter 48 | 49 | class Laplace_fastv21(nn.Module): 50 | 51 | def __init__(self, out_channels, kernel_size, eps=0.3): 52 | super(Laplace_fastv21, self).__init__() 53 | self.out_channels = out_channels 54 | self.kernel_size = kernel_size 55 | self.eps = eps 56 | self.a_ = torch.linspace(0, self.out_channels, self.out_channels).view(-1, 1) 57 | self.b_ = torch.linspace(0, self.out_channels, self.out_channels).view(-1, 1) 58 | self.time_disc = torch.linspace(0, self.kernel_size-1, steps=int(self.kernel_size)) 59 | 60 | def forward(self): 61 | p1 = (self.time_disc - self.b_) / (self.a_ + self.eps) 62 | filter = Laplace(p1).view(self.out_channels, 1, self.kernel_size) # (70,1,85) 63 | return filter 64 | # def Laplace1(p): 65 | # # m = 1000 66 | # # ep = 0.03 67 | # # # tal = 0.1 68 | # # f = 80 69 | # w = 2 * pi * 80 70 | # # A = 0.08 71 | # q = torch.tensor(1 - pow(0.03, 2)) 72 | # #return (0.08 * torch.exp(((-0.03 / (torch.sqrt(q))) * (w * (p - 0.1)))) * (torch.sin(w * (p - 0.1)))) 73 | # # y = 0.08 * torch.exp(((-0.03 / (torch.sqrt(q))) * (w * (p - 0.1))).sigmoid()) * (torch.sin(w * (p - 0.1))) # 99.82% 74 | # return 0.08 * torch.exp(((-0.03 / (torch.sqrt(q))) * (w * (p - 0.1)))) * (torch.sin(w * (p - 0.1))) 75 | 76 | class Laplace_fastv3(nn.Module): 77 | 78 | def __init__(self, out_channels, kernel_size): 79 | super(Laplace_fastv3, self).__init__() 80 | self.out_channels = out_channels 81 | self.kernel_size = kernel_size 82 | self.a_ = torch.linspace(1, out_channels, out_channels).view(-1, 1) 83 | self.b_ = torch.linspace(0, out_channels, out_channels).view(-1, 1) 84 | self.time_disc = torch.linspace(0, 1, steps=int(self.kernel_size)) 85 | 86 | def forward(self): 87 | p1 = (self.time_disc - self.b_) / self.a_ 88 | filter = Laplace(p1).view(self.out_channels, 1, self.kernel_size) # (70,1,85) 89 | return filter 90 | 91 | 92 | ''' 93 | 小波核初始化 94 | ''' 95 | 96 | 97 | def Morlet(p, c): 98 | 99 | return c * torch.exp((-torch.pow(p, 2) / 2)) * torch.cos(5 * p) 100 | 101 | 102 | 103 | class Morlet_fast(nn.Module): 104 | 105 | def __init__(self, out_channels, kernel_size, eps=0.3): 106 | super(Morlet_fast, self).__init__() 107 | if kernel_size % 2 != 0: 108 | kernel_size = kernel_size - 1 109 | self.out_channels = out_channels 110 | self.kernel_size = kernel_size 111 | self.eps = eps 112 | self.a_ = torch.linspace(0, out_channels, out_channels).view(-1, 1) 113 | self.b_ = torch.linspace(0, out_channels, out_channels).view(-1, 1) 114 | self.time_disc_right = torch.linspace(0, (self.kernel_size / 2) - 1, steps=int((self.kernel_size / 2))) 115 | self.time_disc_left = torch.linspace(-(self.kernel_size / 2) + 1, -1, steps=int((self.kernel_size / 2))) 116 | 117 | def forward(self): 118 | p1 = (self.time_disc_right - self.b_) / (self.a_ + self.eps) # 119 | p2 = (self.time_disc_left - self.b_) / (self.a_ + self.eps) # 120 | C = (pow(pi, 0.25)) / torch.sqrt(self.a_ + 0.01) ##一个值得探讨的点 1e-3 # D = C / self.a_.cuda() 121 | Morlet_right = Morlet(p1, C) 122 | Morlet_left = Morlet(p2, C) 123 | filter = torch.cat([Morlet_left, Morlet_right], dim=1).view(self.out_channels, 1, self.kernel_size) 124 | return filter 125 | 126 | 127 | ''' 128 | Mexh小波卷积核 129 | ''' 130 | 131 | 132 | def Mexh(p): 133 | # p = 0.04 * p # 将时间转化为在[-5,5]这个区间内 134 | # y = (2 / pow(3, 0.5) * (pow(pi, -0.25))) * (1 - torch.pow(p, 2)) * torch.exp((-torch.pow(p, 2) / 2)) 135 | 136 | return (2/pi)*((2 / pow(3, 0.5) * (pow(pi, -0.25))) * (1 - torch.pow(p, 2)) * torch.exp((-torch.pow(p, 2) / 2))).atan() 137 | 138 | class Mexh_fast(nn.Module): 139 | 140 | def __init__(self, out_channels, kernel_size, eps=0.3): 141 | super(Mexh_fast, self).__init__() 142 | if kernel_size % 2 != 0: 143 | kernel_size = kernel_size - 1 144 | self.out_channels = out_channels 145 | self.kernel_size = kernel_size 146 | self.eps = eps 147 | self.a_ = torch.linspace(0, out_channels, out_channels).view(-1, 1) 148 | self.b_ = torch.linspace(0, out_channels, out_channels).view(-1, 1) 149 | self.time_disc_right = torch.linspace(0, (self.kernel_size / 2) - 1, steps=int((self.kernel_size / 2))) 150 | self.time_disc_left = torch.linspace(-(self.kernel_size / 2) + 1, -1, steps=int((self.kernel_size / 2))) 151 | 152 | def forward(self): 153 | p1 = (self.time_disc_right - self.b_) / (self.a_ + self.eps) # 154 | p2 = (self.time_disc_left - self.b_) / (self.a_ + self.eps) # 155 | Mexh_right = Mexh(p1) 156 | Mexh_left = Mexh(p2) 157 | filter = torch.cat([Mexh_left, Mexh_right], dim=1).view(self.out_channels, 1, self.kernel_size) # 40x1x250 158 | return filter 159 | 160 | 161 | ''' 162 | Gaussian小波卷积核 163 | ''' 164 | 165 | 166 | def Gaussian(p): 167 | # y = D * torch.exp(-torch.pow(p, 2)) 168 | # F0 = (2./pi)**(1./4.) * torch.exp(-torch.pow(p, 2)) 169 | # y = -2 / (3 ** (1 / 2)) * (-1 + 2 * p ** 2) * F0 170 | # y = (2./pi)**(1./4.) * torch.exp(-torch.pow(p, 2)) 171 | # y = -2 / (3 ** (1 / 2)) * (-1 + 2 * p ** 2) * y 172 | # y = -((1 / (pow(2 * pi, 0.5))) * p * torch.exp((-torch.pow(p, 2)) / 2)) 173 | return -((1 / (pow(2 * pi, 0.5))) * p * (torch.exp(((-torch.pow(p, 2)) / 2)))) 174 | 175 | 176 | 177 | 178 | class Gaussian_fast(nn.Module): 179 | 180 | def __init__(self, out_channels, kernel_size, eps=0.3): 181 | super(Gaussian_fast, self).__init__() 182 | if kernel_size % 2 != 0: 183 | kernel_size = kernel_size - 1 184 | self.out_channels = out_channels 185 | self.kernel_size = kernel_size 186 | self.eps = eps 187 | self.a_ = torch.linspace(0, out_channels, out_channels).view(-1, 1) 188 | self.b_ = torch.linspace(0, out_channels, out_channels).view(-1, 1) 189 | self.time_disc_right = torch.linspace(0, (self.kernel_size / 2) - 1, steps=int((self.kernel_size / 2))) 190 | self.time_disc_left = torch.linspace(-(self.kernel_size / 2) + 1, -1, steps=int((self.kernel_size / 2))) 191 | 192 | def forward(self): 193 | p1 = (self.time_disc_right - self.b_) / (self.a_ + self.eps) # 194 | p2 = (self.time_disc_left - self.b_) / (self.a_ + self.eps) # 195 | Gaussian_right = Gaussian(p1) 196 | Gaussian_left = Gaussian(p2) 197 | filter = torch.cat([Gaussian_left, Gaussian_right], dim=1).view(self.out_channels, 1, 198 | self.kernel_size) # 40x1x250 199 | return filter 200 | 201 | 202 | def Shannon(p): 203 | # y = (torch.sin(2 * pi * (p - 0.5)) - torch.sin(pi * (p - 0.5))) / (pi * (p - 0.5)) 204 | return (torch.sin(2 * pi * (p - 0.5)) - torch.sin(pi * (p - 0.5))) / (pi * (p - 0.5)) 205 | 206 | 207 | class Shannon_fast(nn.Module): 208 | 209 | def __init__(self, out_channels, kernel_size, eps=0.3): 210 | super(Shannon_fast, self).__init__() 211 | if kernel_size % 2 != 0: 212 | kernel_size = kernel_size - 1 213 | self.out_channels = out_channels 214 | self.kernel_size = kernel_size 215 | self.eps = eps 216 | self.a_ = torch.linspace(0, out_channels, out_channels).view(-1, 1) 217 | self.b_ = torch.linspace(0, out_channels, out_channels).view(-1, 1) 218 | self.time_disc_right = torch.linspace(0, (self.kernel_size / 2) - 1, steps=int((self.kernel_size / 2))) 219 | self.time_disc_left = torch.linspace(-(self.kernel_size / 2) + 1, -1, steps=int((self.kernel_size / 2))) 220 | # self.time_disc_right = torch.linspace(0, 1, steps=int((self.kernel_size / 2))) 221 | # self.time_disc_left = torch.linspace(-1, 0, steps=int((self.kernel_size / 2))) 222 | 223 | def forward(self): 224 | p1 = (self.time_disc_right - self.b_) / (self.a_ + self.eps) # 225 | p2 = (self.time_disc_left - self.b_) / (self.a_ + self.eps) # 226 | Shannon_right = Shannon(p1) 227 | Shannon_left = Shannon(p2) 228 | filter = torch.cat([Shannon_left, Shannon_right], dim=1).view(self.out_channels, 1, 229 | self.kernel_size) # 40x1x250 230 | return filter 231 | 232 | def Sin(p): 233 | # y = (torch.sin(2 * pi * (p - 0.5)) - torch.sin(pi * (p - 0.5))) / (pi * (p - 0.5)) 234 | return torch.sin(p) / p 235 | 236 | 237 | class Sin_fast(nn.Module): 238 | 239 | def __init__(self, out_channels, kernel_size, eps=0.3): 240 | super(Sin_fast, self).__init__() 241 | if kernel_size % 2 != 0: 242 | kernel_size = kernel_size - 1 243 | self.out_channels = out_channels 244 | self.kernel_size = kernel_size 245 | self.eps = eps 246 | self.a_ = torch.linspace(0, out_channels, out_channels).view(-1, 1) 247 | self.b_ = torch.linspace(0, out_channels, out_channels).view(-1, 1) 248 | self.time_disc_right = torch.linspace(0, (self.kernel_size / 2) - 1, steps=int((self.kernel_size / 2))) 249 | self.time_disc_left = torch.linspace(-(self.kernel_size / 2) + 1, -1, steps=int((self.kernel_size / 2))) 250 | # self.time_disc_right = torch.linspace(0, 1, steps=int((self.kernel_size / 2))) 251 | # self.time_disc_left = torch.linspace(-1, 0, steps=int((self.kernel_size / 2))) 252 | 253 | def forward(self): 254 | p1 = (self.time_disc_right - self.b_) / (self.a_ + self.eps) # 255 | p2 = (self.time_disc_left - self.b_) / (self.a_ + self.eps) # 256 | Shannon_right = Sin(p1) 257 | Shannon_left = Sin(p2) 258 | filter = torch.cat([Shannon_left, Shannon_right], dim=1).view(self.out_channels, 1, 259 | self.kernel_size) # 40x1x250 260 | return filter 261 | 262 | 263 | if __name__ == '__main__': 264 | import numpy as np 265 | import matplotlib.pyplot as plt 266 | 267 | input = torch.randn(2, 1, 1024).cuda() 268 | weight = Laplace_fastv2(out_channels=64, kernel_size=250).forward().cuda() 269 | weight_t = weight.cpu().detach().numpy() 270 | y = weight_t[20, :, :].squeeze() 271 | x = np.linspace(0, 250, 250) 272 | plt.plot(x, y) 273 | plt.savefig('wahaha.tiff', format='tiff', dpi=600) 274 | plt.show() 275 | -------------------------------------------------------------------------------- /thresholds.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Shrinkagev2(nn.Module): 5 | def __init__(self, channel, gap_size): 6 | super(Shrinkagev2, self).__init__() 7 | self.gap = nn.AdaptiveAvgPool1d(gap_size) 8 | self.fc = nn.Sequential( 9 | nn.Linear(channel, channel), 10 | nn.BatchNorm1d(channel), 11 | nn.ReLU(inplace=True), 12 | nn.Linear(channel, channel), 13 | nn.Sigmoid(), 14 | ) 15 | 16 | def forward(self, x): 17 | x_raw = x 18 | # x = torch.abs(x) 19 | x_abs = x.abs() 20 | x = self.gap(x) 21 | x = torch.flatten(x, 1) 22 | # average = torch.mean(x, dim=1, keepdim=True) #CS 23 | average = x #CW 24 | x = self.fc(x) 25 | x = torch.mul(average, x).unsqueeze(2) 26 | # soft thresholding 27 | x = x_abs - x 28 | # zeros = sub - sub 29 | # n_sub = torch.max(sub, torch.zeros_like(sub)) 30 | x = torch.mul(torch.sign(x_raw), torch.max(x, torch.zeros_like(x))) 31 | return x 32 | 33 | class Shrinkage(nn.Module): 34 | def __init__(self, channel, gap_size): 35 | super(Shrinkage, self).__init__() 36 | self.gap = nn.AdaptiveAvgPool1d(gap_size) 37 | self.fc = nn.Sequential( 38 | nn.Linear(channel, channel), 39 | nn.BatchNorm1d(channel), 40 | nn.ReLU(inplace=True), 41 | nn.Linear(channel, channel), 42 | nn.Sigmoid(), 43 | ) 44 | 45 | def forward(self, x): 46 | x_raw = x 47 | x = torch.abs(x) 48 | x_abs = x 49 | x = self.gap(x) 50 | x = torch.flatten(x, 1) 51 | # average = torch.mean(x, dim=1, keepdim=True) #CS 52 | average = x #CW 53 | x = self.fc(x) 54 | x = torch.mul(average, x) 55 | x = x.unsqueeze(2) 56 | # soft thresholding 57 | sub = x_abs - x 58 | zeros = sub - sub 59 | n_sub = torch.max(sub, zeros) 60 | x = torch.mul(torch.sign(x_raw), n_sub) 61 | return x 62 | 63 | class Shrinkagev3(nn.Module): 64 | def __init__(self, gap_size, inp, oup, reduction=4): 65 | super(Shrinkagev3, self).__init__() 66 | mip = int(max(8, inp // reduction)) 67 | self.gap = nn.AdaptiveAvgPool1d(gap_size) 68 | self.fc = nn.Sequential( 69 | nn.Conv1d(inp, mip, kernel_size=1, stride=1, padding=0), 70 | nn.BatchNorm1d(mip), 71 | nn.ReLU(inplace=True), 72 | nn.Conv1d(mip, oup, kernel_size=1, stride=1, padding=0), 73 | nn.Sigmoid() 74 | ) 75 | 76 | # def forward(self, x): 77 | # x_raw = x 78 | # x = torch.abs(x) 79 | # x_abs = x 80 | # x = self.gap(x) 81 | # # x = torch.flatten(x, 1) 82 | # # average = torch.mean(x, dim=1, keepdim=True) #CS 83 | # average = x #CW 84 | # x = self.fc(x) 85 | # x = torch.mul(average, x) 86 | # # x = x.unsqueeze(2) 87 | # # soft thresholding 88 | # sub = x_abs - x 89 | # zeros = sub - sub 90 | # n_sub = torch.max(sub, zeros) 91 | # x = torch.mul(torch.sign(x_raw), n_sub) 92 | # return x 93 | def forward(self, x): 94 | x_raw = x 95 | # x = torch.abs(x) 96 | x_abs = x.abs() 97 | x = self.gap(x) 98 | # average = torch.mean(x, dim=1, keepdim=True) #CS 99 | average = x #CW 100 | x = self.fc(x) 101 | x = torch.mul(average, x) 102 | # soft thresholding 103 | x = x_abs - x 104 | # zeros = sub - sub 105 | # n_sub = torch.max(sub, torch.zeros_like(sub)) 106 | x = torch.mul(torch.sign(x_raw), torch.max(x, torch.zeros_like(x))) 107 | return x 108 | 109 | class Shrinkagev3p(nn.Module): 110 | def __init__(self, gap_size, inp, oup, reduction=4): 111 | super(Shrinkagev3p, self).__init__() 112 | mip = int(max(8, inp // reduction)) 113 | self.a = nn.Parameter(torch.tensor([0.4])) 114 | self.gap = nn.AdaptiveAvgPool1d(gap_size) 115 | # self.gap = nn.AdaptiveMaxPool1d(gap_size) 116 | self.fc = nn.Sequential( 117 | nn.Conv1d(inp, mip, kernel_size=1, stride=1), 118 | nn.BatchNorm1d(mip), 119 | nn.ReLU(inplace=True), 120 | nn.Conv1d(mip, oup, kernel_size=1, stride=1), 121 | nn.Hardsigmoid() 122 | # nn.Sigmoid() 123 | ) 124 | 125 | # def forward(self, x): 126 | # x_raw = x 127 | # x = torch.abs(x) 128 | # x_abs = x 129 | # x = self.gap(x) 130 | # # x = torch.flatten(x, 1) 131 | # # average = torch.mean(x, dim=1, keepdim=True) #CS 132 | # average = x #CW 133 | # x = self.fc(x) 134 | # x = torch.mul(average, x) 135 | # # x = x.unsqueeze(2) 136 | # # soft thresholding 137 | # sub = x_abs - x 138 | # zeros = sub - sub 139 | # n_sub = torch.max(sub, zeros) 140 | # x = torch.mul(torch.sign(x_raw), n_sub) 141 | # return x 142 | def forward(self, x): 143 | x_raw = x 144 | # x = torch.abs(x) 145 | x_abs = x.abs() 146 | x = self.gap(x) 147 | # average = torch.mean(x, dim=1, keepdim=True) #CS 148 | average = x #CW 149 | x = self.fc(x) 150 | x = torch.mul(average, x) 151 | # soft thresholding 152 | # x = x_abs - x 153 | x = x_abs - self.a * x 154 | # zeros = sub - sub 155 | # n_sub = torch.max(sub, torch.zeros_like(sub)) 156 | x = torch.mul(torch.sign(x_raw), torch.max(x, torch.zeros_like(x))) 157 | return x 158 | 159 | class Shrinkagev3p1(nn.Module): 160 | def __init__(self, gap_size, inp, oup, reduction=32): 161 | super(Shrinkagev3p1, self).__init__() 162 | mip = int(max(8, inp // reduction)) 163 | #self.a = nn.Parameter(torch.tensor([0.5])) 164 | self.a = nn.Parameter(0.5 * torch.ones(1, 10, 1)) 165 | self.gap = nn.AdaptiveAvgPool1d(gap_size) 166 | self.fc = nn.Sequential( 167 | nn.Conv1d(inp, mip, kernel_size=1, stride=1, padding=0), 168 | nn.BatchNorm1d(mip), 169 | nn.ReLU(inplace=True), 170 | nn.Conv1d(mip, oup, kernel_size=1, stride=1, padding=0), 171 | nn.Hardsigmoid() 172 | # nn.Sigmoid() 173 | ) 174 | 175 | # def forward(self, x): 176 | # x_raw = x 177 | # x = torch.abs(x) 178 | # x_abs = x 179 | # x = self.gap(x) 180 | # # x = torch.flatten(x, 1) 181 | # # average = torch.mean(x, dim=1, keepdim=True) #CS 182 | # average = x #CW 183 | # x = self.fc(x) 184 | # x = torch.mul(average, x) 185 | # # x = x.unsqueeze(2) 186 | # # soft thresholding 187 | # sub = x_abs - x 188 | # zeros = sub - sub 189 | # n_sub = torch.max(sub, zeros) 190 | # x = torch.mul(torch.sign(x_raw), n_sub) 191 | # return x 192 | def forward(self, x): 193 | x_raw = x 194 | # x = torch.abs(x) 195 | x_abs = x.abs() 196 | x = self.gap(x) 197 | # average = torch.mean(x, dim=1, keepdim=True) #CS 198 | average = x #CW 199 | x = self.fc(x) 200 | x = torch.mul(average, x) 201 | # soft thresholding 202 | x = x_abs - self.a * x 203 | # zeros = sub - sub 204 | # n_sub = torch.max(sub, torch.zeros_like(sub)) 205 | x = torch.mul(torch.sign(x_raw), torch.max(x, torch.zeros_like(x))) 206 | return x 207 | class Shrinkagev3p11(nn.Module): 208 | def __init__(self, gap_size, channel): 209 | super(Shrinkagev3p11, self).__init__() 210 | #self.a = nn.Parameter(torch.tensor([0.5])) 211 | self.a = nn.Parameter(torch.tensor([0.48])) 212 | self.gap = nn.AdaptiveAvgPool1d(gap_size) 213 | self.fc = nn.Sequential( 214 | nn.Linear(channel, channel), 215 | nn.BatchNorm1d(channel), 216 | nn.ReLU(inplace=True), 217 | nn.Linear(channel, channel), 218 | nn.Hardsigmoid() 219 | ) 220 | 221 | def forward(self, x): 222 | x_raw = x 223 | # x = torch.abs(x) 224 | x_abs = x.abs() 225 | x = self.gap(x) 226 | x = torch.flatten(x, 1) 227 | # average = torch.mean(x, dim=1, keepdim=True) #CS 228 | average = x #CW 229 | x = self.fc(x) 230 | x = torch.mul(average, x).unsqueeze(2) 231 | # soft thresholding 232 | x = x_abs - self.a * x 233 | # zeros = sub - sub 234 | # n_sub = torch.max(sub, torch.zeros_like(sub)) 235 | x = torch.mul(torch.sign(x_raw), torch.max(x, torch.zeros_like(x))) 236 | return x 237 | class Shrinkagev3pp(nn.Module): #会议文献的软阈值降噪 238 | def __init__(self, gap_size, inp, oup, reduction=4): 239 | super(Shrinkagev3pp, self).__init__() 240 | mip = int(max(8, inp // reduction)) 241 | self.a = nn.Parameter(torch.tensor([0.48])) 242 | # self.a = nn.Parameter(torch.clamp(torch.tensor([0.4]), min=0, max=1)) 243 | self.gap = nn.AdaptiveAvgPool1d(gap_size) 244 | # self.gap = nn.AdaptiveMaxPool1d(gap_size) 245 | self.fc = nn.Sequential( 246 | nn.Conv1d(inp, mip, kernel_size=1, stride=1), 247 | nn.BatchNorm1d(mip), 248 | nn.ReLU(inplace=True), 249 | nn.Conv1d(mip, oup, kernel_size=1, stride=1), 250 | nn.Hardsigmoid() 251 | # nn.Sigmoid() 252 | ) 253 | 254 | # def forward(self, x): 255 | # x_raw = x 256 | # x = torch.abs(x) 257 | # x_abs = x 258 | # x = self.gap(x) 259 | # # x = torch.flatten(x, 1) 260 | # # average = torch.mean(x, dim=1, keepdim=True) #CS 261 | # average = x #CW 262 | # x = self.fc(x) 263 | # x = torch.mul(average, x) 264 | # # x = x.unsqueeze(2) 265 | # # soft thresholding 266 | # sub = x_abs - x 267 | # zeros = sub - sub 268 | # n_sub = torch.max(sub, zeros) 269 | # x = torch.mul(torch.sign(x_raw), n_sub) 270 | # return x 271 | def forward(self, x): 272 | x_raw = x 273 | # x = torch.abs(x) 274 | x_abs = x.abs() 275 | x = self.gap(x) 276 | 277 | # average = torch.mean(x, dim=1, keepdim=True) #CS 278 | average = x #CW 279 | x = self.fc(x) 280 | x = torch.mul(average, x) 281 | # soft thresholding 282 | x = x_abs - x 283 | # x = x_abs - self.a * x 284 | # sub = torch.max(x1, torch.zeros_like(x1)) 285 | # sub = self.a * sub 286 | a = torch.clamp(self.a, min=0, max=1) 287 | x = torch.mul(torch.sign(x_raw), a*(torch.max(x, torch.zeros_like(x)))) 288 | return x 289 | class Shrinkagev3pp1(nn.Module): #会议文献的软阈值降噪 290 | def __init__(self, gap_size, channel): 291 | super(Shrinkagev3pp1, self).__init__() 292 | self.a = nn.Parameter(torch.tensor([0.48])) 293 | self.gap = nn.AdaptiveAvgPool1d(gap_size) 294 | self.fc = nn.Sequential( 295 | nn.Linear(channel, channel), 296 | nn.BatchNorm1d(channel), 297 | nn.ReLU(inplace=True), 298 | nn.Linear(channel, channel), 299 | nn.Sigmoid() 300 | ) 301 | 302 | def forward(self, x): 303 | x_raw = x 304 | x_abs = x.abs() 305 | x = self.gap(x) 306 | x = torch.flatten(x, 1) 307 | # average = torch.mean(x, dim=1, keepdim=True) #CS 308 | average = x #CW 309 | x = self.fc(x) 310 | x = torch.mul(average, x).unsqueeze(2) 311 | # soft thresholding 312 | x = x_abs - x 313 | a = torch.clamp(self.a, min=0, max=1) 314 | x = torch.mul(torch.sign(x_raw), a*(torch.max(x, torch.zeros_like(x)))) 315 | return x 316 | class Shrinkagev3ppp(nn.Module): #半软阈值降噪 317 | def __init__(self, gap_size, inp, oup, reduction=4): 318 | super(Shrinkagev3ppp, self).__init__() 319 | mip = int(max(8, inp // reduction)) 320 | self.a = nn.Parameter(torch.tensor([0.48])) 321 | self.gap = nn.AdaptiveAvgPool1d(gap_size) 322 | # self.gap = nn.AdaptiveMaxPool1d(gap_size) 323 | self.fc = nn.Sequential( 324 | nn.Conv1d(inp, mip, kernel_size=1, stride=1), 325 | nn.BatchNorm1d(mip), 326 | nn.ReLU(inplace=True), 327 | nn.Conv1d(mip, oup, kernel_size=1, stride=1), 328 | # nn.Hardsigmoid() 329 | nn.Sigmoid() 330 | ) 331 | 332 | def forward(self, x): 333 | x_raw = x 334 | x_abs = x.abs() 335 | x = self.gap(x) 336 | average = x 337 | x = self.fc(x) 338 | x = torch.mul(average, x) 339 | # soft thresholding 340 | # x1 = x_abs - x 341 | # sub = torch.max(x1, torch.zeros_like(x1)) 342 | sub = torch.max(x_abs - x, torch.zeros_like(x_abs - x)) 343 | mask = sub.clone() 344 | mask[mask > 0] = 1 345 | # a = torch.clamp(self.a, min=0, max=1) 346 | x = sub + (1-self.a) * x 347 | x = torch.mul(x, mask) 348 | x = torch.mul(torch.sign(x_raw), x) 349 | return x 350 | 351 | class Shrinkagev3ppp1(nn.Module): 352 | def __init__(self, gap_size, inp, oup, reduction=4): 353 | super(Shrinkagev3ppp1, self).__init__() 354 | mip = int(max(8, inp // reduction)) 355 | self.a = nn.Parameter(torch.tensor([0.4])) 356 | self.gap = nn.AdaptiveAvgPool1d(gap_size) 357 | # self.gap = nn.AdaptiveMaxPool1d(gap_size) 358 | self.fc = nn.Sequential( 359 | nn.Conv1d(inp, mip, kernel_size=1, stride=1), 360 | nn.BatchNorm1d(mip), 361 | nn.ReLU(inplace=True), 362 | nn.Conv1d(mip, oup, kernel_size=1, stride=1), 363 | nn.Hardsigmoid() 364 | # nn.Sigmoid() 365 | ) 366 | 367 | def forward(self, x): 368 | x_raw = x 369 | x_abs = x.abs() 370 | x = self.gap(x) 371 | average = x 372 | x = self.fc(x) 373 | x = torch.mul(average, x) 374 | # soft thresholding 375 | sub = torch.max(x_abs - x, torch.zeros_like(x_abs - x)) 376 | x = sub + (1-self.a) * x 377 | x = x.index_put((sub == 0).nonzero(as_tuple=True), torch.tensor(0.)) 378 | x = torch.mul(torch.sign(x_raw), x) 379 | return x 380 | 381 | class Shrinkagev3ppp2(nn.Module): #半软阈值降噪 382 | def __init__(self, gap_size, channel): 383 | super(Shrinkagev3ppp2, self).__init__() 384 | self.a = nn.Parameter(torch.tensor([0.48])) 385 | self.gap = nn.AdaptiveAvgPool1d(gap_size) 386 | self.fc = nn.Sequential( 387 | nn.Linear(channel, channel), 388 | nn.BatchNorm1d(channel), 389 | nn.ReLU(inplace=True), 390 | nn.Linear(channel, channel), 391 | nn.Sigmoid(), 392 | ) 393 | 394 | def forward(self, x): 395 | 396 | x_raw = x 397 | x_abs = x.abs() 398 | x = self.gap(x) 399 | x = torch.flatten(x, 1) 400 | average = x 401 | # average = torch.mean(x, dim=1, keepdim=True) 402 | x = self.fc(x) 403 | x = torch.mul(average, x).unsqueeze(2) 404 | # soft thresholding 405 | sub = torch.max(x_abs - x, torch.zeros_like(x_abs - x)) 406 | mask = sub.clone() 407 | mask[mask > 0] = 1 408 | # a = torch.clamp(self.a, min=0, max=1) 409 | x = sub + (1 - self.a) * x 410 | x = torch.mul(torch.sign(x_raw), torch.mul(x, mask)) 411 | return x 412 | 413 | # class Shrinkagev3ppp3(nn.Module): #半软阈值降噪 414 | # def __init__(self, gap_size, channel): 415 | # super(Shrinkagev3ppp3, self).__init__() 416 | # self.a = nn.Parameter(0.48 * torch.ones(16, channel, 1)) 417 | # self.gap = nn.AdaptiveAvgPool1d(gap_size) 418 | # self.fc = nn.Sequential( 419 | # nn.Linear(channel, channel), 420 | # nn.BatchNorm1d(channel), 421 | # nn.ReLU(inplace=True), 422 | # nn.Linear(channel, channel), 423 | # nn.Sigmoid(), 424 | # ) 425 | # 426 | # def forward(self, x): 427 | # 428 | # x_raw = x 429 | # x_abs = x.abs() 430 | # x = self.gap(x) 431 | # x = torch.flatten(x, 1) 432 | # average = x 433 | # # average = torch.mean(x, dim=1, keepdim=True) 434 | # x = self.fc(x) 435 | # x = torch.mul(average, x).unsqueeze(2) 436 | # # soft thresholding 437 | # sub = torch.max(x_abs - x, torch.zeros_like(x_abs - x)) 438 | # mask = sub.clone() 439 | # mask[mask > 0] = 1 440 | # # a = torch.clamp(self.a, min=0, max=1) 441 | # x = sub + (1 - self.a) * x 442 | # x = torch.mul(torch.sign(x_raw), torch.mul(x, mask)) 443 | # return x 444 | 445 | class HShrinkage(nn.Module): #硬阈值降噪 446 | def __init__(self, gap_size, channel): 447 | super(HShrinkage, self).__init__() 448 | self.gap = nn.AdaptiveAvgPool1d(gap_size) 449 | self.fc = nn.Sequential( 450 | nn.Linear(channel, channel), 451 | nn.BatchNorm1d(channel), 452 | nn.ReLU(inplace=True), 453 | nn.Linear(channel, channel), 454 | nn.Sigmoid(), 455 | ) 456 | 457 | def forward(self, x): 458 | 459 | x_raw = x 460 | x_abs = x.abs() 461 | x = self.gap(x) 462 | x = torch.flatten(x, 1) 463 | average = x 464 | # average = torch.mean(x, dim=1, keepdim=True) 465 | x = self.fc(x) 466 | x = torch.mul(average, x).unsqueeze(2) 467 | # soft thresholding 468 | sub = torch.max(x_abs - x, torch.zeros_like(x_abs - x)) 469 | mask = sub.clone() 470 | mask[mask > 0] = 1 471 | x = torch.mul(x_raw, mask) 472 | return x 473 | class Shrinkagev3ppp31(nn.Module): #半软阈值降噪 474 | def __init__(self, gap_size, channel): 475 | super(Shrinkagev3ppp31, self).__init__() 476 | self.a = nn.Parameter(torch.tensor([0.48])) 477 | self.gap = nn.AdaptiveAvgPool1d(gap_size) 478 | self.fc = nn.Sequential( 479 | nn.Linear(channel, channel), 480 | nn.BatchNorm1d(channel), 481 | nn.ReLU(inplace=True), 482 | nn.Linear(channel, channel), 483 | nn.Sigmoid(), 484 | ) 485 | 486 | def forward(self, x): 487 | 488 | x_raw = x 489 | x_abs = x.abs() 490 | x = self.gap(x) 491 | x = torch.flatten(x, 1) 492 | average = x 493 | # average = torch.mean(x, dim=1, keepdim=True) 494 | x = self.fc(x) 495 | x = torch.mul(average, x).unsqueeze(2) 496 | # soft thresholding 497 | sub = torch.max(x_abs - x, torch.zeros_like(x_abs - x)) 498 | # mask = sub.clone() 499 | # mask[mask > 0] = 1 500 | # a = torch.clamp(self.a, min=0, max=1) 501 | x = torch.mul(sub, (1 - self.a)) 502 | x = torch.add(torch.mul(torch.sign(x_raw), x), torch.mul(self.a, x_raw)) 503 | return x 504 | if __name__ == '__main__': 505 | input = torch.randn(2, 3, 1024).cuda() 506 | model = Shrinkagev3ppp(1, 3, 3).cuda() 507 | for param in model.parameters(): 508 | print(type(param.data), param.size()) 509 | output = model(input) 510 | print(output.size()) 511 | --------------------------------------------------------------------------------