├── nam.py
├── EWSNet.py
├── README.md
├── weights.py
└── thresholds.py
/nam.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from Elaplacenet5 import Net
4 | from gdatasave import train_loader
5 | import pandas as pd
6 | import seaborn as sns
7 | from sklearn.svm import SVC
8 | import numpy as np
9 | from tqdm import tqdm
10 | import matplotlib.pyplot as plt
11 | torch.cuda.manual_seed(1)
12 | torch.manual_seed(2)
13 |
14 | sampling_rate = 50000
15 | window_size = 1024
16 | n_class = 4
17 | channels = [64]
18 |
19 |
20 | def normalize_freq(fam):
21 | max_fam, _ = torch.max(fam, dim=-1, keepdim=True)
22 | return torch.div(fam, max_fam)
23 | # This is from P. Welch's method to compute power spectrum, check [3]
24 | def power_spectrum(t_freq):
25 | result = torch.abs(torch.fft.rfft(t_freq))**2
26 | return result / torch.mean(result, dim=2, keepdim=True)
27 |
28 |
29 | def freq_activation_map(model, input, width, channels, target_label):
30 | '''
31 | Param:
32 | model : Neural Network Object
33 | input : timeseries data
34 | width : length of power_spectrum(input)
35 | channels : # last channel of the model
36 | '''
37 | fam = torch.zeros(input.shape[0], 1, width)
38 | if torch.cuda.is_available():
39 | fam = fam.cuda()
40 |
41 | with torch.no_grad():
42 | freq, labels = model(torch.reshape(input, (-1, 1, input.shape[-1])))
43 | freq = freq.unsqueeze(2)
44 | labels = labels.unsqueeze(2)
45 | labels = torch.argmax(F.softmax(labels, dim=1), dim=1)
46 | labels = torch.where(labels == target_label, 1., 0.)
47 | labels = torch.unsqueeze(torch.reshape(labels, [-1, 1]).repeat(1, width), 1)
48 | for c in range(channels):
49 | sp = freq[:, c, :, :]
50 | if torch.cuda.is_available():
51 | sp = sp.cuda()
52 |
53 | sp = power_spectrum(sp)
54 | sp = sp * labels
55 |
56 | if model.p4[0].weight[target_label, c] > 0:
57 | fam += model.p4[0].weight[target_label, c] * sp
58 |
59 | return torch.squeeze(torch.sum(fam, dim=0)), torch.sum(labels[:, 0, 0], dim=0)
60 |
61 |
62 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
63 | #summary(Net1, input_size=(1, 1024)) # 输出模型具有的参数
64 | model = Net().to(device)
65 | model.load_state_dict(torch.load('./model1001.pt'))
66 | torch.no_grad()
67 | # for img, label in test_loader:
68 | '''
69 | 测试集
70 | '''
71 | # Test.Data = Test.Data.type(torch.FloatTensor).to(device)
72 | # label = torch.from_numpy(Test.Label)
73 | # outputs = model(Test.Data)
74 | # outputs = torch.squeeze(outputs).float()
75 | # encoded_data = outputs.cpu().detach().numpy()
76 | # clf = SVC(C=8.94, gamma=3.26)
77 | '''
78 | 训练集
79 | '''
80 | # Train.Data = Train.Data.type(torch.FloatTensor).to(device)
81 | # labels = torch.from_numpy(Train.Label)
82 | # outputs = model(Train.Data)
83 | # outputs = torch.squeeze(outputs).float()
84 | # encoded_data = outputs.cpu().detach().numpy()
85 |
86 | # clf = SVC(C=10.82, kernel='rbf', gamma=1.19, decision_function_shape='ovr') # rbf高斯基核函数
87 | # clf.fit(encoded_data, label.cpu().numpy())
88 | # fit_score = clf.score(encoded_data, label.cpu().numpy())
89 | # print(fit_score)
90 |
91 | freq_intervals = np.fft.rfftfreq(window_size, d=1/sampling_rate)
92 | total_fam = torch.zeros(n_class, len(freq_intervals))
93 | total_len = torch.zeros(n_class, 1)
94 | if torch.cuda.is_available():
95 | total_fam = total_fam.cuda()
96 | total_len = total_len.cuda()
97 |
98 | for batch in tqdm(train_loader):
99 | x, y = batch
100 | if torch.cuda.is_available():
101 | x = x.cuda()
102 | for n in range(n_class):
103 | tmp_fam, cnt = freq_activation_map(model, x.float(), len(freq_intervals), channels[-1], n)
104 | total_fam[n, :] += tmp_fam
105 | total_len[n] += cnt
106 |
107 | total_fam /= total_len
108 | print(total_fam.size())
109 | # _, predicted = outputs.max(1)
110 | # print(predicted, Test.Label)
111 | # num_correct = (predicted.cpu() == Test.Label).sum().item()
112 | # acc = num_correct / outputs.cpu().shape[0]
113 | # eval_acc = 0
114 | # eval_acc += fit_score
115 | # print(eval_acc)
116 | rtotal_fam = normalize_freq(total_fam).cpu().detach()
117 | columns_ = np.floor(freq_intervals).astype(int)
118 | plt.plot(columns_, rtotal_fam[0, :])
119 |
120 | result = pd.DataFrame(rtotal_fam.numpy(),
121 | index = ['0', '1', '2', '3'],
122 | columns = columns_,
123 | )
124 |
125 | new_index = ['3', '2', '1', '0']
126 | print(result.values.shape)
127 | new_index.reverse()
128 |
129 | result = result.reindex(new_index)
130 |
131 | sns.heatmap(result, cmap='viridis')
132 | plt.show()
133 |
134 |
--------------------------------------------------------------------------------
/EWSNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | # from wmodelsii8 import Sin_fast as fast
4 | # from wmodelsii3 import Laplace_fast as fast
5 | from wmodelsii8 import Laplace_fastv2 as fast
6 | # from wsinc import SincConv_fast as fast
7 | from Shrinkage import Shrinkagev3ppp2 as sage
8 |
9 | class Mish1(nn.Module):
10 | def __init__(self):
11 | super(Mish1, self).__init__()
12 | self.mish = nn.ReLU(inplace=True)
13 |
14 | def forward(self, x):
15 |
16 | return self.mish(x)
17 |
18 |
19 | class Net(nn.Module):
20 | def __init__(self):
21 | super(Net, self).__init__() #85,42,70 #63,31,75
22 | self.p1_0 = nn.Sequential( # nn.Conv1d(1, 50, kernel_size=18, stride=2),
23 | # fast(out_channels=64, kernel_size=250, stride=1),
24 | # fast1(out_channels=70, kernel_size=84, stride=1),
25 | nn.Conv1d(1, 64, kernel_size=250, stride=1, bias=True),
26 | nn.BatchNorm1d(64),
27 | Mish1()
28 | )
29 | self.p1_1 = nn.Sequential(nn.Conv1d(64, 16, kernel_size=18, stride=2, bias=True),
30 | # fast(out_channels=50, kernel_size=18, stride=2),
31 | nn.BatchNorm1d(16),
32 | Mish1()
33 | )
34 | self.p1_2 = nn.Sequential(nn.Conv1d(16, 10, kernel_size=10, stride=2, bias=True),
35 | nn.BatchNorm1d(10),
36 | Mish1()
37 | )
38 | self.p1_3 = nn.MaxPool1d(kernel_size=2)
39 | self.p2_1 = nn.Sequential(nn.Conv1d(64, 32, kernel_size=6, stride=1, bias=True),
40 | # fast(out_channels=50, kernel_size=6, stride=1),
41 | nn.BatchNorm1d(32),
42 | Mish1()
43 | )
44 | self.p2_2 = nn.Sequential(nn.Conv1d(32, 16, kernel_size=6, stride=1, bias=True),
45 | nn.BatchNorm1d(16),
46 | Mish1()
47 | )
48 | self.p2_3 = nn.MaxPool1d(kernel_size=2)
49 | self.p2_4 = nn.Sequential(nn.Conv1d(16, 10, kernel_size=6, stride=1, bias=True),
50 | nn.BatchNorm1d(10),
51 | Mish1()
52 | )
53 | self.p2_5 = nn.Sequential(nn.Conv1d(10, 10, kernel_size=8, stride=2, bias=True),
54 | # nn.Conv1d(10, 10, kernel_size=6, stride=2),
55 | nn.BatchNorm1d(10),
56 | Mish1()
57 | ) # PRelu
58 | self.p2_6 = nn.MaxPool1d(kernel_size=2)
59 | self.p3_0 = sage(channel=64, gap_size=1)
60 | self.p3_1 = nn.Sequential(nn.Conv1d(64, 10, kernel_size=43, stride=4, bias=True),
61 | nn.BatchNorm1d(10),
62 | Mish1()
63 | )
64 | self.p3_2 = nn.MaxPool1d(kernel_size=2)
65 | self.p3_3 = nn.Sequential(nn.AdaptiveAvgPool1d(1))
66 | self.p4 = nn.Sequential(nn.Linear(10, 4))
67 | self._initialize_weights()
68 |
69 | def _initialize_weights(self):
70 | for m in self.modules():
71 | if isinstance(m, nn.Conv1d):
72 | if m.kernel_size == (250,):
73 | m.weight.data = fast(out_channels=64, kernel_size=250).forward()
74 | nn.init.constant_(m.bias.data, 0.0)
75 | else:
76 | nn.init.kaiming_normal_(m.weight.data)
77 | nn.init.constant_(m.bias.data, 0.0)
78 | elif isinstance(m, nn.BatchNorm1d):
79 | m.weight.data.fill_(1)
80 | m.bias.data.fill_(0)
81 | elif isinstance(m, nn.Linear):
82 | m.weight.data.normal_(0, 0.01)
83 | if m.bias is not None:
84 | m.bias.data.fill_(1)
85 |
86 |
87 |
88 | def forward(self, x):
89 | x = self.p1_0(x)
90 | p1 = self.p1_3(self.p1_2(self.p1_1(x)))
91 | p2 = self.p2_6(self.p2_5(self.p2_4(self.p2_3(self.p2_2(self.p2_1(x))))))
92 | x = self.p3_2(self.p3_1(x + self.p3_0(x)))
93 | x = torch.add(x, torch.add(p1, p2))
94 | x = self.p3_3(x).squeeze()
95 | x = self.p4(x)
96 | return x
97 |
98 | if __name__ == '__main__':
99 | import numpy as np
100 | import matplotlib.pyplot as plt
101 | import math
102 |
103 | # input = torch.randn(2, 1, 1024).cuda()
104 | # model = Net().cuda()
105 | # # for param in model.parameters():
106 | # # print(type(param.data), param.size())
107 | # print("# parameters:", sum(param.numel() for param in model.parameters()))
108 | # output = model(input)
109 | # print(model)
110 | ##################################
111 | # model = Net().cuda()
112 | # for name, parameters in model.named_parameters():
113 | # print(name, ':', parameters.size())
114 | # model.load_state_dict(torch.load('H:\EWSNet\pre\99.53.pt'), strict=False)
115 | # weight_t = model.state_dict()['p3_0.a'].cpu().detach().numpy()
116 | # print(weight_t)
117 | ##################################
118 | # model = Net().cuda()
119 | # weight_t = model.state_dict()['p1_0.0.weight'].cpu().detach().numpy()
120 | # y1 = weight_t[60, :, :].squeeze()
121 | # # model.load_state_dict(torch.load('H:\EWSNet\pre\model.pt'), strict=False)
122 | # # weight_t = model.state_dict()['p1_0.0.weight'].cpu().detach().numpy()
123 | # # y2 = weight_t[20, :, :].squeeze()
124 | # x = np.linspace(0, 250, 250)
125 | # # plt.plot(x, y1, label='before')
126 | # plt.plot(x, y1, label='Xavier_normal')
127 | # # plt.plot(x, y2, label='after')
128 | # plt.legend()
129 | # plt.savefig('wahaha.tiff', format='tiff', dpi=600)
130 | # plt.show()
131 |
132 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 🔥 Physics-informed Interpretable Wavelet Weight Initialization and Balanced Dynamic Adaptive Threshold for Intelligent Fault Diagnosis of Rolling Bearings
2 |
3 | The pytorch implementation of the paper [Physics-informed interpretable wavelet weight initialization and balanced dynamic adaptive threshold for intelligent fault diagnosis of rolling bearings](https://doi.org/10.1016/j.jmsy.2023.08.014)
4 |
5 |
6 | # Updating!
7 |
8 |
9 | [NEWS!]This paper has been accepted by **[Journal of Manufacturing Systems](https://www.sciencedirect.com/journal/journal-of-manufacturing-systems/vol/70/suppl/C)**!
10 |
11 |
12 | ## Brief introduction
13 | Intelligent fault diagnosis of rolling bearings using deep learning-based methods has made unprecedented progress. However, there is still little research on weight initialization and the threshold setting for noise reduction. An innovative deep triple-stream network called EWSNet is proposed, which presents a wavelet weight initialization method and a balanced dynamic adaptive threshold algorithm. Initially, an enhanced wavelet basis function is designed, in which a scale smoothing factor is defined to acquire more rational wavelet scales. Next, a plug-and-play wavelet weight initialization for deep neural networks is proposed, which utilizes physics-informed wavelet prior knowledge and showcases stronger applicability. Furthermore, a balanced dynamic adaptive threshold is established to enhance the noise-resistant robustness of the model. Finally, normalization activation mapping is devised to reveal the effectiveness of Z-score from a visual perspective rather than experimental results. The validity and reliability of EWSNet are demonstrated through four data sets under the conditions of constant and fluctuating speeds.
14 |
15 | ## Highlights
16 |
17 | - **A novel deep triple-stream network called EWSNet is proposed for fault diagnosis of rolling bearings under the condition of constant or sharp speed variation.**
18 | - **An enhanced wavelet convolution kernel is designed to improve the trainability, in which a scale smoothing factor is employed to acquire rational wavelet scales.**
19 | - **A plug-and-play and physics-informed wavelet weight initialization is proposed to construct an interpretable convolution kernel, which makes the diagnosis interpretable.**
20 | - **Balanced dynamic adaptive threshold is specially devised to improve the antinoise robustness of the model.**
21 | - **Normalization activation mapping is designed to visually reveal that Z-score can enhance the frequency-domain information of raw signals.**
22 |
23 |
24 | ## Paper
25 | **Physics-informed Interpretable Wavelet Weight Initialization and Balanced Dynamic Adaptive Threshold for Intelligent Fault Diagnosis of Rolling Bearings**
26 |
27 | Chao Hea,b, **Hongmei Shia,b,***, Jin Sic and Jianbo Lia,b
28 |
29 | aSchool of Mechanical, Electronic and Control Engineering, Beijing Jiaotong University, Beijing 100044, China
30 |
31 | bCollaborative Innovation Center of Railway Traffic Safety, Beijing 100044, China
32 |
33 | cKey laboratory of information system and technology, Beijing institute of control and electronic technology, Beijing 100038, China
34 |
35 | **[Journal of Manufacturing Systems](https://www.sciencedirect.com/journal/journal-of-manufacturing-systems/vol/70/suppl/C)**
36 |
37 | ## EWSNet
38 | 
39 |
40 | ## Wavelet initialization
41 |
42 | 
43 |
44 | ## Balanced Dynamic Adaptive Thresholding
45 |
46 | 
47 |
48 | 
49 |
50 | **where $\alpha$ and $\eta$ are differentiable($\alpha \in \left( {0,1} \right),\alpha \ne 0,1$). When $\alpha$=0 or 1, Eq. respectively degenerates into hard threshold and soft threshold, and thus we can adjust $\alpha$ appropriately to make $y$ closer to the genuine wavelet coefficient.**
51 |
52 | ## Normalization Activation Mapping
53 |
54 | **Data normalization can accelerate the process of convergence. Also, Z-score makes CNN get better accuracy. Unlike experimental methods, we notice that Z-score enhances frequency-domain information of signals so that CNN can learn these features better.**
55 |
56 | **FAM illustrates the frequency-domain information by utilizing the weights of the classification layer and extracted features, but it can not reveal the influence of normalization methods. Therefore in NAM, the weight of the correct label is $1.0$, and the features are signals processed by the normalization methods and it can visualize which normalization method possesses more frequency-domain knowledge.**
57 |
58 | 
59 |
60 | where ${l_{real}}$ is the real label and ${l_{target}}$ is the tested label.
61 |
62 | ## Example:
63 |
64 |
65 |
66 | ```python
67 | class CNNNet(nn.Module):
68 |
69 | def __init__(self, init_weights=False):
70 | super(CNNNet, self).__init__()
71 | self.conv1 = nn.Conv1d(1, 64, 60, padding=2)
72 | self.conv2 = nn.Conv1d(32, 32, 3, padding=1)
73 | self.conv3 = nn.Conv1d(16, 48, 3, padding=1)
74 | #self.sage = sage(channel=16, gap_size=1)
75 | self.conv4 = nn.Conv1d(24, 64, 3, padding=1)
76 | self.pool = nn.MaxPool2d(2)
77 | self.fc1 = nn.Linear(32*60, 512)
78 | self.fc2 = nn.Linear(512, 4)
79 | if init_weights:
80 | self._initialize_weights()
81 |
82 |
83 | def _initialize_weights(self):
84 | for m in self.modules():
85 | if isinstance(m, nn.Conv1d):
86 | if m.kernel_size == (60,):
87 | m.weight.data = fast(out_channels=64, kernel_size=60, eps=0.2, mode='sigmoid').forward()
88 | nn.init.constant_(m.bias.data, 0.0)
89 |
90 | def forward(self, x):
91 | x = self.pool(F.relu(self.conv1(x)))
92 | x = self.pool(F.relu(self.conv2(x)))
93 | #x = x + self.sage(x)
94 | x = self.pool(F.relu(self.conv3(x)))
95 | x = self.pool(F.relu(self.conv4(x)))
96 | x = x.view(x.shape[0], -1)
97 | x = F.relu(self.fc1(x))
98 | x = self.fc2(x)
99 | return x
100 | ```
101 |
102 |
103 |
104 | ## Citation
105 |
106 | ```html
107 | @article{HE,
108 | title = {Physics-informed interpretable wavelet weight initialization and balanced dynamic adaptive threshold for intelligent fault diagnosis of rolling bearings},
109 | journal = {Journal of Manufacturing Systems},
110 | volume = {70},
111 | pages = {579-592},
112 | year = {2023},
113 | issn = {1878-6642},
114 | doi = {https://doi.org/10.1016/j.jmsy.2023.08.014},
115 | author = {Chao He and Hongmei Shi and Jin Si and Jianbo Li}
116 | ```
117 |
118 | C. He, H. Shi, J. Si, J. Li, Physics-informed interpretable wavelet weight initialization and balanced dynamic adaptive threshold for intelligent fault diagnosis of rolling bearings, Journal of Manufacturing Systems 70 (2023) 579-592, https://doi.org/10.1016/j.jmsy.2023.08.014.
119 |
120 |
121 |
122 | ## Ackowledgements
123 | The authors are grateful for the supports of the Fundamental Research Funds for the Central Universities (Science and Technology Leading Talent Team Project) (2022JBXT005), and the National Natural Science Foundation of China (No.52272429).
124 |
125 | ## References
126 |
127 | - von Rueden L, Mayer S, Beckh K, Georgiev B, Giesselbach S, Heese R, et al Informed Machine Learning – A Taxonomy and Survey of Integrating Prior Knowledge into Learning Systems. IEEE Trans Knowl Data Eng 2023;35(1):614–633. https://doi.org/10.1109/TKDE.2021.3079836
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 | - Vollert S, Atzmueller M, Theissler A. Interpretable Machine Learning: A brief survey from the predictive maintenance perspective. In: 26th IEEE International Conference on Emerging Technologies and Factory Automation, ETFA 2021, Vasteras, Sweden, September 7-10, 2021 IEEE; 2021. p. 1–8. https://doi.org/10.1109/ETFA45728.2021.9613467
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 | - Li T, Zhao Z, Sun C, Cheng L, Chen X, Yan R, et al WaveletKernelNet: An Interpretable Deep Neural Network for Industrial Intelligent Diagnosis. IEEE Trans Syst Man Cybern Syst 2022;52(4):2302–2312. https://doi.org/10.1109/TSMC.2020.3048950
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 | - Zhao M, Zhong S, Fu X, Tang B, Pecht M. Deep Residual Shrinkage Networks for Fault Diagnosis. IEEE Trans Industr Inform 2020;16(7):4681–4690. https://doi.org/10.1109/TII.2019.2943898
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 | - Kim MS, Yun JP, Park P. An Explainable Neural Network for Fault Diagnosis With a Frequency Activation Map. IEEE Access 2021;9:98962–98972. https://doi.org/10.1109/ACCESS.2021.3095565
179 |
180 |
181 |
182 |
183 |
184 |
185 | ## Contact
186 |
187 | - **Chao He**
188 | - **chaohe#bjtu.edu.cn (please replace # by @)**
189 |
190 |
191 |
--------------------------------------------------------------------------------
/weights.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from math import pi
4 | import torch.nn.functional as F
5 |
6 |
7 | '''
8 | Laplace小波卷积核初始化
9 | '''
10 | def Laplace(p):
11 |
12 | w = 2 * pi * 80
13 | q = torch.tensor(1 - pow(0.03, 2))
14 | return 0.08 * torch.exp(((-0.03 / (torch.sqrt(q))) * (w * (p - 0.1)))) * (torch.sin(w * (p - 0.1)))
15 |
16 | class Laplace_fast(nn.Module):
17 | def __init__(self, out_channels, kernel_size, stride=2):
18 | super(Laplace_fast, self).__init__()
19 | self.out_channels = out_channels
20 | self.kernel_size = kernel_size
21 | self.stride = stride
22 | self.a_ = nn.Parameter(torch.linspace(1, 100, out_channels).view(-1, 1))
23 | self.b_ = nn.Parameter(torch.linspace(0, 100, out_channels).view(-1, 1))
24 |
25 | def forward(self, waveforms):
26 | time_disc = torch.linspace(0, 1, steps=int(self.kernel_size))
27 | p1 = (time_disc.cuda() - self.b_.cuda()) / (self.a_.cuda())
28 | laplace_filter = Laplace(p1)
29 | filters = laplace_filter.view(self.out_channels, 1, self.kernel_size).cuda()
30 | return F.conv1d(waveforms, filters, stride=self.stride, padding=0)
31 |
32 |
33 | class Laplace_fastv2:
34 |
35 | def __init__(self, out_channels, kernel_size, eps=-0.3):
36 | super(Laplace_fastv2, self).__init__()
37 | self.out_channels = out_channels
38 | self.kernel_size = kernel_size
39 | self.eps = eps
40 | self.a_ = torch.linspace(0, out_channels, out_channels).view(-1, 1) #out_channels-1 通道是整数
41 | self.b_ = torch.linspace(0, out_channels, out_channels).view(-1, 1)
42 | self.time_disc = torch.linspace(0, self.kernel_size - 1, steps=int(self.kernel_size))
43 |
44 | def forward(self):
45 | p1 = (self.time_disc - self.b_) / (self.a_ - self.eps)
46 | filter = Laplace(p1).view(self.out_channels, 1, self.kernel_size) # (70,1,85)
47 | return filter
48 |
49 | class Laplace_fastv21(nn.Module):
50 |
51 | def __init__(self, out_channels, kernel_size, eps=0.3):
52 | super(Laplace_fastv21, self).__init__()
53 | self.out_channels = out_channels
54 | self.kernel_size = kernel_size
55 | self.eps = eps
56 | self.a_ = torch.linspace(0, self.out_channels, self.out_channels).view(-1, 1)
57 | self.b_ = torch.linspace(0, self.out_channels, self.out_channels).view(-1, 1)
58 | self.time_disc = torch.linspace(0, self.kernel_size-1, steps=int(self.kernel_size))
59 |
60 | def forward(self):
61 | p1 = (self.time_disc - self.b_) / (self.a_ + self.eps)
62 | filter = Laplace(p1).view(self.out_channels, 1, self.kernel_size) # (70,1,85)
63 | return filter
64 | # def Laplace1(p):
65 | # # m = 1000
66 | # # ep = 0.03
67 | # # # tal = 0.1
68 | # # f = 80
69 | # w = 2 * pi * 80
70 | # # A = 0.08
71 | # q = torch.tensor(1 - pow(0.03, 2))
72 | # #return (0.08 * torch.exp(((-0.03 / (torch.sqrt(q))) * (w * (p - 0.1)))) * (torch.sin(w * (p - 0.1))))
73 | # # y = 0.08 * torch.exp(((-0.03 / (torch.sqrt(q))) * (w * (p - 0.1))).sigmoid()) * (torch.sin(w * (p - 0.1))) # 99.82%
74 | # return 0.08 * torch.exp(((-0.03 / (torch.sqrt(q))) * (w * (p - 0.1)))) * (torch.sin(w * (p - 0.1)))
75 |
76 | class Laplace_fastv3(nn.Module):
77 |
78 | def __init__(self, out_channels, kernel_size):
79 | super(Laplace_fastv3, self).__init__()
80 | self.out_channels = out_channels
81 | self.kernel_size = kernel_size
82 | self.a_ = torch.linspace(1, out_channels, out_channels).view(-1, 1)
83 | self.b_ = torch.linspace(0, out_channels, out_channels).view(-1, 1)
84 | self.time_disc = torch.linspace(0, 1, steps=int(self.kernel_size))
85 |
86 | def forward(self):
87 | p1 = (self.time_disc - self.b_) / self.a_
88 | filter = Laplace(p1).view(self.out_channels, 1, self.kernel_size) # (70,1,85)
89 | return filter
90 |
91 |
92 | '''
93 | 小波核初始化
94 | '''
95 |
96 |
97 | def Morlet(p, c):
98 |
99 | return c * torch.exp((-torch.pow(p, 2) / 2)) * torch.cos(5 * p)
100 |
101 |
102 |
103 | class Morlet_fast(nn.Module):
104 |
105 | def __init__(self, out_channels, kernel_size, eps=0.3):
106 | super(Morlet_fast, self).__init__()
107 | if kernel_size % 2 != 0:
108 | kernel_size = kernel_size - 1
109 | self.out_channels = out_channels
110 | self.kernel_size = kernel_size
111 | self.eps = eps
112 | self.a_ = torch.linspace(0, out_channels, out_channels).view(-1, 1)
113 | self.b_ = torch.linspace(0, out_channels, out_channels).view(-1, 1)
114 | self.time_disc_right = torch.linspace(0, (self.kernel_size / 2) - 1, steps=int((self.kernel_size / 2)))
115 | self.time_disc_left = torch.linspace(-(self.kernel_size / 2) + 1, -1, steps=int((self.kernel_size / 2)))
116 |
117 | def forward(self):
118 | p1 = (self.time_disc_right - self.b_) / (self.a_ + self.eps) #
119 | p2 = (self.time_disc_left - self.b_) / (self.a_ + self.eps) #
120 | C = (pow(pi, 0.25)) / torch.sqrt(self.a_ + 0.01) ##一个值得探讨的点 1e-3 # D = C / self.a_.cuda()
121 | Morlet_right = Morlet(p1, C)
122 | Morlet_left = Morlet(p2, C)
123 | filter = torch.cat([Morlet_left, Morlet_right], dim=1).view(self.out_channels, 1, self.kernel_size)
124 | return filter
125 |
126 |
127 | '''
128 | Mexh小波卷积核
129 | '''
130 |
131 |
132 | def Mexh(p):
133 | # p = 0.04 * p # 将时间转化为在[-5,5]这个区间内
134 | # y = (2 / pow(3, 0.5) * (pow(pi, -0.25))) * (1 - torch.pow(p, 2)) * torch.exp((-torch.pow(p, 2) / 2))
135 |
136 | return (2/pi)*((2 / pow(3, 0.5) * (pow(pi, -0.25))) * (1 - torch.pow(p, 2)) * torch.exp((-torch.pow(p, 2) / 2))).atan()
137 |
138 | class Mexh_fast(nn.Module):
139 |
140 | def __init__(self, out_channels, kernel_size, eps=0.3):
141 | super(Mexh_fast, self).__init__()
142 | if kernel_size % 2 != 0:
143 | kernel_size = kernel_size - 1
144 | self.out_channels = out_channels
145 | self.kernel_size = kernel_size
146 | self.eps = eps
147 | self.a_ = torch.linspace(0, out_channels, out_channels).view(-1, 1)
148 | self.b_ = torch.linspace(0, out_channels, out_channels).view(-1, 1)
149 | self.time_disc_right = torch.linspace(0, (self.kernel_size / 2) - 1, steps=int((self.kernel_size / 2)))
150 | self.time_disc_left = torch.linspace(-(self.kernel_size / 2) + 1, -1, steps=int((self.kernel_size / 2)))
151 |
152 | def forward(self):
153 | p1 = (self.time_disc_right - self.b_) / (self.a_ + self.eps) #
154 | p2 = (self.time_disc_left - self.b_) / (self.a_ + self.eps) #
155 | Mexh_right = Mexh(p1)
156 | Mexh_left = Mexh(p2)
157 | filter = torch.cat([Mexh_left, Mexh_right], dim=1).view(self.out_channels, 1, self.kernel_size) # 40x1x250
158 | return filter
159 |
160 |
161 | '''
162 | Gaussian小波卷积核
163 | '''
164 |
165 |
166 | def Gaussian(p):
167 | # y = D * torch.exp(-torch.pow(p, 2))
168 | # F0 = (2./pi)**(1./4.) * torch.exp(-torch.pow(p, 2))
169 | # y = -2 / (3 ** (1 / 2)) * (-1 + 2 * p ** 2) * F0
170 | # y = (2./pi)**(1./4.) * torch.exp(-torch.pow(p, 2))
171 | # y = -2 / (3 ** (1 / 2)) * (-1 + 2 * p ** 2) * y
172 | # y = -((1 / (pow(2 * pi, 0.5))) * p * torch.exp((-torch.pow(p, 2)) / 2))
173 | return -((1 / (pow(2 * pi, 0.5))) * p * (torch.exp(((-torch.pow(p, 2)) / 2))))
174 |
175 |
176 |
177 |
178 | class Gaussian_fast(nn.Module):
179 |
180 | def __init__(self, out_channels, kernel_size, eps=0.3):
181 | super(Gaussian_fast, self).__init__()
182 | if kernel_size % 2 != 0:
183 | kernel_size = kernel_size - 1
184 | self.out_channels = out_channels
185 | self.kernel_size = kernel_size
186 | self.eps = eps
187 | self.a_ = torch.linspace(0, out_channels, out_channels).view(-1, 1)
188 | self.b_ = torch.linspace(0, out_channels, out_channels).view(-1, 1)
189 | self.time_disc_right = torch.linspace(0, (self.kernel_size / 2) - 1, steps=int((self.kernel_size / 2)))
190 | self.time_disc_left = torch.linspace(-(self.kernel_size / 2) + 1, -1, steps=int((self.kernel_size / 2)))
191 |
192 | def forward(self):
193 | p1 = (self.time_disc_right - self.b_) / (self.a_ + self.eps) #
194 | p2 = (self.time_disc_left - self.b_) / (self.a_ + self.eps) #
195 | Gaussian_right = Gaussian(p1)
196 | Gaussian_left = Gaussian(p2)
197 | filter = torch.cat([Gaussian_left, Gaussian_right], dim=1).view(self.out_channels, 1,
198 | self.kernel_size) # 40x1x250
199 | return filter
200 |
201 |
202 | def Shannon(p):
203 | # y = (torch.sin(2 * pi * (p - 0.5)) - torch.sin(pi * (p - 0.5))) / (pi * (p - 0.5))
204 | return (torch.sin(2 * pi * (p - 0.5)) - torch.sin(pi * (p - 0.5))) / (pi * (p - 0.5))
205 |
206 |
207 | class Shannon_fast(nn.Module):
208 |
209 | def __init__(self, out_channels, kernel_size, eps=0.3):
210 | super(Shannon_fast, self).__init__()
211 | if kernel_size % 2 != 0:
212 | kernel_size = kernel_size - 1
213 | self.out_channels = out_channels
214 | self.kernel_size = kernel_size
215 | self.eps = eps
216 | self.a_ = torch.linspace(0, out_channels, out_channels).view(-1, 1)
217 | self.b_ = torch.linspace(0, out_channels, out_channels).view(-1, 1)
218 | self.time_disc_right = torch.linspace(0, (self.kernel_size / 2) - 1, steps=int((self.kernel_size / 2)))
219 | self.time_disc_left = torch.linspace(-(self.kernel_size / 2) + 1, -1, steps=int((self.kernel_size / 2)))
220 | # self.time_disc_right = torch.linspace(0, 1, steps=int((self.kernel_size / 2)))
221 | # self.time_disc_left = torch.linspace(-1, 0, steps=int((self.kernel_size / 2)))
222 |
223 | def forward(self):
224 | p1 = (self.time_disc_right - self.b_) / (self.a_ + self.eps) #
225 | p2 = (self.time_disc_left - self.b_) / (self.a_ + self.eps) #
226 | Shannon_right = Shannon(p1)
227 | Shannon_left = Shannon(p2)
228 | filter = torch.cat([Shannon_left, Shannon_right], dim=1).view(self.out_channels, 1,
229 | self.kernel_size) # 40x1x250
230 | return filter
231 |
232 | def Sin(p):
233 | # y = (torch.sin(2 * pi * (p - 0.5)) - torch.sin(pi * (p - 0.5))) / (pi * (p - 0.5))
234 | return torch.sin(p) / p
235 |
236 |
237 | class Sin_fast(nn.Module):
238 |
239 | def __init__(self, out_channels, kernel_size, eps=0.3):
240 | super(Sin_fast, self).__init__()
241 | if kernel_size % 2 != 0:
242 | kernel_size = kernel_size - 1
243 | self.out_channels = out_channels
244 | self.kernel_size = kernel_size
245 | self.eps = eps
246 | self.a_ = torch.linspace(0, out_channels, out_channels).view(-1, 1)
247 | self.b_ = torch.linspace(0, out_channels, out_channels).view(-1, 1)
248 | self.time_disc_right = torch.linspace(0, (self.kernel_size / 2) - 1, steps=int((self.kernel_size / 2)))
249 | self.time_disc_left = torch.linspace(-(self.kernel_size / 2) + 1, -1, steps=int((self.kernel_size / 2)))
250 | # self.time_disc_right = torch.linspace(0, 1, steps=int((self.kernel_size / 2)))
251 | # self.time_disc_left = torch.linspace(-1, 0, steps=int((self.kernel_size / 2)))
252 |
253 | def forward(self):
254 | p1 = (self.time_disc_right - self.b_) / (self.a_ + self.eps) #
255 | p2 = (self.time_disc_left - self.b_) / (self.a_ + self.eps) #
256 | Shannon_right = Sin(p1)
257 | Shannon_left = Sin(p2)
258 | filter = torch.cat([Shannon_left, Shannon_right], dim=1).view(self.out_channels, 1,
259 | self.kernel_size) # 40x1x250
260 | return filter
261 |
262 |
263 | if __name__ == '__main__':
264 | import numpy as np
265 | import matplotlib.pyplot as plt
266 |
267 | input = torch.randn(2, 1, 1024).cuda()
268 | weight = Laplace_fastv2(out_channels=64, kernel_size=250).forward().cuda()
269 | weight_t = weight.cpu().detach().numpy()
270 | y = weight_t[20, :, :].squeeze()
271 | x = np.linspace(0, 250, 250)
272 | plt.plot(x, y)
273 | plt.savefig('wahaha.tiff', format='tiff', dpi=600)
274 | plt.show()
275 |
--------------------------------------------------------------------------------
/thresholds.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class Shrinkagev2(nn.Module):
5 | def __init__(self, channel, gap_size):
6 | super(Shrinkagev2, self).__init__()
7 | self.gap = nn.AdaptiveAvgPool1d(gap_size)
8 | self.fc = nn.Sequential(
9 | nn.Linear(channel, channel),
10 | nn.BatchNorm1d(channel),
11 | nn.ReLU(inplace=True),
12 | nn.Linear(channel, channel),
13 | nn.Sigmoid(),
14 | )
15 |
16 | def forward(self, x):
17 | x_raw = x
18 | # x = torch.abs(x)
19 | x_abs = x.abs()
20 | x = self.gap(x)
21 | x = torch.flatten(x, 1)
22 | # average = torch.mean(x, dim=1, keepdim=True) #CS
23 | average = x #CW
24 | x = self.fc(x)
25 | x = torch.mul(average, x).unsqueeze(2)
26 | # soft thresholding
27 | x = x_abs - x
28 | # zeros = sub - sub
29 | # n_sub = torch.max(sub, torch.zeros_like(sub))
30 | x = torch.mul(torch.sign(x_raw), torch.max(x, torch.zeros_like(x)))
31 | return x
32 |
33 | class Shrinkage(nn.Module):
34 | def __init__(self, channel, gap_size):
35 | super(Shrinkage, self).__init__()
36 | self.gap = nn.AdaptiveAvgPool1d(gap_size)
37 | self.fc = nn.Sequential(
38 | nn.Linear(channel, channel),
39 | nn.BatchNorm1d(channel),
40 | nn.ReLU(inplace=True),
41 | nn.Linear(channel, channel),
42 | nn.Sigmoid(),
43 | )
44 |
45 | def forward(self, x):
46 | x_raw = x
47 | x = torch.abs(x)
48 | x_abs = x
49 | x = self.gap(x)
50 | x = torch.flatten(x, 1)
51 | # average = torch.mean(x, dim=1, keepdim=True) #CS
52 | average = x #CW
53 | x = self.fc(x)
54 | x = torch.mul(average, x)
55 | x = x.unsqueeze(2)
56 | # soft thresholding
57 | sub = x_abs - x
58 | zeros = sub - sub
59 | n_sub = torch.max(sub, zeros)
60 | x = torch.mul(torch.sign(x_raw), n_sub)
61 | return x
62 |
63 | class Shrinkagev3(nn.Module):
64 | def __init__(self, gap_size, inp, oup, reduction=4):
65 | super(Shrinkagev3, self).__init__()
66 | mip = int(max(8, inp // reduction))
67 | self.gap = nn.AdaptiveAvgPool1d(gap_size)
68 | self.fc = nn.Sequential(
69 | nn.Conv1d(inp, mip, kernel_size=1, stride=1, padding=0),
70 | nn.BatchNorm1d(mip),
71 | nn.ReLU(inplace=True),
72 | nn.Conv1d(mip, oup, kernel_size=1, stride=1, padding=0),
73 | nn.Sigmoid()
74 | )
75 |
76 | # def forward(self, x):
77 | # x_raw = x
78 | # x = torch.abs(x)
79 | # x_abs = x
80 | # x = self.gap(x)
81 | # # x = torch.flatten(x, 1)
82 | # # average = torch.mean(x, dim=1, keepdim=True) #CS
83 | # average = x #CW
84 | # x = self.fc(x)
85 | # x = torch.mul(average, x)
86 | # # x = x.unsqueeze(2)
87 | # # soft thresholding
88 | # sub = x_abs - x
89 | # zeros = sub - sub
90 | # n_sub = torch.max(sub, zeros)
91 | # x = torch.mul(torch.sign(x_raw), n_sub)
92 | # return x
93 | def forward(self, x):
94 | x_raw = x
95 | # x = torch.abs(x)
96 | x_abs = x.abs()
97 | x = self.gap(x)
98 | # average = torch.mean(x, dim=1, keepdim=True) #CS
99 | average = x #CW
100 | x = self.fc(x)
101 | x = torch.mul(average, x)
102 | # soft thresholding
103 | x = x_abs - x
104 | # zeros = sub - sub
105 | # n_sub = torch.max(sub, torch.zeros_like(sub))
106 | x = torch.mul(torch.sign(x_raw), torch.max(x, torch.zeros_like(x)))
107 | return x
108 |
109 | class Shrinkagev3p(nn.Module):
110 | def __init__(self, gap_size, inp, oup, reduction=4):
111 | super(Shrinkagev3p, self).__init__()
112 | mip = int(max(8, inp // reduction))
113 | self.a = nn.Parameter(torch.tensor([0.4]))
114 | self.gap = nn.AdaptiveAvgPool1d(gap_size)
115 | # self.gap = nn.AdaptiveMaxPool1d(gap_size)
116 | self.fc = nn.Sequential(
117 | nn.Conv1d(inp, mip, kernel_size=1, stride=1),
118 | nn.BatchNorm1d(mip),
119 | nn.ReLU(inplace=True),
120 | nn.Conv1d(mip, oup, kernel_size=1, stride=1),
121 | nn.Hardsigmoid()
122 | # nn.Sigmoid()
123 | )
124 |
125 | # def forward(self, x):
126 | # x_raw = x
127 | # x = torch.abs(x)
128 | # x_abs = x
129 | # x = self.gap(x)
130 | # # x = torch.flatten(x, 1)
131 | # # average = torch.mean(x, dim=1, keepdim=True) #CS
132 | # average = x #CW
133 | # x = self.fc(x)
134 | # x = torch.mul(average, x)
135 | # # x = x.unsqueeze(2)
136 | # # soft thresholding
137 | # sub = x_abs - x
138 | # zeros = sub - sub
139 | # n_sub = torch.max(sub, zeros)
140 | # x = torch.mul(torch.sign(x_raw), n_sub)
141 | # return x
142 | def forward(self, x):
143 | x_raw = x
144 | # x = torch.abs(x)
145 | x_abs = x.abs()
146 | x = self.gap(x)
147 | # average = torch.mean(x, dim=1, keepdim=True) #CS
148 | average = x #CW
149 | x = self.fc(x)
150 | x = torch.mul(average, x)
151 | # soft thresholding
152 | # x = x_abs - x
153 | x = x_abs - self.a * x
154 | # zeros = sub - sub
155 | # n_sub = torch.max(sub, torch.zeros_like(sub))
156 | x = torch.mul(torch.sign(x_raw), torch.max(x, torch.zeros_like(x)))
157 | return x
158 |
159 | class Shrinkagev3p1(nn.Module):
160 | def __init__(self, gap_size, inp, oup, reduction=32):
161 | super(Shrinkagev3p1, self).__init__()
162 | mip = int(max(8, inp // reduction))
163 | #self.a = nn.Parameter(torch.tensor([0.5]))
164 | self.a = nn.Parameter(0.5 * torch.ones(1, 10, 1))
165 | self.gap = nn.AdaptiveAvgPool1d(gap_size)
166 | self.fc = nn.Sequential(
167 | nn.Conv1d(inp, mip, kernel_size=1, stride=1, padding=0),
168 | nn.BatchNorm1d(mip),
169 | nn.ReLU(inplace=True),
170 | nn.Conv1d(mip, oup, kernel_size=1, stride=1, padding=0),
171 | nn.Hardsigmoid()
172 | # nn.Sigmoid()
173 | )
174 |
175 | # def forward(self, x):
176 | # x_raw = x
177 | # x = torch.abs(x)
178 | # x_abs = x
179 | # x = self.gap(x)
180 | # # x = torch.flatten(x, 1)
181 | # # average = torch.mean(x, dim=1, keepdim=True) #CS
182 | # average = x #CW
183 | # x = self.fc(x)
184 | # x = torch.mul(average, x)
185 | # # x = x.unsqueeze(2)
186 | # # soft thresholding
187 | # sub = x_abs - x
188 | # zeros = sub - sub
189 | # n_sub = torch.max(sub, zeros)
190 | # x = torch.mul(torch.sign(x_raw), n_sub)
191 | # return x
192 | def forward(self, x):
193 | x_raw = x
194 | # x = torch.abs(x)
195 | x_abs = x.abs()
196 | x = self.gap(x)
197 | # average = torch.mean(x, dim=1, keepdim=True) #CS
198 | average = x #CW
199 | x = self.fc(x)
200 | x = torch.mul(average, x)
201 | # soft thresholding
202 | x = x_abs - self.a * x
203 | # zeros = sub - sub
204 | # n_sub = torch.max(sub, torch.zeros_like(sub))
205 | x = torch.mul(torch.sign(x_raw), torch.max(x, torch.zeros_like(x)))
206 | return x
207 | class Shrinkagev3p11(nn.Module):
208 | def __init__(self, gap_size, channel):
209 | super(Shrinkagev3p11, self).__init__()
210 | #self.a = nn.Parameter(torch.tensor([0.5]))
211 | self.a = nn.Parameter(torch.tensor([0.48]))
212 | self.gap = nn.AdaptiveAvgPool1d(gap_size)
213 | self.fc = nn.Sequential(
214 | nn.Linear(channel, channel),
215 | nn.BatchNorm1d(channel),
216 | nn.ReLU(inplace=True),
217 | nn.Linear(channel, channel),
218 | nn.Hardsigmoid()
219 | )
220 |
221 | def forward(self, x):
222 | x_raw = x
223 | # x = torch.abs(x)
224 | x_abs = x.abs()
225 | x = self.gap(x)
226 | x = torch.flatten(x, 1)
227 | # average = torch.mean(x, dim=1, keepdim=True) #CS
228 | average = x #CW
229 | x = self.fc(x)
230 | x = torch.mul(average, x).unsqueeze(2)
231 | # soft thresholding
232 | x = x_abs - self.a * x
233 | # zeros = sub - sub
234 | # n_sub = torch.max(sub, torch.zeros_like(sub))
235 | x = torch.mul(torch.sign(x_raw), torch.max(x, torch.zeros_like(x)))
236 | return x
237 | class Shrinkagev3pp(nn.Module): #会议文献的软阈值降噪
238 | def __init__(self, gap_size, inp, oup, reduction=4):
239 | super(Shrinkagev3pp, self).__init__()
240 | mip = int(max(8, inp // reduction))
241 | self.a = nn.Parameter(torch.tensor([0.48]))
242 | # self.a = nn.Parameter(torch.clamp(torch.tensor([0.4]), min=0, max=1))
243 | self.gap = nn.AdaptiveAvgPool1d(gap_size)
244 | # self.gap = nn.AdaptiveMaxPool1d(gap_size)
245 | self.fc = nn.Sequential(
246 | nn.Conv1d(inp, mip, kernel_size=1, stride=1),
247 | nn.BatchNorm1d(mip),
248 | nn.ReLU(inplace=True),
249 | nn.Conv1d(mip, oup, kernel_size=1, stride=1),
250 | nn.Hardsigmoid()
251 | # nn.Sigmoid()
252 | )
253 |
254 | # def forward(self, x):
255 | # x_raw = x
256 | # x = torch.abs(x)
257 | # x_abs = x
258 | # x = self.gap(x)
259 | # # x = torch.flatten(x, 1)
260 | # # average = torch.mean(x, dim=1, keepdim=True) #CS
261 | # average = x #CW
262 | # x = self.fc(x)
263 | # x = torch.mul(average, x)
264 | # # x = x.unsqueeze(2)
265 | # # soft thresholding
266 | # sub = x_abs - x
267 | # zeros = sub - sub
268 | # n_sub = torch.max(sub, zeros)
269 | # x = torch.mul(torch.sign(x_raw), n_sub)
270 | # return x
271 | def forward(self, x):
272 | x_raw = x
273 | # x = torch.abs(x)
274 | x_abs = x.abs()
275 | x = self.gap(x)
276 |
277 | # average = torch.mean(x, dim=1, keepdim=True) #CS
278 | average = x #CW
279 | x = self.fc(x)
280 | x = torch.mul(average, x)
281 | # soft thresholding
282 | x = x_abs - x
283 | # x = x_abs - self.a * x
284 | # sub = torch.max(x1, torch.zeros_like(x1))
285 | # sub = self.a * sub
286 | a = torch.clamp(self.a, min=0, max=1)
287 | x = torch.mul(torch.sign(x_raw), a*(torch.max(x, torch.zeros_like(x))))
288 | return x
289 | class Shrinkagev3pp1(nn.Module): #会议文献的软阈值降噪
290 | def __init__(self, gap_size, channel):
291 | super(Shrinkagev3pp1, self).__init__()
292 | self.a = nn.Parameter(torch.tensor([0.48]))
293 | self.gap = nn.AdaptiveAvgPool1d(gap_size)
294 | self.fc = nn.Sequential(
295 | nn.Linear(channel, channel),
296 | nn.BatchNorm1d(channel),
297 | nn.ReLU(inplace=True),
298 | nn.Linear(channel, channel),
299 | nn.Sigmoid()
300 | )
301 |
302 | def forward(self, x):
303 | x_raw = x
304 | x_abs = x.abs()
305 | x = self.gap(x)
306 | x = torch.flatten(x, 1)
307 | # average = torch.mean(x, dim=1, keepdim=True) #CS
308 | average = x #CW
309 | x = self.fc(x)
310 | x = torch.mul(average, x).unsqueeze(2)
311 | # soft thresholding
312 | x = x_abs - x
313 | a = torch.clamp(self.a, min=0, max=1)
314 | x = torch.mul(torch.sign(x_raw), a*(torch.max(x, torch.zeros_like(x))))
315 | return x
316 | class Shrinkagev3ppp(nn.Module): #半软阈值降噪
317 | def __init__(self, gap_size, inp, oup, reduction=4):
318 | super(Shrinkagev3ppp, self).__init__()
319 | mip = int(max(8, inp // reduction))
320 | self.a = nn.Parameter(torch.tensor([0.48]))
321 | self.gap = nn.AdaptiveAvgPool1d(gap_size)
322 | # self.gap = nn.AdaptiveMaxPool1d(gap_size)
323 | self.fc = nn.Sequential(
324 | nn.Conv1d(inp, mip, kernel_size=1, stride=1),
325 | nn.BatchNorm1d(mip),
326 | nn.ReLU(inplace=True),
327 | nn.Conv1d(mip, oup, kernel_size=1, stride=1),
328 | # nn.Hardsigmoid()
329 | nn.Sigmoid()
330 | )
331 |
332 | def forward(self, x):
333 | x_raw = x
334 | x_abs = x.abs()
335 | x = self.gap(x)
336 | average = x
337 | x = self.fc(x)
338 | x = torch.mul(average, x)
339 | # soft thresholding
340 | # x1 = x_abs - x
341 | # sub = torch.max(x1, torch.zeros_like(x1))
342 | sub = torch.max(x_abs - x, torch.zeros_like(x_abs - x))
343 | mask = sub.clone()
344 | mask[mask > 0] = 1
345 | # a = torch.clamp(self.a, min=0, max=1)
346 | x = sub + (1-self.a) * x
347 | x = torch.mul(x, mask)
348 | x = torch.mul(torch.sign(x_raw), x)
349 | return x
350 |
351 | class Shrinkagev3ppp1(nn.Module):
352 | def __init__(self, gap_size, inp, oup, reduction=4):
353 | super(Shrinkagev3ppp1, self).__init__()
354 | mip = int(max(8, inp // reduction))
355 | self.a = nn.Parameter(torch.tensor([0.4]))
356 | self.gap = nn.AdaptiveAvgPool1d(gap_size)
357 | # self.gap = nn.AdaptiveMaxPool1d(gap_size)
358 | self.fc = nn.Sequential(
359 | nn.Conv1d(inp, mip, kernel_size=1, stride=1),
360 | nn.BatchNorm1d(mip),
361 | nn.ReLU(inplace=True),
362 | nn.Conv1d(mip, oup, kernel_size=1, stride=1),
363 | nn.Hardsigmoid()
364 | # nn.Sigmoid()
365 | )
366 |
367 | def forward(self, x):
368 | x_raw = x
369 | x_abs = x.abs()
370 | x = self.gap(x)
371 | average = x
372 | x = self.fc(x)
373 | x = torch.mul(average, x)
374 | # soft thresholding
375 | sub = torch.max(x_abs - x, torch.zeros_like(x_abs - x))
376 | x = sub + (1-self.a) * x
377 | x = x.index_put((sub == 0).nonzero(as_tuple=True), torch.tensor(0.))
378 | x = torch.mul(torch.sign(x_raw), x)
379 | return x
380 |
381 | class Shrinkagev3ppp2(nn.Module): #半软阈值降噪
382 | def __init__(self, gap_size, channel):
383 | super(Shrinkagev3ppp2, self).__init__()
384 | self.a = nn.Parameter(torch.tensor([0.48]))
385 | self.gap = nn.AdaptiveAvgPool1d(gap_size)
386 | self.fc = nn.Sequential(
387 | nn.Linear(channel, channel),
388 | nn.BatchNorm1d(channel),
389 | nn.ReLU(inplace=True),
390 | nn.Linear(channel, channel),
391 | nn.Sigmoid(),
392 | )
393 |
394 | def forward(self, x):
395 |
396 | x_raw = x
397 | x_abs = x.abs()
398 | x = self.gap(x)
399 | x = torch.flatten(x, 1)
400 | average = x
401 | # average = torch.mean(x, dim=1, keepdim=True)
402 | x = self.fc(x)
403 | x = torch.mul(average, x).unsqueeze(2)
404 | # soft thresholding
405 | sub = torch.max(x_abs - x, torch.zeros_like(x_abs - x))
406 | mask = sub.clone()
407 | mask[mask > 0] = 1
408 | # a = torch.clamp(self.a, min=0, max=1)
409 | x = sub + (1 - self.a) * x
410 | x = torch.mul(torch.sign(x_raw), torch.mul(x, mask))
411 | return x
412 |
413 | # class Shrinkagev3ppp3(nn.Module): #半软阈值降噪
414 | # def __init__(self, gap_size, channel):
415 | # super(Shrinkagev3ppp3, self).__init__()
416 | # self.a = nn.Parameter(0.48 * torch.ones(16, channel, 1))
417 | # self.gap = nn.AdaptiveAvgPool1d(gap_size)
418 | # self.fc = nn.Sequential(
419 | # nn.Linear(channel, channel),
420 | # nn.BatchNorm1d(channel),
421 | # nn.ReLU(inplace=True),
422 | # nn.Linear(channel, channel),
423 | # nn.Sigmoid(),
424 | # )
425 | #
426 | # def forward(self, x):
427 | #
428 | # x_raw = x
429 | # x_abs = x.abs()
430 | # x = self.gap(x)
431 | # x = torch.flatten(x, 1)
432 | # average = x
433 | # # average = torch.mean(x, dim=1, keepdim=True)
434 | # x = self.fc(x)
435 | # x = torch.mul(average, x).unsqueeze(2)
436 | # # soft thresholding
437 | # sub = torch.max(x_abs - x, torch.zeros_like(x_abs - x))
438 | # mask = sub.clone()
439 | # mask[mask > 0] = 1
440 | # # a = torch.clamp(self.a, min=0, max=1)
441 | # x = sub + (1 - self.a) * x
442 | # x = torch.mul(torch.sign(x_raw), torch.mul(x, mask))
443 | # return x
444 |
445 | class HShrinkage(nn.Module): #硬阈值降噪
446 | def __init__(self, gap_size, channel):
447 | super(HShrinkage, self).__init__()
448 | self.gap = nn.AdaptiveAvgPool1d(gap_size)
449 | self.fc = nn.Sequential(
450 | nn.Linear(channel, channel),
451 | nn.BatchNorm1d(channel),
452 | nn.ReLU(inplace=True),
453 | nn.Linear(channel, channel),
454 | nn.Sigmoid(),
455 | )
456 |
457 | def forward(self, x):
458 |
459 | x_raw = x
460 | x_abs = x.abs()
461 | x = self.gap(x)
462 | x = torch.flatten(x, 1)
463 | average = x
464 | # average = torch.mean(x, dim=1, keepdim=True)
465 | x = self.fc(x)
466 | x = torch.mul(average, x).unsqueeze(2)
467 | # soft thresholding
468 | sub = torch.max(x_abs - x, torch.zeros_like(x_abs - x))
469 | mask = sub.clone()
470 | mask[mask > 0] = 1
471 | x = torch.mul(x_raw, mask)
472 | return x
473 | class Shrinkagev3ppp31(nn.Module): #半软阈值降噪
474 | def __init__(self, gap_size, channel):
475 | super(Shrinkagev3ppp31, self).__init__()
476 | self.a = nn.Parameter(torch.tensor([0.48]))
477 | self.gap = nn.AdaptiveAvgPool1d(gap_size)
478 | self.fc = nn.Sequential(
479 | nn.Linear(channel, channel),
480 | nn.BatchNorm1d(channel),
481 | nn.ReLU(inplace=True),
482 | nn.Linear(channel, channel),
483 | nn.Sigmoid(),
484 | )
485 |
486 | def forward(self, x):
487 |
488 | x_raw = x
489 | x_abs = x.abs()
490 | x = self.gap(x)
491 | x = torch.flatten(x, 1)
492 | average = x
493 | # average = torch.mean(x, dim=1, keepdim=True)
494 | x = self.fc(x)
495 | x = torch.mul(average, x).unsqueeze(2)
496 | # soft thresholding
497 | sub = torch.max(x_abs - x, torch.zeros_like(x_abs - x))
498 | # mask = sub.clone()
499 | # mask[mask > 0] = 1
500 | # a = torch.clamp(self.a, min=0, max=1)
501 | x = torch.mul(sub, (1 - self.a))
502 | x = torch.add(torch.mul(torch.sign(x_raw), x), torch.mul(self.a, x_raw))
503 | return x
504 | if __name__ == '__main__':
505 | input = torch.randn(2, 3, 1024).cuda()
506 | model = Shrinkagev3ppp(1, 3, 3).cuda()
507 | for param in model.parameters():
508 | print(type(param.data), param.size())
509 | output = model(input)
510 | print(output.size())
511 |
--------------------------------------------------------------------------------