├── README.md └── slim.py /README.md: -------------------------------------------------------------------------------- 1 | # yolov5模型训练后量化代码 2 | 3 | 在终端运行: 4 | 5 | ```bash 6 | python slim.py --in_weights last.pt --out_weights slim_model.pt --device 0 7 | ``` 8 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/20201206114606779.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80NDkzNjg4OQ==,size_16,color_FFFFFF,t_70) 9 | 可以看到权重文件压缩到了 43 MB。 10 | 11 | 更多模型训练和部署可以看我的博客: 12 | 13 | [【小白CV教程】Pytorch训练YOLOv5并量化压缩(VOC格式数据集)](https://blog.csdn.net/weixin_44936889/article/details/110732476) 14 | -------------------------------------------------------------------------------- /slim.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | import torch 5 | import torch.nn as nn 6 | from tqdm import tqdm 7 | 8 | 9 | def autopad(k, p=None): 10 | # Pad to 'same' 11 | if p is None: 12 | p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad 13 | return p 14 | 15 | 16 | class Conv(nn.Module): 17 | # Standard convolution 18 | # ch_in, ch_out, kernel, stride, padding, groups 19 | def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): 20 | super(Conv, self).__init__() 21 | self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), 22 | groups=g, bias=False) 23 | self.bn = nn.BatchNorm2d(c2) 24 | self.act = nn.Hardswish() if act else nn.Identity() 25 | 26 | def forward(self, x): 27 | return self.act(self.bn(self.conv(x))) 28 | 29 | def fuseforward(self, x): 30 | return self.act(self.conv(x)) 31 | 32 | 33 | class Ensemble(nn.ModuleList): 34 | # Ensemble of models 35 | def __init__(self): 36 | super(Ensemble, self).__init__() 37 | 38 | def forward(self, x, augment=False): 39 | y = [] 40 | for module in self: 41 | y.append(module(x, augment)[0]) 42 | # y = torch.stack(y).max(0)[0] # max ensemble 43 | # y = torch.cat(y, 1) # nms ensemble 44 | y = torch.stack(y).mean(0) # mean ensemble 45 | return y, None # inference, train output 46 | 47 | 48 | def attempt_load(weights, map_location=None): 49 | 50 | model = Ensemble() 51 | for w in weights if isinstance(weights, list) else [weights]: 52 | # load FP32 model 53 | model.append(torch.load( 54 | w, map_location=map_location)['model'].float().fuse().eval()) 55 | 56 | # Compatibility updates 57 | for m in tqdm(model.modules()): 58 | if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]: 59 | m.inplace = True # pytorch 1.7.0 compatibility 60 | elif type(m) is Conv: 61 | m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility 62 | 63 | if len(model) == 1: 64 | return model[-1] # return model 65 | else: 66 | print('Ensemble created with %s\n' % weights) 67 | for k in ['names', 'stride']: 68 | setattr(model, k, getattr(model[-1], k)) 69 | return model # return ensemble 70 | 71 | 72 | def select_device(device='', batch_size=None): 73 | # device = 'cpu' or '0' or '0,1,2,3' 74 | cpu_request = device.lower() == 'cpu' 75 | if device and not cpu_request: # if device requested other than 'cpu' 76 | os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable 77 | assert torch.cuda.is_available( 78 | ), 'CUDA unavailable, invalid device %s requested' % device # check availablity 79 | 80 | cuda = False if cpu_request else torch.cuda.is_available() 81 | if cuda: 82 | c = 1024 ** 2 # bytes to MB 83 | ng = torch.cuda.device_count() 84 | if ng > 1 and batch_size: # check that batch_size is compatible with device_count 85 | assert batch_size % ng == 0, 'batch-size %g not multiple of GPU count %g' % ( 86 | batch_size, ng) 87 | x = [torch.cuda.get_device_properties(i) for i in range(ng)] 88 | s = f'Using torch {torch.__version__} ' 89 | for i in range(0, ng): 90 | if i == 1: 91 | s = ' ' * len(s) 92 | 93 | return torch.device('cuda:0' if cuda else 'cpu') 94 | 95 | 96 | if __name__ == '__main__': 97 | 98 | import argparse 99 | 100 | parser = argparse.ArgumentParser() 101 | parser.add_argument('--in_weights', type=str, 102 | default='last.pt', help='initial weights path') 103 | parser.add_argument('--out_weights', type=str, 104 | default='slim_model.pt', help='output weights path') 105 | parser.add_argument('--device', type=str, default='0', help='device') 106 | opt = parser.parse_args() 107 | 108 | device = select_device(opt.device) 109 | model = attempt_load(opt.in_weights, map_location=device) 110 | model.to(device).eval() 111 | model.half() 112 | 113 | torch.save(model, opt.out_weights) 114 | print('done.') 115 | 116 | print('-[INFO] before: {} kb, after: {} kb'.format( 117 | os.path.getsize(opt.in_weights), os.path.getsize(opt.out_weights))) 118 | --------------------------------------------------------------------------------