├── 116_ori.png ├── README.md ├── cmp_debug.py ├── main_onnxrun.py ├── main_opencv.py ├── main_pytorch.py └── net.py /116_ori.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/Face-Parsing-pytorch-opencv-onnxruntime/eb8cd949f0859962c62245d3268b9f51b01af70e/116_ori.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Face-Parsing-pytorch-opencv-onnxruntime 2 | 使用BiSeNet做人脸面部解析,包含了基于pytorch, opencv, onnxruntime三种库的程序实现,并且比较了在调用三种库的输入和输出的差异 3 | .pth文件和.onnx文件从百度云盘下载, 4 | 链接:https://pan.baidu.com/s/1VGm7wsfCMw_RH7V_3ODuhg 5 | 提取码:fza0 6 | 7 | 8 | 基于pytorch框架运行的主程序是main_pytorch.py, 基于opencv运行的是main_opencv.py, 基于onnxruntime运行的是main_onnxrun.py 9 | 。在运行程序时,会保存神经网络的输入和输出到.npy文件。运行完这3个程序后,运行cmp_debug.py,它会比较在调用这三个不同框架时, 10 | 同一个神经网络的输入和输出的差异。 11 | 12 | BiSeNet是一个语义分割网络,人脸面部解析的本质是对人脸的不同器官做分割或者说像素级分类。本程序里,在运行cmp_debug.py后发现,调用 13 | pytorch框架的输出和调用opencv和onnxruntime的输出都不同,而opencv和onnxruntime的输出差异仅仅在小数点后10位,可以认为两者相等。 14 | 那么究竟是什么原因导致调用opencv或者onnxruntime的输出与调用pytorch的输出不同呢? 15 | 从运行程序的可视化结果图看,调用pytorch库的程序的输出结果是正确的,转换生成onnx文件的程序在net.py里,读者可以继续调试排查原因 16 | -------------------------------------------------------------------------------- /cmp_debug.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | torch_input = np.load('torch_input.npy') 4 | torch_out = np.load('torch_out.npy') 5 | 6 | cv_input = np.load('cv_input.npy') 7 | cv_out = np.load('cv_out.npy') 8 | 9 | onnxrun_input = np.load('onnxrun_input.npy') 10 | onnxrun_out = np.load('onnxrun_out.npy') 11 | 12 | if np.array_equal(torch_input, cv_input): 13 | print('pytorch和opencv的输入是相同的') 14 | else: 15 | print('pytorch和opencv的输入的平均差是', np.mean(torch_input - cv_input)) 16 | 17 | if np.array_equal(torch_input, onnxrun_input): 18 | print('pytorch和onnxruntime的输入是相同的') 19 | else: 20 | print('pytorch和onnxruntime的输入的平均差是', np.mean(torch_input - onnxrun_input)) 21 | 22 | if np.array_equal(cv_input, onnxrun_input): 23 | print('opencv和onnxruntime的输入是相同的') 24 | else: 25 | print('opencv和onnxruntime的输入的平均差是', np.mean(cv_input - onnxrun_input)) 26 | 27 | if np.array_equal(torch_out, cv_out): 28 | print('pytorch和opencv的输出是相同的') 29 | else: 30 | print('pytorch和opencv的输出的平均差是', np.mean(torch_out - cv_out)) 31 | 32 | if np.array_equal(torch_out, onnxrun_out): 33 | print('pytorch和onnxruntime的输出是相同的') 34 | else: 35 | print('pytorch和onnxruntime的输出的平均差是', np.mean(torch_out - onnxrun_out)) 36 | 37 | if np.array_equal(cv_out, onnxrun_out): 38 | print('opencv和onnxruntime的输出是相同的') 39 | else: 40 | print('opencv和onnxruntime的输出的平均差是', np.mean(cv_out - onnxrun_out)) -------------------------------------------------------------------------------- /main_onnxrun.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import numpy as np 4 | import onnxruntime as ort 5 | 6 | class face_parse: 7 | def __init__(self): 8 | self.net = ort.InferenceSession('my_param.onnx') 9 | self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 3) 10 | self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3) 11 | self.part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], 12 | [255, 0, 85], [255, 0, 170], 13 | [0, 255, 0], [85, 255, 0], [170, 255, 0], 14 | [0, 255, 85], [0, 255, 170], 15 | [0, 0, 255], [85, 0, 255], [170, 0, 255], 16 | [0, 85, 255], [0, 170, 255], 17 | [255, 255, 0], [255, 255, 85], [255, 255, 170], 18 | [255, 0, 255], [255, 85, 255], [255, 170, 255], 19 | [0, 255, 255], [85, 255, 255], [170, 255, 255]] 20 | def vis_parsing_maps(self, parsing_anno, stride): 21 | vis_parsing_anno = parsing_anno.copy().astype(np.uint8) 22 | vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST) 23 | vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255 24 | 25 | num_of_class = np.max(vis_parsing_anno) 26 | for pi in range(1, num_of_class + 1): 27 | index = np.where(vis_parsing_anno == pi) 28 | vis_parsing_anno_color[index[0], index[1], :] = self.part_colors[pi] 29 | vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) 30 | return vis_parsing_anno_color 31 | def parse(self, srcimg): 32 | img = cv2.cvtColor(srcimg, cv2.COLOR_BGR2RGB) 33 | img = cv2.resize(img, (512,512), interpolation=cv2.INTER_LINEAR) 34 | img = img.astype(np.float32) 35 | img /= 255.0 36 | img = (img - self.mean) / self.std 37 | 38 | blob = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0).astype(np.float32) 39 | np.save('onnxrun_input.npy', blob) 40 | out = self.net.run(None, {'input': blob})[0] 41 | np.save('onnxrun_out.npy', out) 42 | parsing = out.squeeze(0).argmax(0) 43 | vis_parsing_anno_color = self.vis_parsing_maps(parsing, stride=1) 44 | return vis_parsing_anno_color 45 | 46 | if __name__ == "__main__": 47 | parser = argparse.ArgumentParser(description='Face Parse') 48 | parser.add_argument('--imgpath', default='116_ori.png', type=str, help='A path to an image to use for display.') 49 | args = parser.parse_args() 50 | 51 | model = face_parse() 52 | srcimg = cv2.imread(args.imgpath) 53 | vis_parsing_anno_color = model.parse(srcimg) 54 | vis_parsing_anno_color = cv2.cvtColor(vis_parsing_anno_color, cv2.COLOR_RGB2BGR) 55 | cv2.namedWindow('vis_parsing_anno_color', cv2.WINDOW_NORMAL) 56 | cv2.imshow('vis_parsing_anno_color', vis_parsing_anno_color) 57 | cv2.waitKey(0) 58 | cv2.destroyAllWindows() -------------------------------------------------------------------------------- /main_opencv.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import numpy as np 4 | 5 | class face_parse: 6 | def __init__(self): 7 | self.net = cv2.dnn.readNet('my_param.onnx') 8 | self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 3) 9 | self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3) 10 | self.part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], 11 | [255, 0, 85], [255, 0, 170], 12 | [0, 255, 0], [85, 255, 0], [170, 255, 0], 13 | [0, 255, 85], [0, 255, 170], 14 | [0, 0, 255], [85, 0, 255], [170, 0, 255], 15 | [0, 85, 255], [0, 170, 255], 16 | [255, 255, 0], [255, 255, 85], [255, 255, 170], 17 | [255, 0, 255], [255, 85, 255], [255, 170, 255], 18 | [0, 255, 255], [85, 255, 255], [170, 255, 255]] 19 | def vis_parsing_maps(self, parsing_anno, stride): 20 | vis_parsing_anno = parsing_anno.copy().astype(np.uint8) 21 | vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST) 22 | vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255 23 | 24 | num_of_class = np.max(vis_parsing_anno) 25 | for pi in range(1, num_of_class + 1): 26 | index = np.where(vis_parsing_anno == pi) 27 | vis_parsing_anno_color[index[0], index[1], :] = self.part_colors[pi] 28 | vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) 29 | return vis_parsing_anno_color 30 | def parse(self, srcimg): 31 | img = cv2.cvtColor(srcimg, cv2.COLOR_BGR2RGB) 32 | img = cv2.resize(img, (512,512), interpolation=cv2.INTER_LINEAR) 33 | img = img.astype(np.float32) 34 | img /= 255.0 35 | img = (img - self.mean) / self.std 36 | blob = cv2.dnn.blobFromImage(img) 37 | np.save('cv_input.npy', blob) 38 | self.net.setInput(blob) 39 | out = self.net.forward() 40 | np.save('cv_out.npy', out) 41 | parsing = out.squeeze(0).argmax(0) 42 | vis_parsing_anno_color = self.vis_parsing_maps(parsing, stride=1) 43 | return vis_parsing_anno_color 44 | 45 | if __name__ == "__main__": 46 | parser = argparse.ArgumentParser(description='Face Parse') 47 | parser.add_argument('--imgpath', default='116_ori.png', type=str, help='A path to an image to use for display.') 48 | args = parser.parse_args() 49 | 50 | model = face_parse() 51 | srcimg = cv2.imread(args.imgpath) 52 | vis_parsing_anno_color = model.parse(srcimg) 53 | vis_parsing_anno_color = cv2.cvtColor(vis_parsing_anno_color, cv2.COLOR_RGB2BGR) 54 | cv2.namedWindow('vis_parsing_anno_color', cv2.WINDOW_NORMAL) 55 | cv2.imshow('vis_parsing_anno_color', vis_parsing_anno_color) 56 | cv2.waitKey(0) 57 | cv2.destroyAllWindows() -------------------------------------------------------------------------------- /main_pytorch.py: -------------------------------------------------------------------------------- 1 | from net import BiSeNet 2 | import torch 3 | import os 4 | import numpy as np 5 | from PIL import Image 6 | import torchvision.transforms as transforms 7 | import cv2 8 | import argparse 9 | 10 | def vis_parsing_maps(im, parsing_anno, stride): 11 | # Colors for all 20 parts 12 | part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], 13 | [255, 0, 85], [255, 0, 170], 14 | [0, 255, 0], [85, 255, 0], [170, 255, 0], 15 | [0, 255, 85], [0, 255, 170], 16 | [0, 0, 255], [85, 0, 255], [170, 0, 255], 17 | [0, 85, 255], [0, 170, 255], 18 | [255, 255, 0], [255, 255, 85], [255, 255, 170], 19 | [255, 0, 255], [255, 85, 255], [255, 170, 255], 20 | [0, 255, 255], [85, 255, 255], [170, 255, 255]] 21 | 22 | im = np.array(im) 23 | vis_im = im.copy().astype(np.uint8) 24 | vis_parsing_anno = parsing_anno.copy().astype(np.uint8) 25 | vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST) 26 | vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255 27 | 28 | num_of_class = np.max(vis_parsing_anno) 29 | 30 | for pi in range(1, num_of_class + 1): 31 | index = np.where(vis_parsing_anno == pi) 32 | vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi] 33 | 34 | vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) 35 | # print(vis_parsing_anno_color.shape, vis_im.shape) 36 | vis_im = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0) 37 | return vis_parsing_anno_color, vis_im 38 | 39 | if __name__ == "__main__": 40 | parser = argparse.ArgumentParser(description='Face Parse') 41 | parser.add_argument('--imgpath', default='116_ori.png', type=str, help='A path to an image to use for display.') 42 | args = parser.parse_args() 43 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 44 | n_classes = 19 45 | net = BiSeNet(n_classes) 46 | net.to(device) 47 | net.load_state_dict(torch.load('my_params.pth', map_location=device)) 48 | net.eval() 49 | 50 | to_tensor = transforms.Compose([ 51 | transforms.ToTensor(), 52 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 53 | ]) 54 | img = Image.open(args.imgpath) 55 | with torch.no_grad(): 56 | image = img.resize((512, 512), Image.BILINEAR) 57 | img = to_tensor(image) 58 | img = torch.unsqueeze(img, 0) 59 | img = img.to(device) 60 | np.save('torch_input.npy', img.cpu().numpy()) 61 | out = net(img) 62 | np.save('torch_out.npy', out.cpu().numpy()) 63 | parsing = out.squeeze(0).cpu().numpy().argmax(0) 64 | vis_parsing_anno_color, vis_im = vis_parsing_maps(image, parsing, stride=1) 65 | 66 | cv2.namedWindow('vis_parsing_anno_color', cv2.WINDOW_NORMAL) 67 | cv2.imshow('vis_parsing_anno_color', vis_parsing_anno_color) 68 | cv2.namedWindow('vis_im', cv2.WINDOW_NORMAL) 69 | cv2.imshow('vis_im', vis_im) 70 | cv2.waitKey(0) 71 | cv2.destroyAllWindows() -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def conv3x3(in_planes, out_planes, stride=1): 6 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 7 | padding=1, bias=False) 8 | 9 | class BasicBlock(nn.Module): 10 | def __init__(self, in_chan, out_chan, stride=1): 11 | super(BasicBlock, self).__init__() 12 | self.conv1 = conv3x3(in_chan, out_chan, stride) 13 | self.bn1 = nn.BatchNorm2d(out_chan) 14 | self.conv2 = conv3x3(out_chan, out_chan) 15 | self.bn2 = nn.BatchNorm2d(out_chan) 16 | self.relu = nn.ReLU(inplace=True) 17 | self.downsample = None 18 | if in_chan != out_chan or stride != 1: 19 | self.downsample = nn.Sequential( 20 | nn.Conv2d(in_chan, out_chan, 21 | kernel_size=1, stride=stride, bias=False), 22 | nn.BatchNorm2d(out_chan), 23 | ) 24 | 25 | def forward(self, x): 26 | residual = self.conv1(x) 27 | residual = F.relu(self.bn1(residual)) 28 | residual = self.conv2(residual) 29 | residual = self.bn2(residual) 30 | 31 | shortcut = x 32 | if self.downsample is not None: 33 | shortcut = self.downsample(x) 34 | 35 | out = shortcut + residual 36 | out = self.relu(out) 37 | return out 38 | 39 | def create_layer_basic(in_chan, out_chan, bnum, stride=1): 40 | layers = [BasicBlock(in_chan, out_chan, stride=stride)] 41 | for i in range(bnum-1): 42 | layers.append(BasicBlock(out_chan, out_chan, stride=1)) 43 | return nn.Sequential(*layers) 44 | 45 | class Resnet18(nn.Module): 46 | def __init__(self): 47 | super(Resnet18, self).__init__() 48 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 49 | bias=False) 50 | self.bn1 = nn.BatchNorm2d(64) 51 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 52 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) 53 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) 54 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) 55 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) 56 | 57 | def forward(self, x): 58 | x = self.conv1(x) 59 | x = F.relu(self.bn1(x)) 60 | x = self.maxpool(x) 61 | 62 | x = self.layer1(x) 63 | feat8 = self.layer2(x) # 1/8 64 | feat16 = self.layer3(feat8) # 1/16 65 | feat32 = self.layer4(feat16) # 1/32 66 | return feat8, feat16, feat32 67 | 68 | class ConvBNReLU(nn.Module): 69 | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1): 70 | super(ConvBNReLU, self).__init__() 71 | self.conv = nn.Conv2d(in_chan, 72 | out_chan, 73 | kernel_size = ks, 74 | stride = stride, 75 | padding = padding, 76 | bias = False) 77 | self.bn = nn.BatchNorm2d(out_chan) 78 | 79 | def forward(self, x): 80 | x = self.conv(x) 81 | x = F.relu(self.bn(x)) 82 | return x 83 | 84 | class BiSeNetOutput(nn.Module): 85 | def __init__(self, in_chan, mid_chan, n_classes): 86 | super(BiSeNetOutput, self).__init__() 87 | self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) 88 | self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False) 89 | 90 | def forward(self, x): 91 | x = self.conv(x) 92 | x = self.conv_out(x) 93 | return x 94 | 95 | class AttentionRefinementModule(nn.Module): 96 | def __init__(self, in_chan, out_chan): 97 | super(AttentionRefinementModule, self).__init__() 98 | self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) 99 | self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False) 100 | self.bn_atten = nn.BatchNorm2d(out_chan) 101 | self.sigmoid_atten = nn.Sigmoid() 102 | 103 | def forward(self, x): 104 | feat = self.conv(x) 105 | # atten = F.avg_pool2d(feat, feat.size()[2:]) 106 | # atten = F.adaptive_avg_pool2d(feat, (1,1)) 107 | atten = torch.mean(feat, (2, 3), keepdim=True) 108 | atten = self.conv_atten(atten) 109 | atten = self.bn_atten(atten) 110 | atten = self.sigmoid_atten(atten) 111 | out = torch.mul(feat, atten) 112 | return out 113 | 114 | class ContextPath(nn.Module): 115 | def __init__(self): 116 | super(ContextPath, self).__init__() 117 | self.resnet = Resnet18() 118 | self.arm16 = AttentionRefinementModule(256, 128) 119 | self.arm32 = AttentionRefinementModule(512, 128) 120 | self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 121 | self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 122 | self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) 123 | 124 | def forward(self, x): 125 | feat8, feat16, feat32 = self.resnet(x) 126 | # H8, W8 = feat8.size()[2:] 127 | # H16, W16 = feat16.size()[2:] 128 | # H32, W32 = feat32.size()[2:] 129 | 130 | # avg = F.avg_pool2d(feat32, feat32.size()[2:]) 131 | # avg = F.adaptive_avg_pool2d(feat32, (1, 1)) 132 | avg = torch.mean(feat32, (2, 3), keepdim=True) 133 | avg = self.conv_avg(avg) 134 | avg_up = F.interpolate(avg, size=(int(feat32.shape[2]), int(feat32.shape[3])), mode='nearest') 135 | 136 | feat32_arm = self.arm32(feat32) 137 | # feat32_sum = feat32_arm + avg_up 138 | feat32_sum = torch.add(feat32_arm, avg_up) 139 | feat32_up = F.interpolate(feat32_sum, size=(int(feat16.shape[2]), int(feat16.shape[3])), mode='nearest') 140 | feat32_up = self.conv_head32(feat32_up) 141 | 142 | feat16_arm = self.arm16(feat16) 143 | # feat16_sum = feat16_arm + feat32_up 144 | feat16_sum = torch.add(feat16_arm, feat32_up) 145 | feat16_up = F.interpolate(feat16_sum, size=(int(feat8.shape[2]), int(feat8.shape[3])), mode='nearest') 146 | feat16_up = self.conv_head16(feat16_up) 147 | return feat8, feat16_up # x8, x8 148 | 149 | ### This is not used, since I replace this with the resnet feature with the same size 150 | class SpatialPath(nn.Module): 151 | def __init__(self): 152 | super(SpatialPath, self).__init__() 153 | self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3) 154 | self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) 155 | self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) 156 | self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0) 157 | 158 | def forward(self, x): 159 | feat = self.conv1(x) 160 | feat = self.conv2(feat) 161 | feat = self.conv3(feat) 162 | feat = self.conv_out(feat) 163 | return feat 164 | 165 | class FeatureFusionModule(nn.Module): 166 | def __init__(self, in_chan, out_chan): 167 | super(FeatureFusionModule, self).__init__() 168 | self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) 169 | self.conv1 = nn.Conv2d(out_chan, 170 | out_chan//4, 171 | kernel_size = 1, 172 | stride = 1, 173 | padding = 0, 174 | bias = False) 175 | self.conv2 = nn.Conv2d(out_chan//4, 176 | out_chan, 177 | kernel_size = 1, 178 | stride = 1, 179 | padding = 0, 180 | bias = False) 181 | self.relu = nn.ReLU(inplace=True) 182 | self.sigmoid = nn.Sigmoid() 183 | 184 | def forward(self, fsp, fcp): 185 | fcat = torch.cat([fsp, fcp], dim=1) 186 | feat = self.convblk(fcat) 187 | # atten = F.avg_pool2d(feat, feat.size()[2:]) 188 | # atten = F.adaptive_avg_pool2d(feat, (1, 1)) 189 | atten = torch.mean(feat, (2,3), keepdim=True) 190 | atten = self.conv1(atten) 191 | atten = self.relu(atten) 192 | atten = self.conv2(atten) 193 | atten = self.sigmoid(atten) 194 | feat_atten = torch.mul(feat, atten) 195 | # feat_out = feat_atten + feat 196 | feat_out = torch.add(feat_atten,feat) 197 | return feat_out 198 | 199 | class BiSeNet(nn.Module): 200 | def __init__(self, n_classes): 201 | super(BiSeNet, self).__init__() 202 | self.cp = ContextPath() 203 | self.ffm = FeatureFusionModule(256, 256) 204 | self.conv_out = BiSeNetOutput(256, 256, n_classes) 205 | 206 | def forward(self, x): 207 | feat_res8, feat_cp8 = self.cp(x) # here return res3b1 feature 208 | feat_fuse = self.ffm(feat_res8, feat_cp8) 209 | feat_out = self.conv_out(feat_fuse) 210 | feat_out = F.interpolate(feat_out, size=(int(x.shape[2]), int(x.shape[3])), mode='bilinear', align_corners=True) 211 | return feat_out 212 | 213 | if __name__ == "__main__": 214 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 215 | net = BiSeNet(19).to(device) 216 | net.eval() 217 | 218 | with torch.no_grad(): 219 | in_ten = torch.randn(2, 3, 512, 512).to(device) 220 | out = net(in_ten) 221 | print(out.shape) 222 | 223 | inputs = torch.randn(1, 3, 512, 512).to(device) 224 | torch.onnx.export(net, inputs, 'my_param.onnx', verbose=False, opset_version=12, input_names=["input"], output_names=["out"]) 225 | --------------------------------------------------------------------------------