├── LICENSE
├── README.md
├── offline
├── dpcrn.py
├── dpcrn_onnx.py
├── dpcrn_trt.py
└── models
│ ├── dpcrn.engine
│ ├── dpcrn.onnx
│ ├── dpcrn_simple.onnx
│ └── trtexec.exe
├── online
├── dpcrn.py
├── dpcrn_stream.py
├── dpcrn_stream_onnx.py
├── dpcrn_stream_trt.py
├── models
│ ├── dpcrn.engine
│ ├── dpcrn.onnx
│ ├── dpcrn_simple.onnx
│ └── trtexec.exe
└── modules
│ ├── __init__.py
│ ├── convert.py
│ └── convolution.py
├── readme
├── README_zh.md
├── TRTSETUP.md
└── TRTSETUP_zh.md
└── requirements.txt
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Rong Xiaobin
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deploying Deep Learning Speech Enhancement Models with TensorRT
2 | The Chinese version of the document can be found in: [使用 TensorRT 部署深度学习语音增强模型](./readme/README_zh.md)
3 |
4 | To install TensorRT, refer to: [TensorRT Installation Tutorial](./readme/TRTSETUP.md)
5 |
6 | The deployment of speech enhancement models can be categorized into two types: **offline inference** and **online inference**. Offline inference involves performing model inference on pre-prepared data, usually in the form of a batch of samples or longer audio signals. Offline inference does not have real-time requirements and can utilize efficient inference methods and resource allocation strategies.
7 |
8 | On the other hand, online inference involves performing model inference on real-time generated speech data, such as continuously captured audio signals from a microphone. Online inference requires low latency and high throughput to meet real-time requirements.
9 |
10 | ## Deployment of Offline Models
11 | ### 1. Conversion to ONNX Model
12 | For offline models, exporting to ONNX is straightforward. The only thing to consider is setting the time dimension of the input shape. Although `torch.onnx.export` supports dynamic dimensions, considering the limited need in practical applications, we choose to fix the time dimension to 563, corresponding to 9 seconds of audio data. During offline processing, if the audio is less than 9 seconds, it is padded with zeros. If the audio is longer than 9 seconds, it is processed in batches of 9 seconds each.
13 |
14 | `offline\dpcrn_onnx.py` provides the export and inference of the ONNX model, and evaluates the inference speed on ONNXRuntime.
15 |
16 | ### 2. Conversion to Engine Model
17 | We use the conversion tool `trtexec.exe` provided by TensorRT to convert the model from ONNX to the Engine format supported by TensorRT. This tool is located in the `bin` directory of the TensorRT installation package, and the usage is as follows:
18 | ```
19 | trtexec.exe --onnx=[onnx_path] --saveEngine=[save_path]
20 | ```
21 |
22 | `offline\dpcrn_trt.py` provides the export and inference of the Engine model, and evaluates the inference speed. The results are shown in the following table.
23 |
24 | | **Model Format** | **Inference Framework** | **Inference Platform** | **Average Inference Speed (ms)** | **Maximum Inference Speed (ms)** | **Minimum Inference Speed (ms)** |
25 | |:----------------:|:----------------------:|:---------------------:|:-------------------------------:|:-------------------------------:|:-------------------------------:|
26 | | ONNX | ONNXRuntime | CPU | 8.6 | 21.0 | 7.6|
27 | | Engine |TensorRT| CUDA | 2.2 | 5.1 | 1.9 |
28 |
29 | The CPU used for inference is the 13th Gen Intel(R) Core(TM) i9-13900HX @ 2.20 GHz, and the CUDA used is NVIDIA GeForce RTX 4080 Laptop GPU. The inference process was repeated 1000 times to obtain the average, maximum, and minimum inference speeds. Here, the inference speed is defined as the processing time per audio duration. It can be observed that the inference speed of the TensorRT framework is nearly 4 times faster than ONNXRuntime.
30 |
31 | ## Deployment of Online Models
32 | In speech enhancement, online inference has broader application scenarios and higher real-time requirements for the model. Consequently, the deployment of online inference is more complex. Here, we adopt the method of **streaming inference** to perform frame-by-frame inference on real-time data streams. When implementing streaming inference, appropriate data buffering mechanisms, data stream management, and pipeline design for model inference are required to ensure data continuity and stable inference.
33 |
34 | ### 1. Conversion to Streaming Model
35 | RNN naturally adapts to streaming inference without additional conversion. In contrast, the convolutional layers are the main part of the neural network that requires conversion for streaming. In the `online\modules` directory, we define two types of operators for streaming convolution and streaming transposed convolution in `convolution.py`. We also provide a method in `convert.py` to copy the original model parameter dictionary for the conversion of streaming models.
36 |
37 | `online\dpcrn_stream.py` provides the conversion and inference process for streaming models. Note that for streaming models, the time dimension of the input tensor is always set to 1.
38 |
39 | ### 2. Conversion to ONNX Model
40 | For streaming models, there is no need to consider the time dimension when converting to ONNX. However, it is recommended to specify all input tensors in the `forward` function instead of using a list as done in `online\dpcrn_stream.py`.
41 |
42 | `online\dpcrn_stream_onnx.py` provides the conversion and inference process for streaming ONNX models and evaluates the inference speed on ONNXRuntime.
43 |
44 | ### 3. Conversion to Engine Model
45 | Similarly, we use the conversion tool `trtexec.exe` provided by TensorRT for model conversion.
46 |
47 | `online\dpcrn_stream_trt.py` provides the export and inference of the streaming Engine model, and evaluates the inference speed. The results are shown in the following table.
48 |
49 | | **Model Format** | **Inference Framework** | **Inference Platform** | **Average Inference Speed (ms)** | **Maximum Inference Speed (ms)** | **Minimum Inference Speed (ms)** |
50 | |:----------------:|:----------------------:|:---------------------:|:-------------------------------:|:-------------------------------:|:-------------------------------:|
51 | | ONNX | ONNXRuntime | CPU |1.0 | 3.1 | 0.9 |
52 | | Engine |TensorRT| CUDA |2.2 | 4.7 | 1.8 |
53 |
54 | The evaluation of inference speed is repeated 1000 times. Here, the inference speed is defined as the processing time per frame. As we can see, using TensorRT for inference is slower than using ONNXRuntime. This is because, for high-throughput streaming models, the data transfer between CUDA and CPU during TensorRT inference takes a certain amount of time. Only when the model's inference speed on the CPU becomes the bottleneck, using TensorRT for CUDA inference will have a positive effect.
55 |
56 | ## Acknowledgement
57 | The SE model used in this repository is [DPCRN](https://arxiv.org/abs/2107.05429), an excellent SE model with high performance and low latency, which ranked 2nd in DNS3 Challenge. The author is [Xiaohuai Le](https://github.com/Le-Xiaohuai-speech), my senior, who has taught me a lot, including the streaming conversion method employed in this repository.
58 |
59 |
--------------------------------------------------------------------------------
/offline/dpcrn.py:
--------------------------------------------------------------------------------
1 | """
2 | A more elegant implementation of DPCRN.
3 | 1.74 GMac, 787.15 k
4 | """
5 | import torch
6 | import torch.nn as nn
7 |
8 |
9 | class DPRNN(nn.Module):
10 | def __init__(self, numUnits, width, channel, **kwargs):
11 | super(DPRNN, self).__init__(**kwargs)
12 | self.numUnits = numUnits
13 | self.width = width
14 | self.channel = channel
15 |
16 | self.intra_rnn = nn.LSTM(input_size = self.numUnits, hidden_size = self.numUnits//2, batch_first = True, bidirectional = True)
17 | self.intra_fc = nn.Linear(self.numUnits, self.numUnits)
18 | self.intra_ln = nn.LayerNorm((width, numUnits), eps=1e-8)
19 |
20 | self.inter_rnn = nn.LSTM(input_size = self.numUnits, hidden_size = self.numUnits, batch_first = True, bidirectional = False)
21 | self.inter_fc = nn.Linear(self.numUnits, self.numUnits)
22 | self.inter_ln = nn.LayerNorm((width, numUnits), eps=1e-8)
23 |
24 | def forward(self,x):
25 | # x: (B, C, T, F)
26 | ## Intra RNN
27 | x = x.permute(0, 2, 3, 1) # (B,T,F,C)
28 | intra_x = x.reshape(x.shape[0] * x.shape[1], x.shape[2], x.shape[3]) # (B*T,F,C)
29 | intra_x = self.intra_rnn(intra_x)[0] # (B*T,F,C)
30 | intra_x = self.intra_fc(intra_x) # (B*T,F,C)
31 | intra_x = intra_x.reshape(x.shape[0], -1, self.width, self.channel) # (B,T,F,C)
32 | intra_x = self.intra_ln(intra_x)
33 | intra_out = torch.add(x, intra_x)
34 |
35 | ## Inter RNN
36 | x = intra_out.permute(0,2,1,3) # (B,F,T,C)
37 | inter_x = x.reshape(x.shape[0] * x.shape[1], x.shape[2], x.shape[3])
38 | inter_x = self.inter_rnn(inter_x)[0] # (B*F,T,C)
39 | inter_x = self.inter_fc(inter_x) # (B*F,T,C)
40 | inter_x = inter_x.reshape(x.shape[0], self.width, -1, self.channel) # (B,F,T,C)
41 | inter_x = inter_x.permute(0,2,1,3) # (B,T,F,C)
42 | inter_x = self.inter_ln(inter_x)
43 | inter_out = torch.add(intra_out, inter_x)
44 |
45 | dual_out = inter_out.permute(0,3,1,2) # (B,C,T,F)
46 |
47 | return dual_out
48 |
49 |
50 | class Encoder(nn.Module):
51 | def __init__(self,
52 | in_channels,
53 | filter_size,
54 | kernel_size,
55 | stride_size):
56 | super().__init__()
57 | self.N_layers = len(filter_size)
58 |
59 | self.pad_list = nn.ModuleList([])
60 | self.conv_list = nn.ModuleList([])
61 | self.bn_list = nn.ModuleList([])
62 | self.act_list = nn.ModuleList([])
63 |
64 | for i in range(self.N_layers):
65 | Cin = in_channels if i==0 else filter_size[i-1]
66 | self.pad_list.append(nn.ZeroPad2d([(kernel_size[i][1]-1)//2, (kernel_size[i][1]-1)//2, kernel_size[i][0]-1, 0]))
67 | self.conv_list.append(nn.Conv2d(Cin, filter_size[i], kernel_size[i], stride_size[i]))
68 | self.bn_list.append(nn.BatchNorm2d(filter_size[i]))
69 | self.act_list.append(nn.PReLU())
70 |
71 | def forward(self, x):
72 | """
73 | x: (B,C,T,F)
74 | """
75 | en_outs = []
76 | for i in range(self.N_layers):
77 | x = self.pad_list[i](x)
78 | x = self.conv_list[i](x)
79 | x = self.bn_list[i](x)
80 | x = self.act_list[i](x)
81 | en_outs.append(x)
82 | # print(f'en_{i}:', x.shape)
83 | return x, en_outs
84 |
85 |
86 | class Decoder(nn.Module):
87 | def __init__(self,
88 | out_channels,
89 | filter_size,
90 | kernel_size,
91 | stride_size):
92 | super().__init__()
93 | self.N_layers = len(filter_size)
94 |
95 | self.pad_list = nn.ModuleList([])
96 | self.conv_list = nn.ModuleList([])
97 | self.bn_list = nn.ModuleList([])
98 | self.act_list = nn.ModuleList([])
99 |
100 | for i in range(self.N_layers-1, -1, -1):
101 | Cout = out_channels if i==0 else filter_size[i-1]
102 | act = nn.Identity() if i== 0 else nn.PReLU()
103 | self.pad_list.append(nn.ZeroPad2d([0, 0, kernel_size[i][0]-1, 0]))
104 | self.conv_list.append(nn.ConvTranspose2d(filter_size[i]*2, Cout, kernel_size[i], stride_size[i], (kernel_size[i][0]-1, (kernel_size[i][1]-1)//2)))
105 | self.bn_list.append(nn.BatchNorm2d(Cout))
106 | self.act_list.append(act)
107 |
108 | def forward(self, x, en_outs):
109 | """
110 | x: (B,C,T,F)
111 | """
112 | for i in range(self.N_layers):
113 | x = torch.cat([x, en_outs[self.N_layers-1-i]], dim=1)
114 | x = self.pad_list[i](x)
115 | x = self.conv_list[i](x)
116 | x = self.bn_list[i](x)
117 | x = self.act_list[i](x)
118 | # print(f'de_{i}:', x.shape)
119 | return x
120 |
121 |
122 | class DPCRN(nn.Module):
123 | def __init__(self,
124 | in_channels=2,
125 | out_channels=2,
126 | filter_size=[32,32,32,64,128],
127 | kernel_size=[(2,5), (2,3), (2,3), (2,3), (2,3)],
128 | stride_size=[(1,2), (1,2), (1,2), (1,1), (1,1)],
129 | N_dprnn=2,
130 | num_units=128,
131 | width=33,
132 | **kwargs):
133 | super().__init__()
134 | self.N_dprnn = N_dprnn
135 | self.encoder = Encoder(in_channels, filter_size, kernel_size, stride_size)
136 |
137 | self.dprnns = nn.ModuleList([])
138 | for i in range(N_dprnn):
139 | self.dprnns.append(DPRNN(num_units, width, filter_size[-1]))
140 |
141 | self.decoder = Decoder(out_channels, filter_size, kernel_size, stride_size)
142 |
143 | def forward(self, x):
144 | """
145 | x: (B,F,T,2), noisy spectrogram, where B is batch size, F is frequency bins, T is time frames, and 2 is R/I components.
146 | """
147 | x_ref = x
148 | x = x.permute(0, 3, 2, 1) # (B,C,T,F)
149 |
150 | x, en_outs = self.encoder(x)
151 |
152 | for i in range(self.N_dprnn):
153 | x = self.dprnns[i](x)
154 |
155 | x = self.decoder(x, en_outs)
156 |
157 | m = x.permute(0,3,2,1)
158 |
159 | s_real = x_ref[...,0] * m[...,0] - x_ref[...,1] * m[...,1]
160 | s_imag = x_ref[...,1] * m[...,0] + x_ref[...,0] * m[...,1]
161 | s = torch.stack([s_real, s_imag], dim=-1) # (B,F,T,2)
162 |
163 | return s
164 |
165 |
166 |
167 | if __name__ == "__main__":
168 | model = DPCRN().cuda()
169 |
170 | from ptflops import get_model_complexity_info
171 | flops, params = get_model_complexity_info(model, (257, 63, 2), as_strings=True,
172 | print_per_layer_stat=False, verbose=True)
173 | print(flops, params)
174 |
175 | model = model.cpu().eval()
176 | x = torch.randn(1, 257, 63, 2)
177 | y = model(x)
178 | print(y.shape)
--------------------------------------------------------------------------------
/offline/dpcrn_onnx.py:
--------------------------------------------------------------------------------
1 | """
2 | A more elegant implementation of DPCRN.
3 | 1.74 GMac, 787.15 k
4 | """
5 | import torch
6 | import time
7 | import onnx
8 | import onnxruntime
9 | import numpy as np
10 | from onnxsim import simplify
11 | from dpcrn import DPCRN
12 |
13 |
14 | ## load model
15 | model = DPCRN().eval() # remember to set `eval` mode!
16 |
17 | ## convert to onnx
18 | file = 'models/dpcrn.onnx'
19 | device = torch.device('cpu')
20 |
21 | time_len = 9 # set the length to 9 s
22 | frame_num = time_len * 16000 // 256 + 1 # compute frame numbers, fs=16000, hop_size=256
23 | x = torch.randn(1, 257, frame_num, 2, device=device)
24 |
25 | torch.onnx.export(model,
26 | (x,),
27 | file,
28 | input_names = ['mix'],
29 | output_names = ['enh'],
30 | opset_version=11,
31 | verbose = False)
32 |
33 | onnx_model = onnx.load(file)
34 | onnx.checker.check_model(onnx_model)
35 |
36 | model_simp, check = simplify(onnx_model)
37 | assert check, "Simplified ONNX model could not be validated"
38 | onnx.save(model_simp, file.split('.onnx')[0] + '_simple.onnx')
39 |
40 | ## run onnx model
41 | # session = onnxruntime.InferenceSession(file, None, providers=['CUDAExecutionProvider'])
42 | session = onnxruntime.InferenceSession(file.split('.onnx')[0]+'_simple.onnx', None, providers=['CPUExecutionProvider'])
43 | inputs = x.cpu().detach().numpy()
44 |
45 | ## execute inference
46 | outputs = session.run([], {'mix': inputs})
47 |
48 | ## check error
49 | y = model(x)
50 | diff = outputs - y.detach().numpy()
51 | print(">>> The maximum numerical error:", np.abs(diff).max())
52 |
53 | ## test inference speed
54 | T_list = []
55 | for i in range(1000):
56 | tic = time.perf_counter()
57 | outputs = session.run([], {'mix': inputs})
58 | toc = time.perf_counter()
59 | T_list.append((toc-tic) / time_len)
60 | print(">>> inference time: mean: {:.1f}ms, max: {:.1f}ms, min: {:.1f}ms".format(1e3*np.mean(T_list), 1e3*np.max(T_list), 1e3*np.min(T_list)))
61 |
--------------------------------------------------------------------------------
/offline/dpcrn_trt.py:
--------------------------------------------------------------------------------
1 | """
2 | 1. onnx to engine: use `trtexec.exe`.
3 | 2. tensorrt and torch cannot be imported simultaneously!
4 | """
5 | import os
6 | import numpy as np
7 | import tensorrt as trt
8 | import pycuda.driver as cuda
9 | import pycuda.autoinit # must import
10 | from collections import namedtuple
11 |
12 | def onnx2engine(trtexec_path, onnx_path, save_path):
13 | os.system(f"{trtexec_path} --onnx={onnx_path} --saveEngine={save_path}")
14 |
15 |
16 | Bindings = namedtuple("Bindings", ("name", "shape", "host", "device", "size"))
17 |
18 | class TRTModel:
19 | """
20 | Implements inference for the EfficientNet TensorRT engine.
21 | """
22 |
23 | def __init__(self, engine_path, dtype=np.float32):
24 | """
25 | Args:
26 | engine_path: The path to the serialized engine to load from disk.
27 | dtype: The datatype used in inference.
28 | """
29 | # init arguments
30 | self.engine_path = engine_path
31 | self.dtype = dtype
32 |
33 | # Load TRT engine
34 | self.logger = trt.Logger(trt.Logger.ERROR)
35 | self.runtime = trt.Runtime(self.logger)
36 | self.engine = self.load_engine(self.runtime, self.engine_path)
37 | self.context = self.engine.create_execution_context()
38 | self.inputs, self.outputs, self.bindings, self.stream = self.allocate_buffers()
39 |
40 | @staticmethod
41 | def load_engine(trt_runtime, engine_path):
42 | with open(engine_path, 'rb') as f:
43 | engine_data = f.read()
44 | engine = trt_runtime.deserialize_cuda_engine(engine_data)
45 | return engine
46 |
47 |
48 | def allocate_buffers(self):
49 | inputs = []
50 | outputs = []
51 | bindings = []
52 | stream = cuda.Stream()
53 |
54 | for i in range(self.engine.num_io_tensors):
55 | name = self.engine.get_tensor_name(i)
56 | shape = self.engine.get_tensor_shape(name)
57 | size = trt.volume(shape)
58 | # print(i, name, ':', shape, size)
59 |
60 | host_mem = cuda.pagelocked_empty(size, self.dtype)
61 | device_mem = cuda.mem_alloc(host_mem.nbytes)
62 |
63 | bindings.append(int(device_mem))
64 |
65 | if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
66 | inputs.append(Bindings(name, shape, host_mem, device_mem, host_mem.nbytes))
67 | else:
68 | outputs.append(Bindings(name, shape, host_mem, device_mem, host_mem.nbytes))
69 |
70 | return inputs, outputs, bindings, stream
71 |
72 |
73 | def __call__(self, x: np.ndarray):
74 | x = x.astype(self.dtype)
75 | np.copyto(self.inputs[0].host, x.ravel())
76 |
77 | # Transfer the noisy data from CPU to CUDA.
78 | cuda.memcpy_htod_async(self.inputs[0].device, self.inputs[0].host, self.stream)
79 | # Execute inference.
80 | self.context.execute_async_v2(bindings=self.bindings, stream_handle=self.stream.handle)
81 | # Transfer the enhanced data from CPU to CUDA.
82 | cuda.memcpy_dtoh_async(self.outputs[-1].host, self.outputs[-1].device, self.stream)
83 |
84 | self.stream.synchronize()
85 |
86 | return self.outputs[-1].host.reshape(self.outputs[-1].shape)
87 |
88 |
89 | if __name__ == "__main__":
90 | import time
91 |
92 | trtexec_path = r'.\models\trtexec.exe'
93 | onnx_path = r'.\models\dpcrn.onnx'
94 | save_path = r'.\models\dpcrn.engine'
95 |
96 | ## Convert to engine
97 | onnx2engine(trtexec_path, onnx_path, save_path)
98 |
99 | ## Load engine model
100 | model = TRTModel(save_path)
101 |
102 | ## Execute inference
103 | time_len = 9 # set the length to 9 s
104 | frame_num = time_len * 16000 // 256 + 1 # compute frame numbers, fs=16000, hop_size=256
105 | x = np.random.randn(1, 257, frame_num, 2)
106 | y = model(x)
107 |
108 | ## Test inference speed
109 | times = np.zeros([1000])
110 | for i in range(len(times)):
111 | tic = time.perf_counter()
112 | outputs = model(x)
113 | toc = time.perf_counter()
114 | times[i] = 1000*((toc-tic) / time_len)
115 |
116 | print(">>> Average Inference Time (ms): ", times.mean())
117 | print(">>> Maximum Inference Time (ms): ", times.max())
118 | print(">>> Minimum Inference Time (ms): ", times.min())
119 |
120 |
121 |
122 |
123 |
124 |
125 |
--------------------------------------------------------------------------------
/offline/models/dpcrn.engine:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaobin-Rong/TRT-SE/e35783a695669fd1346ce566dfe73abc7250ee2f/offline/models/dpcrn.engine
--------------------------------------------------------------------------------
/offline/models/dpcrn.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaobin-Rong/TRT-SE/e35783a695669fd1346ce566dfe73abc7250ee2f/offline/models/dpcrn.onnx
--------------------------------------------------------------------------------
/offline/models/dpcrn_simple.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaobin-Rong/TRT-SE/e35783a695669fd1346ce566dfe73abc7250ee2f/offline/models/dpcrn_simple.onnx
--------------------------------------------------------------------------------
/offline/models/trtexec.exe:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaobin-Rong/TRT-SE/e35783a695669fd1346ce566dfe73abc7250ee2f/offline/models/trtexec.exe
--------------------------------------------------------------------------------
/online/dpcrn.py:
--------------------------------------------------------------------------------
1 | """
2 | A more elegant implementation of DPCRN.
3 | 1.74 GMac, 787.15 k
4 | """
5 | import torch
6 | import torch.nn as nn
7 |
8 |
9 | class DPRNN(nn.Module):
10 | def __init__(self, numUnits, width, channel, **kwargs):
11 | super(DPRNN, self).__init__(**kwargs)
12 | self.numUnits = numUnits
13 | self.width = width
14 | self.channel = channel
15 |
16 | self.intra_rnn = nn.LSTM(input_size = self.numUnits, hidden_size = self.numUnits//2, batch_first = True, bidirectional = True)
17 | self.intra_fc = nn.Linear(self.numUnits, self.numUnits)
18 | self.intra_ln = nn.LayerNorm((width, numUnits), eps=1e-8)
19 |
20 | self.inter_rnn = nn.LSTM(input_size = self.numUnits, hidden_size = self.numUnits, batch_first = True, bidirectional = False)
21 | self.inter_fc = nn.Linear(self.numUnits, self.numUnits)
22 | self.inter_ln = nn.LayerNorm((width, numUnits), eps=1e-8)
23 |
24 | def forward(self,x):
25 | # x: (B, C, T, F)
26 | ## Intra RNN
27 | x = x.permute(0, 2, 3, 1) # (B,T,F,C)
28 | intra_x = x.reshape(x.shape[0] * x.shape[1], x.shape[2], x.shape[3]) # (B*T,F,C)
29 | intra_x = self.intra_rnn(intra_x)[0] # (B*T,F,C)
30 | intra_x = self.intra_fc(intra_x) # (B*T,F,C)
31 | intra_x = intra_x.reshape(x.shape[0], -1, self.width, self.channel) # (B,T,F,C)
32 | intra_x = self.intra_ln(intra_x)
33 | intra_out = torch.add(x, intra_x)
34 |
35 | ## Inter RNN
36 | x = intra_out.permute(0,2,1,3) # (B,F,T,C)
37 | inter_x = x.reshape(x.shape[0] * x.shape[1], x.shape[2], x.shape[3])
38 | inter_x = self.inter_rnn(inter_x)[0] # (B*F,T,C)
39 | inter_x = self.inter_fc(inter_x) # (B*F,T,C)
40 | inter_x = inter_x.reshape(x.shape[0], self.width, -1, self.channel) # (B,F,T,C)
41 | inter_x = inter_x.permute(0,2,1,3) # (B,T,F,C)
42 | inter_x = self.inter_ln(inter_x)
43 | inter_out = torch.add(intra_out, inter_x)
44 |
45 | dual_out = inter_out.permute(0,3,1,2) # (B,C,T,F)
46 |
47 | return dual_out
48 |
49 |
50 | class Encoder(nn.Module):
51 | def __init__(self,
52 | in_channels,
53 | filter_size,
54 | kernel_size,
55 | stride_size):
56 | super().__init__()
57 | self.N_layers = len(filter_size)
58 |
59 | self.pad_list = nn.ModuleList([])
60 | self.conv_list = nn.ModuleList([])
61 | self.bn_list = nn.ModuleList([])
62 | self.act_list = nn.ModuleList([])
63 |
64 | for i in range(self.N_layers):
65 | Cin = in_channels if i==0 else filter_size[i-1]
66 | self.pad_list.append(nn.ZeroPad2d([(kernel_size[i][1]-1)//2, (kernel_size[i][1]-1)//2, kernel_size[i][0]-1, 0]))
67 | self.conv_list.append(nn.Conv2d(Cin, filter_size[i], kernel_size[i], stride_size[i]))
68 | self.bn_list.append(nn.BatchNorm2d(filter_size[i]))
69 | self.act_list.append(nn.PReLU())
70 |
71 | def forward(self, x):
72 | """
73 | x: (B,C,T,F)
74 | """
75 | en_outs = []
76 | for i in range(self.N_layers):
77 | x = self.pad_list[i](x)
78 | x = self.conv_list[i](x)
79 | x = self.bn_list[i](x)
80 | x = self.act_list[i](x)
81 | en_outs.append(x)
82 | # print(f'en_{i}:', x.shape)
83 | return x, en_outs
84 |
85 |
86 | class Decoder(nn.Module):
87 | def __init__(self,
88 | out_channels,
89 | filter_size,
90 | kernel_size,
91 | stride_size):
92 | super().__init__()
93 | self.N_layers = len(filter_size)
94 |
95 | self.pad_list = nn.ModuleList([])
96 | self.conv_list = nn.ModuleList([])
97 | self.bn_list = nn.ModuleList([])
98 | self.act_list = nn.ModuleList([])
99 |
100 | for i in range(self.N_layers-1, -1, -1):
101 | Cout = out_channels if i==0 else filter_size[i-1]
102 | act = nn.Identity() if i== 0 else nn.PReLU()
103 | self.pad_list.append(nn.ZeroPad2d([0, 0, kernel_size[i][0]-1, 0]))
104 | self.conv_list.append(nn.ConvTranspose2d(filter_size[i]*2, Cout, kernel_size[i], stride_size[i], (kernel_size[i][0]-1, (kernel_size[i][1]-1)//2)))
105 | self.bn_list.append(nn.BatchNorm2d(Cout))
106 | self.act_list.append(act)
107 |
108 | def forward(self, x, en_outs):
109 | """
110 | x: (B,C,T,F)
111 | """
112 | for i in range(self.N_layers):
113 | x = torch.cat([x, en_outs[self.N_layers-1-i]], dim=1)
114 | x = self.pad_list[i](x)
115 | x = self.conv_list[i](x)
116 | x = self.bn_list[i](x)
117 | x = self.act_list[i](x)
118 | # print(f'de_{i}:', x.shape)
119 | return x
120 |
121 |
122 | class DPCRN(nn.Module):
123 | def __init__(self,
124 | in_channels=2,
125 | out_channels=2,
126 | filter_size=[32,32,32,64,128],
127 | kernel_size=[(2,5), (2,3), (2,3), (2,3), (2,3)],
128 | stride_size=[(1,2), (1,2), (1,2), (1,1), (1,1)],
129 | N_dprnn=2,
130 | num_units=128,
131 | width=33,
132 | **kwargs):
133 | super().__init__()
134 | self.N_dprnn = N_dprnn
135 | self.encoder = Encoder(in_channels, filter_size, kernel_size, stride_size)
136 |
137 | self.dprnns = nn.ModuleList([])
138 | for i in range(N_dprnn):
139 | self.dprnns.append(DPRNN(num_units, width, filter_size[-1]))
140 |
141 | self.decoder = Decoder(out_channels, filter_size, kernel_size, stride_size)
142 |
143 | def forward(self, x):
144 | """
145 | x: (B,F,T,2), noisy spectrogram, where B is batch size, F is frequency bins, T is time frames, and 2 is R/I components.
146 | """
147 | x_ref = x
148 | x = x.permute(0, 3, 2, 1) # (B,C,T,F)
149 |
150 | x, en_outs = self.encoder(x)
151 |
152 | for i in range(self.N_dprnn):
153 | x = self.dprnns[i](x)
154 |
155 | x = self.decoder(x, en_outs)
156 |
157 | m = x.permute(0,3,2,1)
158 |
159 | s_real = x_ref[...,0] * m[...,0] - x_ref[...,1] * m[...,1]
160 | s_imag = x_ref[...,1] * m[...,0] + x_ref[...,0] * m[...,1]
161 | s = torch.stack([s_real, s_imag], dim=-1) # (B,F,T,2)
162 |
163 | return s
164 |
165 |
166 |
167 | if __name__ == "__main__":
168 | model = DPCRN().cuda()
169 |
170 | from ptflops import get_model_complexity_info
171 | flops, params = get_model_complexity_info(model, (257, 63, 2), as_strings=True,
172 | print_per_layer_stat=False, verbose=True)
173 | print(flops, params)
174 |
175 | model = model.cpu().eval()
176 | x = torch.randn(1, 257, 63, 2)
177 | y = model(x)
178 | print(y.shape)
--------------------------------------------------------------------------------
/online/dpcrn_stream.py:
--------------------------------------------------------------------------------
1 | """
2 | A more elegant implementation of streaming DPCRN.
3 | 1.74 GMac, 787.15 k
4 | """
5 | import torch
6 | import torch.nn as nn
7 | from modules.convolution import StreamConv2d, StreamConvTranspose2d
8 | from modules.convert import convert_to_stream
9 |
10 |
11 | class StreamDPRNN(nn.Module):
12 | def __init__(self, numUnits, width, channel, **kwargs):
13 | super().__init__(**kwargs)
14 | self.numUnits = numUnits
15 | self.width = width
16 | self.channel = channel
17 |
18 | self.intra_rnn = nn.LSTM(input_size = self.numUnits, hidden_size = self.numUnits//2, batch_first=True, bidirectional=True)
19 | self.intra_fc = nn.Linear(self.numUnits, self.numUnits)
20 | self.intra_ln = nn.LayerNorm((width, numUnits), eps=1e-8)
21 |
22 | self.inter_rnn = nn.LSTM(input_size = self.numUnits, hidden_size = self.numUnits, batch_first=True, bidirectional=False)
23 | self.inter_fc = nn.Linear(self.numUnits, self.numUnits)
24 | self.inter_ln = nn.LayerNorm((width, numUnits), eps=1e-8)
25 |
26 | def forward(self, x, h_cache, c_cache):
27 | """
28 | x: (B, C, T=1, F)
29 | h_cache: (1, F, C), hidden cache for inter RNN.
30 | c_cache: (1, F, C), cell cache for inter RNN.
31 | """
32 | ## Intra RNN
33 | x = x.permute(0, 2, 3, 1) # (B,T,F,C)
34 | intra_x = x.reshape(x.shape[0] * x.shape[1], x.shape[2], x.shape[3]) # (B*T,F,C)
35 | intra_x = self.intra_rnn(intra_x)[0] # (B*T,F,C)
36 | intra_x = self.intra_fc(intra_x) # (B*T,F,C)
37 | intra_x = intra_x.reshape(x.shape[0], -1, self.width, self.channel) # (B,T,F,C)
38 | intra_x = self.intra_ln(intra_x)
39 | intra_out = torch.add(x, intra_x)
40 |
41 | ## Inter RNN
42 | x = intra_out.permute(0,2,1,3) # (B,F,T,C)
43 | inter_x = x.reshape(x.shape[0] * x.shape[1], x.shape[2], x.shape[3])
44 | inter_x, (h_cache, c_cache) = self.inter_rnn(inter_x, (h_cache, c_cache)) # (B*F,T,C)
45 | inter_x = self.inter_fc(inter_x) # (B*F,T,C)
46 | inter_x = inter_x.reshape(x.shape[0], self.width, -1, self.channel) # (B,F,T,C)
47 | inter_x = inter_x.permute(0,2,1,3) # (B,T,F,C)
48 | inter_x = self.inter_ln(inter_x)
49 | inter_out = torch.add(intra_out, inter_x)
50 |
51 | dual_out = inter_out.permute(0,3,1,2) # (B,C,T,F)
52 |
53 | return dual_out, h_cache, c_cache
54 |
55 |
56 | class StreamEncoder(nn.Module):
57 | def __init__(self,
58 | in_channels,
59 | filter_size,
60 | kernel_size,
61 | stride_size):
62 | super().__init__()
63 | self.N_layers = len(filter_size)
64 |
65 | self.conv_list = nn.ModuleList([])
66 | self.bn_list = nn.ModuleList([])
67 | self.act_list = nn.ModuleList([])
68 |
69 | for i in range(self.N_layers):
70 | Cin = in_channels if i==0 else filter_size[i-1]
71 | self.conv_list.append(StreamConv2d(Cin, filter_size[i], kernel_size[i], stride_size[i], (0, (kernel_size[i][1]-1)//2)))
72 | self.bn_list.append(nn.BatchNorm2d(filter_size[i]))
73 | self.act_list.append(nn.PReLU())
74 |
75 | def forward(self, x, en_caches):
76 | """
77 | x: (B,C,1,F).
78 | en_caches: A list of cache (B,C,1,F) for each conv layer in encoder.
79 | """
80 | en_outs = []
81 | for i in range(self.N_layers):
82 | x, en_caches[i] = self.conv_list[i](x, en_caches[i])
83 | x = self.bn_list[i](x)
84 | x = self.act_list[i](x)
85 | en_outs.append(x)
86 | # print(f'en_{i}:', x.shape)
87 | return x, en_outs, en_caches
88 |
89 |
90 | class StreamDecoder(nn.Module):
91 | def __init__(self,
92 | out_channels,
93 | filter_size,
94 | kernel_size,
95 | stride_size):
96 | super().__init__()
97 | self.N_layers = len(filter_size)
98 |
99 | self.conv_list = nn.ModuleList([])
100 | self.bn_list = nn.ModuleList([])
101 | self.act_list = nn.ModuleList([])
102 |
103 | for i in range(self.N_layers-1, -1, -1):
104 | Cout = out_channels if i==0 else filter_size[i-1]
105 | act = nn.Identity() if i== 0 else nn.PReLU()
106 | self.conv_list.append(StreamConvTranspose2d(filter_size[i]*2, Cout, kernel_size[i], stride_size[i], (0, (kernel_size[i][1]-1)//2)))
107 | self.bn_list.append(nn.BatchNorm2d(Cout))
108 | self.act_list.append(act)
109 |
110 | def forward(self, x, en_outs, de_caches):
111 | """
112 | x: (B,C,T,F)
113 | de_caches: A list of cache (B,C,T,F) for each conv layer in decoder.
114 | """
115 | for i in range(self.N_layers):
116 | x = torch.cat([x, en_outs[self.N_layers-1-i]], dim=1)
117 | x, de_caches[i] = self.conv_list[i](x, de_caches[i])
118 | x = self.bn_list[i](x)
119 | x = self.act_list[i](x)
120 | # print(f'de_{i}:', x.shape)
121 | return x, de_caches
122 |
123 |
124 | class StreamDPCRN(nn.Module):
125 | def __init__(self,
126 | in_channels=2,
127 | out_channels=2,
128 | filter_size=[32,32,32,64,128],
129 | kernel_size=[(2,5), (2,3), (2,3), (2,3), (2,3)],
130 | stride_size=[(1,2), (1,2), (1,2), (1,1), (1,1)],
131 | N_dprnn=2,
132 | num_units=128,
133 | width=33,
134 | **kwargs):
135 | super().__init__()
136 | self.N_dprnn = N_dprnn
137 | self.encoder = StreamEncoder(in_channels, filter_size, kernel_size, stride_size)
138 |
139 | self.dprnns = nn.ModuleList([])
140 | for i in range(N_dprnn):
141 | self.dprnns.append(StreamDPRNN(num_units, width, filter_size[-1]))
142 |
143 | self.decoder = StreamDecoder(out_channels, filter_size, kernel_size, stride_size)
144 |
145 | def forward(self, x, en_caches, rnn_caches, de_caches):
146 | """
147 | x: (B,F,T,2)
148 | en_caces: A list of cache (B,C,T,F) for each conv layer in encoder.
149 | rnn_caches: A list of cache (1,F,C) for each rnn in dprnns.
150 | en_caces: A list of cache (B,C,T,F) of each conv layer in decoder.
151 | """
152 | x_ref = x
153 | x = x.permute(0, 3, 2, 1) # (B,C,T,F)
154 |
155 | x, en_outs, en_caches = self.encoder(x, en_caches)
156 |
157 | for i in range(self.N_dprnn):
158 | x, rnn_caches[2*i], rnn_caches[2*i+1] = self.dprnns[i](x, rnn_caches[2*i], rnn_caches[2*i+1])
159 |
160 | x, de_caches = self.decoder(x, en_outs, de_caches)
161 |
162 | m = x.permute(0,3,2,1)
163 |
164 | s_real = x_ref[...,0] * m[...,0] - x_ref[...,1] * m[...,1]
165 | s_imag = x_ref[...,1] * m[...,0] + x_ref[...,0] * m[...,1]
166 | s = torch.stack([s_real, s_imag], dim=-1) # (B,F,T,2)
167 |
168 | return s, en_caches, rnn_caches, de_caches
169 |
170 |
171 |
172 | if __name__ == "__main__":
173 | from dpcrn import DPCRN
174 | model = DPCRN().eval()
175 | model_stream = StreamDPCRN().eval()
176 | convert_to_stream(model_stream, model)
177 |
178 | x = torch.randn(1, 257, 63, 2)
179 | en_caches = [torch.zeros(1, 2, 1, 257),
180 | torch.zeros(1, 32, 1, 129),
181 | torch.zeros(1, 32, 1, 65),
182 | torch.zeros(1, 32, 1, 33),
183 | torch.zeros(1, 64,1, 33)]
184 |
185 | rnn_caches = [torch.zeros(1,33,128),
186 | torch.zeros(1,33,128),
187 | torch.zeros(1,33,128),
188 | torch.zeros(1,33,128)]
189 |
190 | de_caches = [torch.zeros(1, 256,1, 33),
191 | torch.zeros(1, 128,1, 33),
192 | torch.zeros(1, 64, 1, 33),
193 | torch.zeros(1, 64, 1, 65),
194 | torch.zeros(1, 64, 1, 129)]
195 |
196 | y1 = []
197 | for i in range(x.shape[-2]):
198 | yi, en_caches, rnn_caches, de_caches = model_stream(x[:,:,i:i+1,:], en_caches, rnn_caches, de_caches)
199 | y1.append(yi)
200 | y1 = torch.cat(y1, dim=2)
201 |
202 | ## check errors
203 | y = model(x)
204 | print((y-y1).abs().max())
205 |
--------------------------------------------------------------------------------
/online/dpcrn_stream_onnx.py:
--------------------------------------------------------------------------------
1 | """
2 | A more elegant implementation of DPCRN.
3 | 1.74 GMac, 787.15 k
4 | """
5 | import torch
6 | import torch.nn as nn
7 | from modules.convolution import StreamConv2d, StreamConvTranspose2d
8 | from modules.convert import convert_to_stream
9 |
10 |
11 | class StreamDPRNN(nn.Module):
12 | def __init__(self, numUnits, width, channel, **kwargs):
13 | super().__init__(**kwargs)
14 | self.numUnits = numUnits
15 | self.width = width
16 | self.channel = channel
17 |
18 | self.intra_rnn = nn.LSTM(input_size = self.numUnits, hidden_size = self.numUnits//2, batch_first=True, bidirectional=True)
19 | self.intra_fc = nn.Linear(self.numUnits, self.numUnits)
20 | self.intra_ln = nn.LayerNorm((width, numUnits), eps=1e-8)
21 |
22 | self.inter_rnn = nn.LSTM(input_size = self.numUnits, hidden_size = self.numUnits, batch_first=True, bidirectional=False)
23 | self.inter_fc = nn.Linear(self.numUnits, self.numUnits)
24 | self.inter_ln = nn.LayerNorm((width, numUnits), eps=1e-8)
25 |
26 | def forward(self, x, h_cache, c_cache):
27 | """
28 | x: (B, C, T=1, F)
29 | h_cache: (1, F, C), hidden cache for inter RNN.
30 | c_cache: (1, F, C), cell cache for inter RNN.
31 | """
32 | ## Intra RNN
33 | x = x.permute(0, 2, 3, 1) # (B,T,F,C)
34 | intra_x = x.reshape(x.shape[0] * x.shape[1], x.shape[2], x.shape[3]) # (B*T,F,C)
35 | intra_x = self.intra_rnn(intra_x)[0] # (B*T,F,C)
36 | intra_x = self.intra_fc(intra_x) # (B*T,F,C)
37 | intra_x = intra_x.reshape(x.shape[0], -1, self.width, self.channel) # (B,T,F,C)
38 | intra_x = self.intra_ln(intra_x)
39 | intra_out = torch.add(x, intra_x)
40 |
41 | ## Inter RNN
42 | x = intra_out.permute(0,2,1,3) # (B,F,T,C)
43 | inter_x = x.reshape(x.shape[0] * x.shape[1], x.shape[2], x.shape[3])
44 | inter_x, (h_cache, c_cache) = self.inter_rnn(inter_x, (h_cache, c_cache)) # (B*F,T,C)
45 | inter_x = self.inter_fc(inter_x) # (B*F,T,C)
46 | inter_x = inter_x.reshape(x.shape[0], self.width, -1, self.channel) # (B,F,T,C)
47 | inter_x = inter_x.permute(0,2,1,3) # (B,T,F,C)
48 | inter_x = self.inter_ln(inter_x)
49 | inter_out = torch.add(intra_out, inter_x)
50 |
51 | dual_out = inter_out.permute(0,3,1,2) # (B,C,T,F)
52 |
53 | return dual_out, h_cache, c_cache
54 |
55 |
56 | class StreamEncoder(nn.Module):
57 | def __init__(self,
58 | in_channels,
59 | filter_size,
60 | kernel_size,
61 | stride_size):
62 | super().__init__()
63 | self.N_layers = len(filter_size)
64 |
65 | self.conv_list = nn.ModuleList([])
66 | self.bn_list = nn.ModuleList([])
67 | self.act_list = nn.ModuleList([])
68 |
69 | for i in range(self.N_layers):
70 | Cin = in_channels if i==0 else filter_size[i-1]
71 | self.conv_list.append(StreamConv2d(Cin, filter_size[i], kernel_size[i], stride_size[i], (0, (kernel_size[i][1]-1)//2)))
72 | self.bn_list.append(nn.BatchNorm2d(filter_size[i]))
73 | self.act_list.append(nn.PReLU())
74 |
75 | def forward(self, x, en_caches):
76 | """
77 | x: (B,C,1,F).
78 | en_caches: A list of cache (B,C,1,F) for each conv layer in encoder.
79 | """
80 | en_outs = []
81 | for i in range(self.N_layers):
82 | x, en_caches[i] = self.conv_list[i](x, en_caches[i])
83 | x = self.bn_list[i](x)
84 | x = self.act_list[i](x)
85 | en_outs.append(x)
86 | # print(f'en_{i}:', x.shape)
87 | return x, en_outs, en_caches
88 |
89 |
90 | class StreamDecoder(nn.Module):
91 | def __init__(self,
92 | out_channels,
93 | filter_size,
94 | kernel_size,
95 | stride_size):
96 | super().__init__()
97 | self.N_layers = len(filter_size)
98 |
99 | self.conv_list = nn.ModuleList([])
100 | self.bn_list = nn.ModuleList([])
101 | self.act_list = nn.ModuleList([])
102 |
103 | for i in range(self.N_layers-1, -1, -1):
104 | Cout = out_channels if i==0 else filter_size[i-1]
105 | act = nn.Identity() if i== 0 else nn.PReLU()
106 | self.conv_list.append(StreamConvTranspose2d(filter_size[i]*2, Cout, kernel_size[i], stride_size[i], (0, (kernel_size[i][1]-1)//2)))
107 | self.bn_list.append(nn.BatchNorm2d(Cout))
108 | self.act_list.append(act)
109 |
110 | def forward(self, x, en_outs, de_caches):
111 | """
112 | x: (B,C,T,F)
113 | de_caches: A list of cache (B,C,T,F) for each conv layer in decoder.
114 | """
115 | for i in range(self.N_layers):
116 | x = torch.cat([x, en_outs[self.N_layers-1-i]], dim=1)
117 | x, de_caches[i] = self.conv_list[i](x, de_caches[i])
118 | x = self.bn_list[i](x)
119 | x = self.act_list[i](x)
120 | # print(f'de_{i}:', x.shape)
121 | return x, de_caches
122 |
123 |
124 | class StreamDPCRN(nn.Module):
125 | def __init__(self,
126 | in_channels=2,
127 | out_channels=2,
128 | filter_size=[32,32,32,64,128],
129 | kernel_size=[(2,5), (2,3), (2,3), (2,3), (2,3)],
130 | stride_size=[(1,2), (1,2), (1,2), (1,1), (1,1)],
131 | N_dprnn=2,
132 | num_units=128,
133 | width=33,
134 | **kwargs):
135 | super().__init__()
136 | self.encoder = StreamEncoder(in_channels, filter_size, kernel_size, stride_size)
137 |
138 | self.dprnns = nn.ModuleList([])
139 | for i in range(N_dprnn):
140 | self.dprnns.append(StreamDPRNN(num_units, width, filter_size[-1]))
141 |
142 | self.decoder = StreamDecoder(out_channels, filter_size, kernel_size, stride_size)
143 |
144 | def forward(self, x,
145 | en_cache1, en_cache2, en_cache3, en_cache4, en_cache5,
146 | rnn_cache1, rnn_cache2, rnn_cache3, rnn_cache4,
147 | de_cache1, de_cache2, de_cache3, de_cache4, de_cache5):
148 | """
149 | When export to onnx, the inputs should be specific tensors.
150 | x: (B,F,T,2)
151 | cachees: ...
152 | """
153 | x_ref = x
154 | x = x.permute(0, 3, 2, 1) # (B,C,T,F)
155 |
156 | x, en_outs, [en_cache1, en_cache2, en_cache3, en_cache4, en_cache5] = self.encoder(x, [en_cache1, en_cache2, en_cache3, en_cache4, en_cache5])
157 |
158 | x, rnn_cache1, rnn_cache2 = self.dprnns[0](x, rnn_cache1, rnn_cache2)
159 | x, rnn_cache3, rnn_cache4 = self.dprnns[1](x, rnn_cache3, rnn_cache4)
160 |
161 | x, [de_cache1, de_cache2, de_cache3, de_cache4, de_cache5] = self.decoder(x, en_outs, [de_cache1, de_cache2, de_cache3, de_cache4, de_cache5])
162 |
163 | m = x.permute(0,3,2,1)
164 |
165 | s_real = x_ref[...,0] * m[...,0] - x_ref[...,1] * m[...,1]
166 | s_imag = x_ref[...,1] * m[...,0] + x_ref[...,0] * m[...,1]
167 | s = torch.stack([s_real, s_imag], dim=-1) # (B,F,T,2)
168 |
169 | return s, en_cache1, en_cache2, en_cache3, en_cache4, en_cache5,\
170 | rnn_cache1, rnn_cache2, rnn_cache3, rnn_cache4,\
171 | de_cache1, de_cache2, de_cache3, de_cache4, de_cache5
172 |
173 |
174 |
175 | if __name__ == "__main__":
176 | from dpcrn import DPCRN
177 | model = DPCRN().eval()
178 | model_stream = StreamDPCRN().eval()
179 | convert_to_stream(model_stream, model)
180 |
181 | x = torch.randn(1, 257, 1000, 2)
182 | en_cache1 = torch.zeros(1, 2, 1, 257)
183 | en_cache2 = torch.zeros(1, 32, 1, 129)
184 | en_cache3 = torch.zeros(1, 32, 1, 65)
185 | en_cache4 = torch.zeros(1, 32, 1, 33)
186 | en_cache5 = torch.zeros(1, 64,1, 33)
187 |
188 | rnn_cache1 = torch.zeros(1, 33, 128)
189 | rnn_cache2 = torch.zeros(1, 33, 128)
190 | rnn_cache3 = torch.zeros(1, 33, 128)
191 | rnn_cache4 = torch.zeros(1, 33, 128)
192 |
193 | de_cache1 = torch.zeros(1, 256,1, 33)
194 | de_cache2 = torch.zeros(1, 128,1, 33)
195 | de_cache3 = torch.zeros(1, 64, 1, 33)
196 | de_cache4 = torch.zeros(1, 64, 1, 65)
197 | de_cache5 = torch.zeros(1, 64, 1, 129)
198 |
199 | y1 = []
200 | for i in range(x.shape[-2]):
201 | yi, en_cache1, en_cache2, en_cache3, en_cache4, en_cache5,\
202 | rnn_cache1, rnn_cache2, rnn_cache3, rnn_cache4,\
203 | de_cache1, de_cache2, de_cache3, de_cache4, de_cache5 = \
204 | model_stream(x[:,:,i:i+1,:], en_cache1, en_cache2, en_cache3, en_cache4, en_cache5,
205 | rnn_cache1, rnn_cache2, rnn_cache3, rnn_cache4,
206 | de_cache1, de_cache2, de_cache3, de_cache4, de_cache5)
207 | y1.append(yi)
208 | y1 = torch.cat(y1, dim=2)
209 |
210 | ## check streaming errors
211 | y = model(x)
212 | print((y-y1).abs().max())
213 |
214 |
215 | import time
216 | import onnx
217 | import onnxruntime
218 | import numpy as np
219 | from onnxsim import simplify
220 | ## convert to onnx
221 | file = 'models/dpcrn.onnx'
222 | device = torch.device('cpu')
223 | input = torch.randn(1, 257, 1, 2, device=device)
224 | torch.onnx.export(model_stream,
225 | (input, en_cache1, en_cache2, en_cache3, en_cache4, en_cache5,
226 | rnn_cache1, rnn_cache2, rnn_cache3, rnn_cache4,
227 | de_cache1, de_cache2, de_cache3, de_cache4, de_cache5),
228 | file,
229 | input_names = ['mix', 'en_cache1', 'en_cache2', 'en_cache3', 'en_cache4', 'en_cache5',
230 | 'rnn_cache1', 'rnn_cache2', 'rnn_cache3', 'rnn_cache4',
231 | 'de_cache1', 'de_cache2', 'de_cache3', 'de_cache4', 'de_cache5'],
232 | output_names = ['enh', 'en_cache1_out', 'en_cache2_out', 'en_cache3_out', 'en_cache4_out', 'en_cache5_out',
233 | 'rnn_cache1_out', 'rnn_cache2_out', 'rnn_cache3_out', 'rnn_cache4_out',
234 | 'de_cache1_out', 'de_cache2_out', 'de_cache3_out', 'de_cache4_out', 'de_cache5_out'],
235 | opset_version=11,
236 | verbose = False)
237 |
238 | onnx_model = onnx.load(file)
239 | onnx.checker.check_model(onnx_model)
240 |
241 | model_simp, check = simplify(onnx_model)
242 | assert check, "Simplified ONNX model could not be validated"
243 | onnx.save(model_simp, file.split('.onnx')[0] + '_simple.onnx')
244 |
245 | ## run onnx model
246 | # session = onnxruntime.InferenceSession(file, None, providers=['CUDAExecutionProvider'])
247 | session = onnxruntime.InferenceSession(file.split('.onnx')[0]+'_simple.onnx', None, providers=['CPUExecutionProvider'])
248 | input = x.cpu().detach().numpy()
249 | en_cache1 = torch.zeros(1, 2, 1, 257).numpy()
250 | en_cache2 = torch.zeros(1, 32, 1, 129).numpy()
251 | en_cache3 = torch.zeros(1, 32, 1, 65).numpy()
252 | en_cache4 = torch.zeros(1, 32, 1, 33).numpy()
253 | en_cache5 = torch.zeros(1, 64,1, 33).numpy()
254 |
255 | rnn_cache1 = torch.zeros(1, 33, 128).numpy()
256 | rnn_cache2 = torch.zeros(1, 33, 128).numpy()
257 | rnn_cache3 = torch.zeros(1, 33, 128).numpy()
258 | rnn_cache4 = torch.zeros(1, 33, 128).numpy()
259 |
260 | de_cache1 = torch.zeros(1, 256,1, 33).numpy()
261 | de_cache2 = torch.zeros(1, 128,1, 33).numpy()
262 | de_cache3 = torch.zeros(1, 64, 1, 33).numpy()
263 | de_cache4 = torch.zeros(1, 64, 1, 65).numpy()
264 | de_cache5 = torch.zeros(1, 64, 1, 129).numpy()
265 |
266 | T_list = []
267 | outputs = []
268 |
269 | for i in range(input.shape[-2]):
270 | tic = time.perf_counter()
271 |
272 | out_i, en_cache1, en_cache2, en_cache3, en_cache4, en_cache5,\
273 | rnn_cache1, rnn_cache2, rnn_cache3, rnn_cache4,\
274 | de_cache1, de_cache2, de_cache3, de_cache4, de_cache5 \
275 | = session.run([], {'mix': input[..., i:i+1, :],
276 | 'en_cache1': en_cache1, 'en_cache2':en_cache2, 'en_cache3':en_cache3, 'en_cache4':en_cache4, 'en_cache5':en_cache5,
277 | 'rnn_cache1':rnn_cache1, 'rnn_cache2':rnn_cache2, 'rnn_cache3':rnn_cache3, 'rnn_cache4':rnn_cache4,
278 | 'de_cache1':de_cache1, 'de_cache2':de_cache2, 'de_cache3':de_cache3, 'de_cache4':de_cache4, 'de_cache5':de_cache5})
279 |
280 | toc = time.perf_counter()
281 | T_list.append(toc-tic)
282 | outputs.append(out_i)
283 | outputs = np.concatenate(outputs, axis=2)
284 | ## check onnx errors
285 | print(np.abs(outputs - y.detach().numpy()).max())
286 |
287 | ## evaluate inference speed
288 | print(">>> inference time: mean: {:.1f}ms, max: {:.1f}ms, min: {:.1f}ms".format(1e3*np.mean(T_list), 1e3*np.max(T_list), 1e3*np.min(T_list)))
--------------------------------------------------------------------------------
/online/dpcrn_stream_trt.py:
--------------------------------------------------------------------------------
1 | """
2 | 1. onnx to engine: use `trtexec.exe`.
3 | 2. tensorrt and torch cannot be imported simultaneously!
4 | 3. During streaming inference of the engine model, if we set outputs=[] and then append frame by frame,
5 | it will result in each element in outputs being the result of the last frame (reason unknown).
6 | The solution is to set outputs=np.zeros((1, 2, T, F)) and then assign values frame by frame.
7 | """
8 | import os
9 | import numpy as np
10 | import tensorrt as trt
11 | import pycuda.driver as cuda
12 | import pycuda.autoinit # must import
13 | from collections import namedtuple
14 |
15 |
16 | def onnx2engine(trtexec_path, onnx_path, save_path):
17 | os.system(f"{trtexec_path} --onnx={onnx_path} --saveEngine={save_path}")
18 |
19 |
20 | Bindings = namedtuple("Bindings", ("name", "shape", "host", "device", "size"))
21 |
22 | class TRTModel:
23 | """
24 | Implements inference for the EfficientNet TensorRT engine.
25 | """
26 |
27 | def __init__(self, engine_path, dtype=np.float32):
28 | """
29 | Args:
30 | engine_path: The path to the serialized engine to load from disk.
31 | dtype: The datatype used in inference.
32 | """
33 | # init arguments
34 | self.engine_path = engine_path
35 | self.dtype = dtype
36 |
37 | # Load TRT engine
38 | self.logger = trt.Logger(trt.Logger.ERROR)
39 | self.runtime = trt.Runtime(self.logger)
40 | self.engine = self.load_engine(self.runtime, self.engine_path)
41 | self.context = self.engine.create_execution_context()
42 | self.inputs, self.outputs, self.bindings, self.stream = self.allocate_buffers()
43 |
44 | @staticmethod
45 | def load_engine(trt_runtime, engine_path):
46 | with open(engine_path, 'rb') as f:
47 | engine_data = f.read()
48 | engine = trt_runtime.deserialize_cuda_engine(engine_data)
49 | return engine
50 |
51 |
52 | def allocate_buffers(self):
53 | inputs = []
54 | outputs = []
55 | bindings = []
56 | stream = cuda.Stream()
57 |
58 | for i in range(self.engine.num_io_tensors):
59 | name = self.engine.get_tensor_name(i)
60 | shape = self.engine.get_tensor_shape(name)
61 | size = trt.volume(shape)
62 | # print(i, name, ':', shape, size)
63 | ## 0 mix : (1, 257, 1, 2) 514
64 | ## 1 en_cache1 : (1, 2, 1, 257) 514
65 | ## 2 en_cache2 : (1, 32, 1, 129) 4128
66 | ## 3 en_cache3 : (1, 32, 1, 65) 2080
67 | ## 4 en_cache4 : (1, 32, 1, 33) 1056
68 | ## 5 en_cache5 : (1, 64, 1, 33) 2112
69 | ## 6 rnn_cache1 : (1, 33, 128) 4224
70 | ## 7 rnn_cache2 : (1, 33, 128) 4224
71 | ## 8 rnn_cache3 : (1, 33, 128) 4224
72 | ## 9 rnn_cache4 : (1, 33, 128) 4224
73 | ## 10 de_cache1 : (1, 256, 1, 33) 8448
74 | ## 11 de_cache2 : (1, 128, 1, 33) 4224
75 | ## 12 de_cache3 : (1, 64, 1, 33) 2112
76 | ## 13 de_cache4 : (1, 64, 1, 65) 4160
77 | ## 14 de_cache5 : (1, 64, 1, 129) 8256
78 | ## 15 en_cache1_out : (1, 2, 1, 257) 514
79 | ## 16 en_cache2_out : (1, 32, 1, 129) 4128
80 | ## 17 en_cache3_out : (1, 32, 1, 65) 2080
81 | ## 18 en_cache4_out : (1, 32, 1, 33) 1056
82 | ## 19 en_cache5_out : (1, 64, 1, 33) 2112
83 | ## 20 rnn_cache1_out : (1, 33, 128) 4224
84 | ## 21 rnn_cache2_out : (1, 33, 128) 4224
85 | ## 22 rnn_cache3_out : (1, 33, 128) 4224
86 | ## 23 rnn_cache4_out : (1, 33, 128) 4224
87 | ## 24 de_cache1_out : (1, 256, 1, 33) 8448
88 | ## 25 de_cache2_out : (1, 128, 1, 33) 4224
89 | ## 10 de_cache1 : (1, 256, 1, 33) 8448
90 | ## 11 de_cache2 : (1, 128, 1, 33) 4224
91 | ## 12 de_cache3 : (1, 64, 1, 33) 2112
92 | ## 13 de_cache4 : (1, 64, 1, 65) 4160
93 | ## 14 de_cache5 : (1, 64, 1, 129) 8256
94 | ## 15 en_cache1_out : (1, 2, 1, 257) 514
95 | ## 16 en_cache2_out : (1, 32, 1, 129) 4128
96 | ## 17 en_cache3_out : (1, 32, 1, 65) 2080
97 | ## 18 en_cache4_out : (1, 32, 1, 33) 1056
98 | ## 19 en_cache5_out : (1, 64, 1, 33) 2112
99 | ## 20 rnn_cache1_out : (1, 33, 128) 4224
100 | ## 21 rnn_cache2_out : (1, 33, 128) 4224
101 | ## 22 rnn_cache3_out : (1, 33, 128) 4224
102 | ## 23 rnn_cache4_out : (1, 33, 128) 4224
103 | ## 24 de_cache1_out : (1, 256, 1, 33) 8448
104 | ## 25 de_cache2_out : (1, 128, 1, 33) 4224
105 | ## 26 de_cache3_out : (1, 64, 1, 33) 2112
106 | ## 27 de_cache4_out : (1, 64, 1, 65) 4160
107 | ## 28 de_cache5_out : (1, 64, 1, 129) 8256
108 | ## 29 enh : (1, 257, 1, 2) 514
109 | host_mem = cuda.pagelocked_empty(size, self.dtype)
110 | device_mem = cuda.mem_alloc(host_mem.nbytes)
111 |
112 | bindings.append(int(device_mem))
113 |
114 | if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
115 | inputs.append(Bindings(name, shape, host_mem, device_mem, host_mem.nbytes))
116 | else:
117 | outputs.append(Bindings(name, shape, host_mem, device_mem, host_mem.nbytes))
118 |
119 | return inputs, outputs, bindings, stream
120 |
121 |
122 | def __call__(self, x: np.ndarray):
123 | x = x.astype(self.dtype)
124 | np.copyto(self.inputs[0].host, x.ravel())
125 |
126 | # Transfer the current frame of noisy data from CPU to CUDA.
127 | cuda.memcpy_htod_async(self.inputs[0].device, self.inputs[0].host, self.stream)
128 | # Execute inference.
129 | self.context.execute_async_v2(bindings=self.bindings, stream_handle=self.stream.handle)
130 |
131 | # Copy cache_out to cache_in within CUDA.
132 | for i in range(1, 15):
133 | # print(self.inputs[i].name, self.outputs[i-1].name)
134 | assert(self.outputs[i-1].size == self.inputs[i].size)
135 | cuda.memcpy_dtod_async(self.inputs[i].device, self.outputs[i-1].device, self.outputs[i-1].size, self.stream)
136 |
137 | # Transfer the current frame of enhanced data from CUDA to CPU.
138 | cuda.memcpy_dtoh_async(self.outputs[-1].host, self.outputs[-1].device, self.stream)
139 |
140 | self.stream.synchronize()
141 |
142 | return self.outputs[-1].host.reshape(self.outputs[-1].shape)
143 |
144 |
145 | if __name__ == "__main__":
146 | import time
147 |
148 | trtexec_path = r'.\models\trtexec.exe'
149 | onnx_path = r'.\models\dpcrn.onnx'
150 | save_path = r'.\models\dpcrn.engine'
151 |
152 | ## Convert to engine
153 | onnx2engine(trtexec_path, onnx_path, save_path)
154 |
155 | ## Load engine model
156 | model = TRTModel(save_path)
157 |
158 | x = np.random.randn(1, 257, 1000, 2)
159 |
160 | times = np.zeros([x.shape[-2]])
161 | outputs = np.zeros([1, 257, x.shape[-2], 2]) # [1, F, T, 2]
162 | for i in range(x.shape[-2]):
163 | tic = time.perf_counter()
164 | out_i = model(x[:,:, i:i+1,:])
165 | toc = time.perf_counter()
166 |
167 | outputs[:,:,i:i+1,:] = out_i
168 | times[i] = 1000*(toc-tic)
169 |
170 | print("Average Inference Time (ms): ", times.mean())
171 | print("Maximum Inference Time (ms): ", times.max())
172 | print("Minimum Inference Time (ms): ", times.min())
173 |
174 |
175 |
176 |
177 |
178 |
179 |
--------------------------------------------------------------------------------
/online/models/dpcrn.engine:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaobin-Rong/TRT-SE/e35783a695669fd1346ce566dfe73abc7250ee2f/online/models/dpcrn.engine
--------------------------------------------------------------------------------
/online/models/dpcrn.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaobin-Rong/TRT-SE/e35783a695669fd1346ce566dfe73abc7250ee2f/online/models/dpcrn.onnx
--------------------------------------------------------------------------------
/online/models/dpcrn_simple.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaobin-Rong/TRT-SE/e35783a695669fd1346ce566dfe73abc7250ee2f/online/models/dpcrn_simple.onnx
--------------------------------------------------------------------------------
/online/models/trtexec.exe:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaobin-Rong/TRT-SE/e35783a695669fd1346ce566dfe73abc7250ee2f/online/models/trtexec.exe
--------------------------------------------------------------------------------
/online/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaobin-Rong/TRT-SE/e35783a695669fd1346ce566dfe73abc7250ee2f/online/modules/__init__.py
--------------------------------------------------------------------------------
/online/modules/convert.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | def convert_to_stream(stream_model, model):
4 | state_dict = model.state_dict()
5 | new_state_dict = stream_model.state_dict()
6 |
7 | for key in stream_model.state_dict().keys():
8 | if key in state_dict.keys():
9 | new_state_dict[key] = state_dict[key]
10 |
11 | elif key.replace('Conv1d.', '') in state_dict.keys():
12 | new_state_dict[key] = state_dict[key.replace('Conv1d.', '')]
13 |
14 | elif key.replace('Conv2d.', '') in state_dict.keys():
15 | new_state_dict[key] = state_dict[key.replace('Conv2d.', '')]
16 |
17 | elif key.replace('ConvTranspose2d.', '') in state_dict.keys():
18 | new_state_dict[key] = state_dict[key.replace('ConvTranspose2d.', '')]
19 |
20 | else:
21 | raise(ValueError('key error!'))
22 |
23 | stream_model.load_state_dict(new_state_dict)
24 |
--------------------------------------------------------------------------------
/online/modules/convolution.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Sat Dec 3 17:32:08 2022
4 | Modified on Tue Jan 9 17:47:18 2024
5 |
6 | @author: Xiaohuai Le, Xiaobin Rong
7 | """
8 | import torch
9 | import torch.nn as nn
10 | from typing import List, Tuple, Union
11 |
12 | """
13 | When export to ONNX format, ensure that the cache is saved as a tensor, not a list.
14 | """
15 |
16 | class StreamConv1d(nn.Module):
17 | def __init__(self,
18 | in_channels: int,
19 | out_channels: int,
20 | kernel_size: int,
21 | stride: int=1,
22 | padding: int=0,
23 | dilation: int=1,
24 | groups: int=1,
25 | bias: bool=True,
26 | *args, **kargs):
27 | super(StreamConv1d, self).__init__(*args, *kargs)
28 |
29 | assert padding == 0, "To meet the demands of causal streaming requirements"
30 |
31 | self.Conv1d = nn.Conv1d(in_channels = in_channels,
32 | out_channels = out_channels,
33 | kernel_size = kernel_size,
34 | stride = stride,
35 | padding = padding,
36 | dilation = dilation,
37 | groups = groups,
38 | bias = bias)
39 |
40 | def forward(self, x, cache):
41 | """
42 | x: [bs, C, T_size]
43 | cache: [bs, C, T_size-1]
44 | """
45 | inp = torch.cat([cache, x], dim=-1)
46 | oup = self.Conv1d(inp)
47 | out_cache = inp[..., 1:]
48 | return oup, out_cache
49 |
50 |
51 | class StreamConv2d(nn.Module):
52 | def __init__(self,
53 | in_channels: int,
54 | out_channels: int,
55 | kernel_size: Union[int, Tuple[int, int]],
56 | stride: Union[int, Tuple[int, int]] = 1,
57 | padding: Union[str, int, Tuple[int, int]] = 0,
58 | dilation: Union[int, Tuple[int, int]] = 1,
59 | groups: int = 1,
60 | bias: bool = True,
61 | *args, **kargs):
62 | super().__init__(*args, **kargs)
63 | """
64 | kernel_size = [T_size, F_size] by defalut
65 | """
66 | if type(padding) is int:
67 | self.T_pad = padding
68 | self.F_pad = padding
69 | elif type(padding) in [list, tuple]:
70 | self.T_pad, self.F_pad = padding
71 | else:
72 | raise ValueError('Invalid padding size.')
73 |
74 | assert self.T_pad == 0, "To meet the demands of causal streaming requirements"
75 |
76 | self.Conv2d = nn.Conv2d(in_channels = in_channels,
77 | out_channels = out_channels,
78 | kernel_size = kernel_size,
79 | stride = stride,
80 | padding = padding,
81 | dilation = dilation,
82 | groups = groups,
83 | bias = bias)
84 |
85 | def forward(self, x, cache):
86 | """
87 | x: [bs, C, 1, F]
88 | cache: [bs, C, T_size-1, F]
89 | """
90 | inp = torch.cat([cache, x], dim=2)
91 | outp = self.Conv2d(inp)
92 | out_cache = inp[:,:, 1:]
93 | return outp, out_cache
94 |
95 |
96 | class StreamConvTranspose2d(nn.Module):
97 | def __init__(self,
98 | in_channels: int,
99 | out_channels: int,
100 | kernel_size: Union[int, Tuple[int, int]],
101 | stride: Union[int, Tuple[int, int]] = 1,
102 | padding: Union[str, int, Tuple[int, int]] = 0,
103 | dilation: Union[int, Tuple[int, int]] = 1,
104 | groups: int = 1,
105 | bias: bool = True,
106 | *args, **kargs):
107 | super().__init__(*args, **kargs)
108 | """
109 | kernel_size = [T_size, F_size] by default
110 | stride = [T_stride, F_stride] and assert T_stride == 1
111 | """
112 | if type(kernel_size) is int:
113 | self.T_size = kernel_size
114 | self.F_size = kernel_size
115 | elif type(kernel_size) in [list, tuple]:
116 | self.T_size, self.F_size = kernel_size
117 | else:
118 | raise ValueError('Invalid kernel size.')
119 |
120 | if type(stride) is int:
121 | self.T_stride = stride
122 | self.F_stride = stride
123 | elif type(stride) in [list, tuple]:
124 | self.T_stride, self.F_stride = stride
125 | else:
126 | raise ValueError('Invalid stride size.')
127 |
128 | assert self.T_stride == 1
129 |
130 | if type(padding) is int:
131 | self.T_pad = padding
132 | self.F_pad = padding
133 | elif type(padding) in [list, tuple]:
134 | self.T_pad, self.F_pad = padding
135 | else:
136 | raise ValueError('Invalid padding size.')
137 |
138 | if type(dilation) is int:
139 | self.T_dilation = dilation
140 | self.F_dilation = dilation
141 | elif type(dilation) in [list, tuple]:
142 | self.T_dilation, self.F_dilation = dilation
143 | else:
144 | raise ValueError('Invalid dilation size.')
145 |
146 | assert self.T_pad == (self.T_size-1) * self.T_dilation, "To meet the demands of causal streaming requirements"
147 |
148 | self.ConvTranspose2d = nn.ConvTranspose2d(in_channels = in_channels,
149 | out_channels = out_channels,
150 | kernel_size = kernel_size,
151 | stride = stride,
152 | padding = padding,
153 | dilation = dilation,
154 | groups = groups,
155 | bias = bias)
156 |
157 | def forward(self, x, cache):
158 | """
159 | x: [bs, C, 1, F]
160 | cache: [bs, C, T_size-1, F]
161 | """
162 | inp = torch.cat([cache, x], dim=2)
163 | outp = self.ConvTranspose2d(inp)
164 | out_cache = inp[:,:, 1:]
165 | return outp, out_cache
166 |
167 |
168 | if __name__ == '__main__':
169 | from convert import convert_to_stream
170 |
171 | ### test Conv1d Stream
172 | Sconv = StreamConv1d(1, 1, 3)
173 | Conv = nn.Conv1d(1, 1, 3)
174 | convert_to_stream(Sconv, Conv)
175 |
176 | test_input = torch.randn([1, 1, 10])
177 | with torch.no_grad():
178 | ## Non-Streaming
179 | test_out1 = Conv(torch.nn.functional.pad(test_input, [2,0]))
180 |
181 | ## Streaming
182 | cache = torch.zeros([1, 1, 2])
183 | test_out2 = []
184 | for i in range(10):
185 | out, cache = Sconv(test_input[..., i:i+1], cache)
186 | test_out2.append(out)
187 | test_out2 = torch.cat(test_out2, dim=-1)
188 | print(">>> Streaming Conv1d error:", (test_out1 - test_out2).abs().max())
189 |
190 | ### test Conv2d Stream
191 | Sconv = StreamConv2d(1, 1, [3,3])
192 | Conv = nn.Conv2d(1, 1, (3,3))
193 | convert_to_stream(Sconv, Conv)
194 |
195 | test_input = torch.randn([1,1,10,6])
196 |
197 | with torch.no_grad():
198 | ## Non-Streaming
199 | test_out1 = Conv(torch.nn.functional.pad(test_input,[0,0,2,0]))
200 |
201 | ## Streaming
202 | cache = torch.zeros([1,1,2,6])
203 | test_out2 = []
204 | for i in range(10):
205 | out, cache = Sconv(test_input[:,:, i:i+1], cache)
206 | test_out2.append(out)
207 | test_out2 = torch.cat(test_out2, dim=2)
208 | print(">>> Streaming Conv2d error:", (test_out1 - test_out2).abs().max())
209 |
210 |
211 | ### test ConvTranspose2d Stream
212 | kt = 3 # kernel size along T axis
213 | dt = 2 # dilation along T axis
214 | pt = (kt-1) * dt # padding along T axis
215 | DeConv = torch.nn.ConvTranspose2d(4, 8, (kt,3), stride=(1,2), padding=(pt,1), dilation=(dt,2), groups=2)
216 | SDeconv = StreamConvTranspose2d(4, 8, (kt,3), stride=(1,2), padding=(2*2,1), dilation=(dt,2), groups=2)
217 | convert_to_stream(SDeconv, DeConv)
218 |
219 | test_input = torch.randn([1, 4, 100, 6])
220 | with torch.no_grad():
221 | ## Non-Streaming
222 | test_out1 = DeConv(nn.functional.pad(test_input, [0,0,pt,0])) # causal padding!
223 | test_out1 = test_out1
224 | ## Streaming
225 | test_out2 = []
226 | cache = torch.zeros([1, 4, pt, 6])
227 | for i in range(100):
228 | out, cache = SDeconv(test_input[:,:, i:i+1], cache)
229 | test_out2.append(out)
230 | test_out2 = torch.cat(test_out2, dim=2)
231 |
232 | print(">>> Streaming ConvTranspose2d error:", (test_out1 - test_out2).abs().max())
233 |
234 |
235 |
--------------------------------------------------------------------------------
/readme/README_zh.md:
--------------------------------------------------------------------------------
1 | # 使用 TensorRT 部署深度学习语音增强模型
2 |
3 | 安装 TensorRT 见:[TensorRT 安装教程](./TRTSETUP_zh.md)
4 |
5 | 语音增强模型的部署可以分为 **离线推理** 和 **在线推理** 两类。离线推理是在预先准备的数据上进行模型推理,通常是一批样本或较长的语音信号。离线推理无实时要求,可使用高效的推理方法和资源分配策略。
6 |
7 | 在线推理是在实时场景中对实时生成的语音数据进行模型推理,如通过麦克风捕捉的连续语音信号。在线推理要求低延迟、高吞吐量以满足实时性的要求。
8 |
9 | ## 离线模型的部署
10 | ### 1. 转换为 ONNX 模型
11 | 对于离线模型,ONNX 的导出非常简单。唯一需要注意的是输入形状的时间维度的设置。尽管 `torch.onnx.export` 支持动态维度,但考虑到实际应用场景中该需求不大,我们选择固定时间维度为 563,对应 9 s 长度的音频数据。当离线处理时,若音频不足 9 s,对其进行补 0 处理;若音频大于 9 s,对其以 9 s 分段按批次处理。
12 |
13 | `offline\dpcrn_onnx.py` 提供了 ONNX 模型的导出及推理,并评估了 ONNX 模型在 ONNXRuntime 上的推理速度。
14 |
15 | ### 2. 转换为 Engine 模型
16 | 我们使用 TensorRT 官方提供的转换工具 `trtexec.exe` 来进行模型从 ONNX 到 TensorRT 支持的 Engine 格式的转换,该工具位于 TensorRT 安装包内的 `bin` 目录下,使用方法为:
17 | ```
18 | trtexec.exe --onnx=[onnx_path] ---saveEngine=[save_path]
19 | ```
20 |
21 | `offline\dpcrn_trt.py` 提供了 Engine 模型的导出及推理,并评估了 Engine 模型的推理速度,结果如下表。
22 |
23 | | **模型格式** | **推理框架** | **推理平台** | **平均推理速度 (ms)** | **最大推理速度 (ms)** | **最小推理速度 (ms)** |
24 | |:-----------:|:-----------:|:-----------:|:----------------------:|:----------------------:|:----------------------:|
25 | | ONNX | ONNXRuntime | CPU | 8.6 | 21.0 | 7.6|
26 | | Engine |TensorRT| CUDA | 2.2 | 5.1 | 1.9 |
27 |
28 | 其中推理使用的 CPU 为 13th Gen Intel(R) Core(TM) i9-13900HX @ 2.20 GHz,使用的 CUDA 为 NVIDIA GeForce RTX 4080 Laptop GPU。推理过程重复进行了 1000 次,据此得到平均和最大最小推理速度。这里推理速度的定义为:音频处理时长 / 音频总时长。可以看到 TensorRT 框架的推理速度比 ONNXRuntime 提升了将近 4 倍。
29 |
30 | ## 在线模型的部署
31 | 在语音增强中,在线推理的应用场景更广泛且对模型的实时性要求更高,相应地,在线推理的部署也更复杂。在这里我们采用 **流式推理** 的方法,对实时数据流进行逐帧推理。在实现流式推理时,需要适当的数据缓冲机制、数据流管理和模型推理的流水线设计,以确保数据的连续性和推理的稳定性。
32 |
33 | ### 1. 转换为流式模型
34 | RNN 天然适应流式推理,无需额外转换,相比之下,卷积层是神经网络中需要进行流式转换的主要部分。在 `online\modules` 中,我们在 `convolution.py` 中定义了流式卷积和流式转职卷积两种算子的实现,并且在 `convert.py` 中提供了复制原模型参数字典的方法,用于流式模型的转换。
35 |
36 | `online\dpcrn_stream.py` 提供了流式模型的转换及推理过程,注意对流式模型,输入张量的时间维度始终为 1.
37 |
38 | ### 2. 转换为 ONNX 模型
39 | 对于流式模型,转换为 ONNX 时无需考虑时间维度的问题,但最好在 `forward` 函数中指定所有的输入张量,而不是像 `online\dpcrn_stream.py` 中那样用列表来代替。
40 |
41 | `online\dpcrn_stream_onnx.py` 提供了流式 ONNX 模型的转换及推理过程,并评估了其在 ONNXRuntime 的推理速度。
42 |
43 | ### 3. 转换为 Engine 模型
44 | 我们同样使用 TensorRT 官方提供的转换工具 `trtexec.exe` 来进行模型转换。
45 |
46 | `online\dpcrn_stream_trt.py` 提供了流式 Engine 模型的导出及推理,并评估了 Engine 模型的推理速度,结果如下表。
47 |
48 | | **模型格式** | **推理框架** | **推理平台** | **平均推理速度 (ms)** | **最大推理速度 (ms)** | **最小推理速度 (ms)** |
49 | |:-----------:|:-----------:|:-----------:|:----------------------:|:----------------------:|:----------------------:|
50 | | ONNX | ONNXRuntime | CPU |1.0 | 3.1 | 0.9 |
51 | | Engine |TensorRT| CUDA |2.2 | 4.7 | 1.8 |
52 |
53 | 其中推理速度的评估重复进行了 1000 次,这里推理速度的定义为:音频处理时长 / 音频总帧数。可以看到,使用 TensorRT 推理反而比使用 ONNXRuntime 更慢,这是因为对于高吞吐的流式模型,使用 TensorRT 推理时数据从 CUDA 向 CPU 迁移将会占用一定的时间。只有当模型在 CPU 上推理速度称为瓶颈时,使用 TensorRT 在 CUDA 上推理才有正面效果。
54 |
55 | ## 致谢
56 | 这个仓库使用的语音增强模型是 [DPCRN](https://arxiv.org/abs/2107.05429),是一个高性能低延时的优秀 SE 模型,在 DNS3 挑战赛中取得了排名第二的成绩。作者是我的师兄 [乐笑怀](https://github.com/Le-Xiaohuai-speech),他教给了我很多东西,包括这个仓库中使用的流式转换方法。
57 |
--------------------------------------------------------------------------------
/readme/TRTSETUP.md:
--------------------------------------------------------------------------------
1 | # TensorRT Installation Tutorial
2 | The Chinese version of the document can be found in: [TensorRT 安装教程](./TRTSETUP_zh.md)
3 |
4 | TensorRT is a deep learning inference engine developed by NVIDIA, designed to optimize and accelerate the inference process of deep learning models. It is specifically designed for deploying and efficiently executing deep learning models on NVIDIA GPUs. Models trained using frameworks such as PyTorch and TensorFlow can be converted to the TensorRT format and then utilized with the TensorRT inference engine to improve the model's runtime speed on GPUs. TensorRT is an ideal framework for deploying deep learning models on GPUs.
5 |
6 | This tutorial provides the installation and configuration methods for TensorRT on the Windows operating system. I use `NVIDIA GeForce RTX 4080 Laptop GPU`, and the corresponding software environment consists of the following components:
7 | * CUDA 11.7
8 | * cuDNN 8.8
9 | * TensorRT 8.5
10 | * Python 3.8.13
11 |
12 | ## 1. Environment Setup
13 | TensorRT is version-specific to CUDA and cuDNN, so it is necessary to ensure that CUDA and cuDNN are already installed and their versions are compatible.
14 | CUDA download link: https://developer.nvidia.com/cuda-downloads
15 | cuDNN download link: https://developer.nvidia.com/rdp/cudnn-download
16 |
17 | ### 1.1 Check CUDA Installation
18 | Open the command prompt and enter `nvcc -V`. If you see the following information, it means CUDA is installed:
19 | ```
20 | Cuda compilation tools, release 11.7, V11.7.64
21 | Build cuda_11.7.r11.7/compiler.31294372_0
22 | ```
23 | Furthermore, navigate to the installation path of CUDA, which is by default `C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA`. Within this path, go to `v11.7\extras\demo_suite`, where `v11.7` represents the version of CUDA downloaded. In this directory, open the command prompt and run `.\deviceQuery.exe` and `.\bandwidthTest.exe`. If both tests pass, it indicates that CUDA is functioning correctly.
24 |
25 | Note: If CUDA is not installed, download the appropriate version for your GPU model from the official NVIDIA website.
26 |
27 | ### 1.2 Check cuDNN Installation
28 | Installing cuDNN involves copying a series of files to the CUDA installation directory:
29 | * Download the cuDNN installation package corresponding to your CUDA version from the official website. Registration and login are required on the website to access the downloads. Extract the files locally.
30 | * Move the files in the `bin` directory of cuDNN to the `bin` directory of CUDA (`C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\bin`).
31 | * Move the files in the `include` directory of cuDNN to the `include` directory of CUDA.
32 | * Move the files in the `lib\x64` directory of cuDNN to the `lib\x64` directory of CUDA.
33 |
34 | ## 2. Install TensorRT
35 | ### 2.1 Download
36 | TensorRT can be downloaded from the official website https://developer.nvidia.com/tensorrt. Registration and login are required on the website to access the downloads. For Windows systems, simply download the compressed package and extract it to the local machine, making sure to select the TensorRT version that corresponds to your CUDA version.
37 |
38 | ### 2.2 Configuration
39 | #### 2.2.1 File Configuration
40 | The configuration process for TensorRT is similar to that of cuDNN:
41 | * Copy the files in the `include` directory of the TensorRT installation package to the `include` directory of CUDA.
42 | * Copy all the `lib` files in the `lib` directory of the TensorRT installation package to the `lib\x64` directory of CUDA.
43 | * Copy all the `dll` files in the `lib` directory of the TensorRT installation package to the `bin` directory of CUDA.
44 | * Add the path to the `bin` directory of TensorRT to the environment variables.
45 | * Add the paths to the `include`, `lib`, and `bin` directories of CUDA to the environment variables.
46 |
47 | #### 2.2.2 Install tensorrt
48 | * Navigate to the `TensorRT-8.5.1.7\python` directory, which contains the `whl` files for different versions of TensorRT compatible with different Python versions. Since the Python version in the virtual environment is 3.8, install the `whl` file corresponding to `cp38`. In the terminal command line at this path, within the virtual environment, run:
49 | ```
50 | pip install tensorrt-8.5.1.7-cp38-none-win_amd64.whl
51 | ```
52 | * Navigate to the `TensorRT-8.5.1.7\graphsurgeon` directory. In the terminal command line at this path, within the virtual environment, run:
53 | ```
54 | pip install graphsurgeon-0.4.6-py2.py3-none-any.whl
55 | ```
56 | * Navigate to the `TensorRT-8.5.1.7\onnx_graphsurgeon` directory,In the terminal command line at this path, within the virtual environment, run:
57 | ```
58 | pip install onnx_ graphsurgeon -0.3.12 - py2.py3-none -any.whl
59 | ```
60 | After the installation is complete, enter the Python environment and print the version information. If no errors are reported, it indicates a successful installation.
61 | ```
62 | import tensorrt as trt
63 | print(trt .__ version __)
64 | assert trt. Builder (trt. Logger ())
65 | ```
66 |
67 | #### 2.2.3 Install PyCUDA
68 | * Go to the PyCUDA download URL: https://www.lfd.uci.edu/~gohlke/pythonlibs/#pycuda and download the appropriate PyCUDA `.whl` file based on your CUDA and Python versions.
69 | * Navigate to the directory where the `.whl` file is located. Open a terminal or command prompt at that path and activate your virtual environment.
70 | * Run the following command:
71 | ```
72 | pip install pycuda‑2022.1+cuda116‑cp38‑cp38‑win_amd64.whl
73 | ```
74 |
75 | After completing the above configurations, you may still encounter an error when using the TensorRT compilation tool to convert ONNX models: `Could not locate zlibwapi.dll. Please make sure it is in your library path!`
76 |
77 | To resolve this issue:
78 | * First, download the `zlib` file and extract its contents.
79 | * Navigate to the `dll_x64` folder.
80 | * Copy the `zlibwapi.lib` file and paste it into `C:\Program Files\NVIDIA GPU ComputingToolkit\CUDA\v11.7\lib\x64`.
81 | * Copy the `zlibwapi.dll` file and paste it into `C:\Program Files\NVIDIA GPU ComputingToolkit\CUDA\v11.7\bin`.
82 |
83 |
84 | ## Reference
85 | [1] [Windows 安装 CUDA / cuDNN](https://zhuanlan.zhihu.com/p/99880204?from_voters_page=true)
86 | [2] [Windows 系统下如何确认 CUDA 和 cuDNN 都安装成功了](https://blog.csdn.net/qq_35768355/article/details/132985948)
87 | [3] [TensorRT 安装](https://blog.csdn.net/weixin_51691064/article/details/130403978)
88 | [4] [TensorRT 安装 zlibwapi.dll](https://blog.csdn.net/weixin_42166222/article/details/130625663)
89 | [5] [TensorRT 安装记录](https://blog.csdn.net/qq_37541097/article/details/114847600)
90 | [6] [PyCUDA 安装与使用](https://blog.csdn.net/qq_41910905/article/details/109650182)
91 |
--------------------------------------------------------------------------------
/readme/TRTSETUP_zh.md:
--------------------------------------------------------------------------------
1 | # TensorRT 安装教程
2 |
3 | TensorRT 是由 NVIDIA 开发的深度学习推理引擎,旨在优化和加速深度学习模型的推理过程。它是为了在 NVIDIA GPU 上部署和高效执行深度学习模型而设计的。利用 Pytorch、TensorFlow 等框架训练好的模型,可以转化为 TensorRT 的格式,然后利用 TensorRT 推理引擎进行推理,从而提升该模型在 GPU 上的运行速度。对于部署于 GPU 的深度学习模型,TensorRT 是个非常理想的推理框架。
4 |
5 | 本教程提供了在 Windows 系统中 TensorRT 的安装及配置方法。本机的显卡型号为 `NVIDIA GeForce RTX 4080 Laptop GPU`,使用的环境为:
6 | * CUDA 11.7
7 | * cuDNN 8.8
8 | * TensorRT 8.5
9 | * Python 3.8.13
10 |
11 | ## 1 环境准备
12 | TensorRT 与 CUDA 及 cuDNN 的版本是对应的,因此需要先确定已安装好 CUDA 与 cuDNN,并确认其版本。
13 | CUDA 下载地址:https://developer.nvidia.com/cuda-downloads
14 | cuDNN 下载地址:https://developer.nvidia.com/rdp/cudnn-download
15 |
16 | ### 1.1 检查 CUDA
17 | 进入命令行,输入 `nvcc -V`,若看到如下信息,则说明 CUDA 已安装。
18 | ```
19 | Cuda compilation tools, release 11.7, V11.7.64
20 | Build cuda_11.7.r11.7/compiler.31294372_0
21 | ```
22 | 进一步,进入 CUDA 的安装路径,默认情况下是 `C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA`,进入其下的 `v11.7\extras\demo_suite`,其中 v11.7 为下载的 CUDA 版本号。在该路径下打开命令行,运行 `.\deviceQuery.exe` 及 `.\bandwidthTest.exe`,若两项测试均通过,说明 CUDA 能正常工作。
23 |
24 | P.S. 若 CUDA 未安装好,根据自己显卡的型号到官网下载合适版本的 CUDA 即可。
25 |
26 | ### 1.2 检查 cuDNN
27 | cuDNN 的安装是将一系列文件导入到 CUDA 的安装目录下:
28 | * 从官网下载好与 CUDA 对应版本的 cuDNN 安装包,官网要求注册会员并登录后方可下载。解压至本地;
29 | * 将 cuDNN 中 `bin` 目录下的文件移动到 CUDA 的 `bin` 目录(`C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\bin`);
30 | * 将 cuDNN 中 `include` 目录下的文件移动到 CUDA 的 `include` 目录;
31 | * 将 cuDNN 中 `lib\x64` 目录下的文件移动到 CUDA 的 `lib\x64` 目录。
32 |
33 | ## 2 安装 TensorRT
34 | ### 2.1 下载
35 | TensorRT 可以通过官网 https://developer.nvidia.com/tensorrt 下载,官网要求注册会员并登录后方可下载。Windows 系统只需下载好压缩包并解压到本地即可,注意选择与 CUDA 版本对应的 TensorRT 版本。
36 |
37 | ### 2.2 配置
38 | #### 2.2.1. 文件配置
39 | TensorRT 的配置与 cuDNN 的配置类似:
40 | * 将 TensorRT 安装包中 `include` 目录下的文件复制到 CUDA 的 `include` 目录下;
41 | * 将 TensorRT 安装包中 `lib` 目录下的所有 `lib` 文件复制到 CUDA 的 `lib\x64` 目录下;
42 | * 将 TensorRT 安装包中 `lib` 目录下的所有 `dll` 文件复制到 CUDA 的 `bin` 目录下;
43 | * 将 TensorRT 的 `bin` 路径添加到环境变量;
44 | * 将 CUDA 的 `include`、`lib` 和 `bin` 路径添加到环境变量。
45 |
46 | #### 2.2.2 安装 tensorrt
47 | * 进入 `TensorRT-8.5.1.7\python` 目录,该目录下有 tensorrt 的针对不同 python 版本的 `whl` 文件。我们虚拟环境中的 python 版本为 3.8,应该安装 `cp38` 对应的 `whl` 文件。
48 | 在该路径下打开终端命令行,进入虚拟环境后,运行:
49 | ```
50 | pip install tensorrt -8.5.1.7 - cp38-none -win_amd 64. whl
51 | ```
52 | * 进入 `TensorRT-8.5.1.7\graphsurgeon` 目录,在该路径下打开终端命令行,进入虚拟环境后,运行:
53 | ```
54 | pip install graphsurgeon -0.4.6 - py2.py3-none -any.whl
55 | ```
56 | * 进入 `TensorRT-8.5.1.7\onnx_graphsurgeon` 目录,在该路径下打开终端命令行,进入虚拟环境后,运行:
57 | ```
58 | pip install onnx_ graphsurgeon -0.3.12 - py2.py3-none -any.whl
59 | ```
60 | 安装完成后,进入 python 环境,打印版本号等信息,若不报错则说明安装成功。
61 | ```
62 | import tensorrt as trt
63 | print(trt .__ version __)
64 | assert trt. Builder (trt. Logger ())
65 | ```
66 |
67 | #### 2.2.3 安装 pycuda
68 | * 在 PyCUDA 下载网址 https://www.lfd.uci.edu/~gohlke/pythonlibs/#pycuda 上根据 CUDA 和 python 版本下载好合适的 PyCUDA `whl` 文件。
69 | * 进入 `whl` 文件所在目录,在该路径下打开终端命令行,进入虚拟环境后,运行:
70 | ```
71 | pip install pycuda ‑2022.1+ cuda 116‑ cp 38‑ cp 38‑ win_amd 64. whl
72 | ```
73 |
74 | 以上一系列的配置完成后,在使用 TensorRT 的编译工具对 ONNX 模型进行转换时,仍可能会报错:`Could not locate zlibwapi.dll. Please make sure it is in your library path!`
75 | 解决方案是:
76 | * 首先下载 `zlib` 文件,解压后进入 `dll_x64` 文件夹;
77 | * 将 `zlibwapi.lib` 文件放到 `C:\Program Files\NVIDIA GPU ComputingToolkit\CUDA\v11.7\lib\x64` 下;
78 | * 将 `zlibwapi.dll` 文件放到 `C:\Program Files\NVIDIA GPU ComputingToolkit\CUDA\v11.7\bin` 下。
79 |
80 |
81 | ## 参考文章
82 | [1] [Windows 安装 CUDA / cuDNN](https://zhuanlan.zhihu.com/p/99880204?from_voters_page=true)
83 | [2] [Windows 系统下如何确认 CUDA 和 cuDNN 都安装成功了](https://blog.csdn.net/qq_35768355/article/details/132985948)
84 | [3] [TensorRT 安装](https://blog.csdn.net/weixin_51691064/article/details/130403978)
85 | [4] [TensorRT 安装 zlibwapi.dll](https://blog.csdn.net/weixin_42166222/article/details/130625663)
86 | [5] [TensorRT 安装记录](https://blog.csdn.net/qq_37541097/article/details/114847600)
87 | [6] [PyCUDA 安装与使用](https://blog.csdn.net/qq_41910905/article/details/109650182)
88 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.22.4
2 | onnx==1.14.1
3 | onnxruntime==1.15.1
4 | onnxruntime_gpu==1.16.0
5 | onnxsim==0.4.33
6 | ptflops==0.7
7 | pycuda==2022.2.2
8 | tensorrt==8.5.1.7
9 | torch==1.9.0+cu111
10 |
--------------------------------------------------------------------------------