├── findplate ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ └── dataset.cpython-36.pyc │ └── dataset.py ├── utils │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ └── visualize.cpython-36.pyc │ └── visualize.py ├── plate.csv ├── char.csv ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── alexnet.cpython-36.pyc │ │ ├── resnet34.cpython-36.pyc │ │ ├── resnet50.cpython-36.pyc │ │ ├── squeezenet.cpython-36.pyc │ │ ├── basic_module.cpython-36.pyc │ │ └── squeezenet_gray.cpython-36.pyc │ ├── squeezenet.py │ ├── squeezenet_gray.py │ └── basic_module.py ├── checkpoints │ ├── squeezenet_char.pth │ └── squeezenet_plate.pth ├── config.py └── testnetwork.py ├── 演示动画.gif ├── README.md ├── requirements.txt ├── gui.py ├── network.py └── main.py /findplate/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /findplate/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .visualize import Visualizer -------------------------------------------------------------------------------- /findplate/plate.csv: -------------------------------------------------------------------------------- 1 | label_idx,label_name 2 | 0,has 3 | 1,no 4 | -------------------------------------------------------------------------------- /演示动画.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CescMessi/carplaterecognize/HEAD/演示动画.gif -------------------------------------------------------------------------------- /findplate/char.csv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CescMessi/carplaterecognize/HEAD/findplate/char.csv -------------------------------------------------------------------------------- /findplate/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .squeezenet import SqueezeNet 2 | from .squeezenet_gray import SqueezeNetGray -------------------------------------------------------------------------------- /findplate/checkpoints/squeezenet_char.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CescMessi/carplaterecognize/HEAD/findplate/checkpoints/squeezenet_char.pth -------------------------------------------------------------------------------- /findplate/checkpoints/squeezenet_plate.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CescMessi/carplaterecognize/HEAD/findplate/checkpoints/squeezenet_plate.pth -------------------------------------------------------------------------------- /findplate/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CescMessi/carplaterecognize/HEAD/findplate/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /findplate/data/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CescMessi/carplaterecognize/HEAD/findplate/data/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /findplate/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CescMessi/carplaterecognize/HEAD/findplate/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /findplate/models/__pycache__/alexnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CescMessi/carplaterecognize/HEAD/findplate/models/__pycache__/alexnet.cpython-36.pyc -------------------------------------------------------------------------------- /findplate/models/__pycache__/resnet34.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CescMessi/carplaterecognize/HEAD/findplate/models/__pycache__/resnet34.cpython-36.pyc -------------------------------------------------------------------------------- /findplate/models/__pycache__/resnet50.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CescMessi/carplaterecognize/HEAD/findplate/models/__pycache__/resnet50.cpython-36.pyc -------------------------------------------------------------------------------- /findplate/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CescMessi/carplaterecognize/HEAD/findplate/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /findplate/utils/__pycache__/visualize.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CescMessi/carplaterecognize/HEAD/findplate/utils/__pycache__/visualize.cpython-36.pyc -------------------------------------------------------------------------------- /findplate/models/__pycache__/squeezenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CescMessi/carplaterecognize/HEAD/findplate/models/__pycache__/squeezenet.cpython-36.pyc -------------------------------------------------------------------------------- /findplate/models/__pycache__/basic_module.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CescMessi/carplaterecognize/HEAD/findplate/models/__pycache__/basic_module.cpython-36.pyc -------------------------------------------------------------------------------- /findplate/models/__pycache__/squeezenet_gray.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CescMessi/carplaterecognize/HEAD/findplate/models/__pycache__/squeezenet_gray.cpython-36.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # carplaterecognize 2 | 使用pytorch和opencv的简易识别车牌程序 3 | 4 | 使用opencv找到类似车牌的物体,接着使用模型判断是否为车牌。若为车牌,将车牌图像拉伸至标准形状,对字符进行分割,每个字符单独使用模型进行识别。 5 | 6 | 模型训练代码已经包含,使用的是简单的squeezenet,可以自行修改。 7 | 8 | ![演示图片](./演示动画.gif) 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.16.2 2 | tqdm>=4.31.1 3 | torchvision>=0.2.2 4 | torchnet>=0.0.4 5 | visdom>=0.1.8.8 6 | fire>=0.1.3 7 | opencv_python>=4.1.0.25 8 | ipdb>=0.12 9 | torch>=1.0.1 10 | Pillow>=6.0.0 11 | -------------------------------------------------------------------------------- /gui.py: -------------------------------------------------------------------------------- 1 | from tkinter import * 2 | from tkinter.filedialog import askopenfilename 3 | from main import recognition 4 | 5 | def select(): 6 | file_path = askopenfilename(title=u'选择文件', filetypes=[(".JPG", ".jpg")]) 7 | number.set(recognition(file_path)) 8 | 9 | 10 | root = Tk() 11 | root.title('车牌检测') 12 | number = StringVar() 13 | Label(root, text= '车牌号').grid(row = 0, column = 0) 14 | Entry(root, textvariable = number).grid(row = 0, column = 1) 15 | Button(root, text = '打开图片', command = select).grid(row = 1, column = 0) 16 | Button(root, text = '退出', command = root.quit).grid(row = 1, column = 3) 17 | root.mainloop() 18 | -------------------------------------------------------------------------------- /findplate/models/squeezenet.py: -------------------------------------------------------------------------------- 1 | from torchvision.models import squeezenet1_1 2 | from findplate.models.basic_module import BasicModule 3 | from torch import nn 4 | from torch.optim import Adam 5 | from findplate.config import opt 6 | 7 | class SqueezeNet(BasicModule): 8 | def __init__(self, num_classes=2): 9 | super(SqueezeNet, self).__init__() 10 | self.model_name = 'squeezenet' 11 | self.model = squeezenet1_1(pretrained=False) 12 | # 修改 原始的num_class: 预训练模型是1000分类 13 | self.model.num_classes = num_classes 14 | self.model.classifier = nn.Sequential( 15 | nn.Dropout(p=0.5), 16 | nn.Conv2d(512, num_classes, 1), 17 | nn.ReLU(inplace=True), 18 | nn.AvgPool2d(13, stride=1) 19 | ) 20 | 21 | def forward(self,x): 22 | return self.model(x) 23 | 24 | def get_optimizer(self, lr, weight_decay): 25 | # 因为使用了预训练模型,我们只需要训练后面的分类 26 | # 前面的特征提取部分可以保持不变 27 | return Adam(self.model.classifier.parameters(), lr, weight_decay=weight_decay) -------------------------------------------------------------------------------- /findplate/models/squeezenet_gray.py: -------------------------------------------------------------------------------- 1 | from torchvision.models import squeezenet1_1 2 | from findplate.models.basic_module import BasicModule 3 | from torch import nn 4 | from torch.optim import Adam 5 | from findplate.config import opt 6 | 7 | class SqueezeNetGray(BasicModule): 8 | def __init__(self, num_classes=65): 9 | super(SqueezeNetGray, self).__init__() 10 | self.model_name = 'squeezenet_gray' 11 | self.model = squeezenet1_1(pretrained=False) 12 | # 修改 原始的num_class: 预训练模型是1000分类 13 | self.model.num_classes = num_classes 14 | self.model.classifier = nn.Sequential( 15 | nn.Dropout(p=0.5), 16 | nn.Conv2d(512, num_classes, 1), 17 | nn.ReLU(inplace=True), 18 | nn.AvgPool2d(13, stride=1) 19 | ) 20 | 21 | def forward(self,x): 22 | return self.model(x) 23 | 24 | def get_optimizer(self, lr, weight_decay): 25 | # 因为使用了预训练模型,我们只需要训练后面的分类 26 | # 前面的特征提取部分可以保持不变 27 | return Adam(self.model.classifier.parameters(), lr, weight_decay=weight_decay) -------------------------------------------------------------------------------- /findplate/models/basic_module.py: -------------------------------------------------------------------------------- 1 | 2 | import torch as t 3 | import time 4 | 5 | 6 | class BasicModule(t.nn.Module): 7 | """ 8 | 封装了nn.Module,主要是提供了save和load两个方法 9 | """ 10 | 11 | def __init__(self): 12 | super(BasicModule,self).__init__() 13 | self.model_name=str(type(self))# 默认名字 14 | 15 | def load(self, path): 16 | """ 17 | 可加载指定路径的模型 18 | """ 19 | self.load_state_dict(t.load(path, map_location='cpu')) 20 | 21 | def save(self, name=None): 22 | """ 23 | 保存模型,默认使用“模型名字+时间”作为文件名 24 | """ 25 | if name is None: 26 | prefix = './findplate/checkpoints/' + self.model_name + '_' 27 | name = time.strftime(prefix + '%m%d_%H%M%S.pth') 28 | t.save(self.state_dict(), name) 29 | return name 30 | 31 | def get_optimizer(self, lr, weight_decay): 32 | return t.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay) 33 | 34 | 35 | class Flat(t.nn.Module): 36 | """ 37 | 把输入reshape成(batch_size,dim_length) 38 | """ 39 | 40 | def __init__(self): 41 | super(Flat, self).__init__() 42 | #self.size = size 43 | 44 | def forward(self, x): 45 | return x.view(x.size(0), -1) 46 | -------------------------------------------------------------------------------- /findplate/config.py: -------------------------------------------------------------------------------- 1 | 2 | import warnings 3 | import torch as t 4 | 5 | class DefaultConfig(object): 6 | env = 'default' # visdom 环境 7 | vis_port =8097 # visdom 端口 8 | model = 'SqueezeNet' # 使用的模型,名字必须与models/__init__.py中的名字一致 9 | classifier_num = 2 # 分类器最终的分类数量 10 | gray = False # 读取图片是否为灰度图 11 | 12 | train_data_root = './imgs/images/cnn_plate_train/' # 训练集存放路径 13 | test_data_root = './data/test/' # 测试集存放路径 14 | load_model_path = None # 加载预训练的模型的路径,为None代表不加载 15 | 16 | batch_size = 16 # batch size 17 | use_gpu = True # user GPU or not 18 | num_workers = 0 # how many workers for loading data 19 | print_freq = 20 # print info every N batch 20 | 21 | debug_file = '/tmp/debug' # if os.path.exists(debug_file): enter ipdb 22 | result_file = 'result.csv' 23 | id_file = './findplate/plate.csv' 24 | 25 | max_epoch = 100 26 | lr = 0.001 # initial learning rate 27 | lr_decay = 0.5 # when val_loss increase, lr = lr*lr_decay 28 | weight_decay = 0e-5 # 损失函数 29 | 30 | 31 | def _parse(self, kwargs): 32 | """ 33 | 根据字典kwargs 更新 config参数 34 | """ 35 | for k, v in kwargs.items(): 36 | if not hasattr(self, k): 37 | warnings.warn("Warning: opt has not attribut %s" % k) 38 | setattr(self, k, v) 39 | 40 | self.device =t.device('cuda') if self.use_gpu else t.device('cpu') 41 | 42 | 43 | print('user config:') 44 | for k, v in self.__class__.__dict__.items(): 45 | if not k.startswith('_'): 46 | print(k, getattr(self, k)) 47 | 48 | opt = DefaultConfig() 49 | -------------------------------------------------------------------------------- /findplate/utils/visualize.py: -------------------------------------------------------------------------------- 1 | 2 | import visdom 3 | import time 4 | import numpy as np 5 | 6 | 7 | class Visualizer(object): 8 | """ 9 | 封装了visdom的基本操作,但是你仍然可以通过`self.vis.function` 10 | 调用原生的visdom接口 11 | """ 12 | 13 | def __init__(self, env='default', **kwargs): 14 | self.vis = visdom.Visdom(env=env,use_incoming_socket=False, **kwargs) 15 | 16 | # 画的第几个数,相当于横座标 17 | # 保存(’loss',23) 即loss的第23个点 18 | self.index = {} 19 | self.log_text = '' 20 | 21 | def reinit(self, env='default', **kwargs): 22 | """ 23 | 修改visdom的配置 24 | """ 25 | self.vis = visdom.Visdom(env=env, **kwargs) 26 | return self 27 | 28 | def plot_many(self, d): 29 | """ 30 | 一次plot多个 31 | @params d: dict (name,value) i.e. ('loss',0.11) 32 | """ 33 | for k, v in d.items(): 34 | self.plot(k, v) 35 | 36 | def img_many(self, d): 37 | for k, v in d.items(): 38 | self.img(k, v) 39 | 40 | def plot(self, name, y, **kwargs): 41 | """ 42 | self.plot('loss',1.00) 43 | """ 44 | x = self.index.get(name, 0) 45 | self.vis.line(Y=np.array([y]), X=np.array([x]), 46 | win=name, 47 | opts=dict(title=name), 48 | update=None if x == 0 else 'append', 49 | **kwargs 50 | ) 51 | self.index[name] = x + 1 52 | 53 | def img(self, name, img_, **kwargs): 54 | """ 55 | self.img('input_img',t.Tensor(64,64)) 56 | self.img('input_imgs',t.Tensor(3,64,64)) 57 | self.img('input_imgs',t.Tensor(100,1,64,64)) 58 | self.img('input_imgs',t.Tensor(100,3,64,64),nrows=10) 59 | 60 | !!!don‘t ~~self.img('input_imgs',t.Tensor(100,64,64),nrows=10)~~!!! 61 | """ 62 | self.vis.images(img_.cpu().numpy(), 63 | win=name, 64 | opts=dict(title=name), 65 | **kwargs 66 | ) 67 | 68 | def log(self, info, win='log_text'): 69 | """ 70 | self.log({'loss':1,'lr':0.0001}) 71 | """ 72 | 73 | self.log_text += ('[{time}] {info}
'.format( 74 | time=time.strftime('%m%d_%H%M%S'), 75 | info=info)) 76 | self.vis.text(self.log_text, win) 77 | 78 | def __getattr__(self, name): 79 | return getattr(self.vis, name) 80 | -------------------------------------------------------------------------------- /findplate/testnetwork.py: -------------------------------------------------------------------------------- 1 | 2 | from findplate.config import opt 3 | import os 4 | import sys 5 | import torch as t 6 | from findplate import models 7 | from findplate.data.dataset import MyDataset 8 | from torch.utils.data import DataLoader 9 | from torchnet import meter 10 | from findplate.utils.visualize import Visualizer 11 | from tqdm import tqdm 12 | from torchvision import transforms as T 13 | 14 | def resource_path(relative_path): 15 | try: 16 | base_path = sys._MEIPASS 17 | except Exception: 18 | base_path = os.path.abspath(".") 19 | 20 | return os.path.join(base_path, relative_path) 21 | 22 | # 判断是否为车牌 23 | @t.no_grad() 24 | def detect(img): 25 | # 载入模型和参数 26 | model = getattr(models, opt.model)().eval() 27 | model.load(resource_path('findplate/checkpoints/squeezenet_plate.pth')) 28 | # 归一化 29 | normalize = T.Normalize(mean=[0.485, 0.456, 0.406], 30 | std=[0.229, 0.224, 0.225]) 31 | # 变换 32 | transforms = T.Compose([ 33 | T.Resize(224), 34 | T.CenterCrop(224), 35 | T.ToTensor(), 36 | normalize 37 | ]) 38 | inputdata = transforms(img) 39 | inputdata = t.unsqueeze(inputdata, 0) 40 | # 将图像喂入模型,获取标签 41 | score = model(inputdata) 42 | id_label_dict = row_csv2dict(resource_path('findplate/plate.csv')) 43 | label = score.max(dim = 1)[1].detach().tolist() 44 | label = [id_label_dict[str(i)] for i in label] 45 | return label 46 | 47 | # 识别字符 48 | @t.no_grad() 49 | def identify(img_array): 50 | model = getattr(models, 'SqueezeNetGray')().eval() 51 | model.load(resource_path('findplate/checkpoints/squeezenet_char.pth')) 52 | normalize = T.Normalize(mean=[0.485, 0.456, 0.406], 53 | std=[0.229, 0.224, 0.225]) 54 | transforms = T.Compose([ 55 | T.Resize(224), 56 | T.CenterCrop(224), 57 | T.ToTensor(), 58 | normalize 59 | ]) 60 | 61 | # 将多个图像合并到一个tensor中 62 | flag = 0 63 | for img in img_array: 64 | data = transforms(img) 65 | data = t.unsqueeze(data, 0) 66 | if flag == 0: 67 | inputdata = data 68 | flag = 1 69 | else: 70 | inputdata = t.cat((inputdata, data), 0) 71 | 72 | 73 | score = model(inputdata) 74 | id_label_dict = row_csv2dict(resource_path('findplate/char.csv')) 75 | label = score.max(dim = 1)[1].detach().tolist() 76 | label = [id_label_dict[str(i)] for i in label] 77 | return label 78 | 79 | 80 | 81 | 82 | def row_csv2dict(csv_file): 83 | import csv 84 | dict_club={} 85 | with open(csv_file)as f: 86 | reader=csv.reader(f,delimiter=',') 87 | for row in reader: 88 | dict_club[row[0]]=row[1] 89 | return dict_club 90 | 91 | -------------------------------------------------------------------------------- /findplate/data/dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from PIL import Image 4 | from torch.utils import data 5 | import numpy as np 6 | from torchvision import transforms as T 7 | from torchvision.datasets import ImageFolder 8 | import random 9 | from findplate.config import opt 10 | 11 | 12 | class MyDataset(data.Dataset): 13 | 14 | def __init__(self, root, transforms=None, train=True, test=False): 15 | """ 16 | 主要目标: 获取所有图片的地址,并根据训练,验证,测试划分数据 17 | """ 18 | self.test = test 19 | 20 | if self.test: 21 | imgs = [os.path.join(root, img) for img in os.listdir(root)] 22 | else: 23 | dataset = ImageFolder(root) 24 | self.data_classes = dataset.classes 25 | imgs = [dataset.imgs[i][0] for i in range(len(dataset.imgs))] 26 | labels = [dataset.imgs[i][1] for i in range(len(dataset.imgs))] 27 | imgs_num = len(imgs) 28 | 29 | if self.test: 30 | self.imgs = imgs 31 | 32 | # 按7:3的比例划分训练集和验证集 33 | elif train: 34 | self.imgs = [] 35 | self.labels = [] 36 | for i in range(imgs_num): 37 | if random.random()<0.7: 38 | self.imgs.append(imgs[i]) 39 | self.labels.append(labels[i]) 40 | else: 41 | self.imgs = [] 42 | self.labels = [] 43 | for i in range(imgs_num): 44 | if random.random()>0.7: 45 | self.imgs.append(imgs[i]) 46 | self.labels.append(labels[i]) 47 | if transforms is None: 48 | normalize = T.Normalize(mean=[0.485, 0.456, 0.406], 49 | std=[0.229, 0.224, 0.225]) 50 | if self.test or not train: 51 | self.transforms = T.Compose([ 52 | T.Resize(224), 53 | T.CenterCrop(224), 54 | T.ToTensor(), 55 | normalize 56 | ]) 57 | else: 58 | self.transforms = T.Compose([ 59 | T.Resize(256), 60 | T.RandomResizedCrop(224), 61 | T.RandomHorizontalFlip(), 62 | T.ToTensor(), 63 | normalize 64 | ]) 65 | 66 | def id_to_class(self, index): 67 | return self.data_classes(index) 68 | 69 | def __getitem__(self, index): 70 | """ 71 | 一次返回一张图片的数据 72 | """ 73 | img_path = self.imgs[index] 74 | if self.test: 75 | # label = self.imgs[index].split('.')[-2].split('/')[-1] 76 | label = img_path.split('/')[-1] 77 | else: 78 | label = self.labels[index] 79 | data = Image.open(img_path) 80 | if opt.gray == True: 81 | dataRGB = data.convert('RGB') 82 | dataRGB = self.transforms(dataRGB) 83 | return dataRGB, label 84 | 85 | data = self.transforms(data) 86 | return data, label 87 | 88 | def __len__(self): 89 | return len(self.imgs) 90 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | 2 | from findplate.config import opt 3 | import os 4 | import torch as t 5 | from findplate import models 6 | from findplate.data.dataset import MyDataset 7 | from torch.utils.data import DataLoader 8 | from torchnet import meter 9 | from findplate.utils.visualize import Visualizer 10 | from tqdm import tqdm 11 | from torchvision import transforms as T 12 | 13 | 14 | 15 | 16 | 17 | def write_csv(results,file_name,col1_name,col2_name): 18 | import csv 19 | with open(file_name,'w',newline='') as f: 20 | writer = csv.writer(f) 21 | writer.writerow([col1_name,col2_name]) 22 | writer.writerows(results) 23 | 24 | 25 | def train(**kwargs): 26 | opt._parse(kwargs) 27 | vis = Visualizer(opt.env,port = opt.vis_port) 28 | 29 | # step1: configure model 30 | model = getattr(models, opt.model)() 31 | if opt.load_model_path: 32 | model.load(opt.load_model_path) 33 | model.to(opt.device) 34 | 35 | # step2: data 36 | train_data = MyDataset(opt.train_data_root,train=True) 37 | val_data = MyDataset(opt.train_data_root,train=False) 38 | train_dataloader = DataLoader(train_data,opt.batch_size, 39 | shuffle=True,num_workers=opt.num_workers) 40 | val_dataloader = DataLoader(val_data,opt.batch_size, 41 | shuffle=False,num_workers=opt.num_workers) 42 | # write id and classes into csv file 43 | data_id_to_class = [] 44 | label_idx = 0 45 | for label_name in train_data.data_classes: 46 | data_id_to_class.append([label_idx, label_name]) 47 | label_idx += 1 48 | print(data_id_to_class) 49 | id_file_name = opt.id_file 50 | write_csv(data_id_to_class,id_file_name,'label_idx','label_name') 51 | 52 | # step3: criterion and optimizer 53 | criterion = t.nn.CrossEntropyLoss() 54 | lr = opt.lr 55 | optimizer = model.get_optimizer(lr, opt.weight_decay) 56 | 57 | # step4: meters 58 | loss_meter = meter.AverageValueMeter() 59 | confusion_matrix = meter.ConfusionMeter(opt.classifier_num) 60 | previous_loss = 1e10 61 | 62 | # train 63 | for epoch in range(opt.max_epoch): 64 | 65 | loss_meter.reset() 66 | confusion_matrix.reset() 67 | 68 | for ii,(data,label) in tqdm(enumerate(train_dataloader)): 69 | 70 | # train model 71 | input = data.to(opt.device) 72 | target = label.to(opt.device) 73 | 74 | 75 | optimizer.zero_grad() 76 | score = model(input) 77 | loss = criterion(score,target) 78 | loss.backward() 79 | optimizer.step() 80 | 81 | 82 | # meters update and visualize 83 | loss_meter.add(loss.item()) 84 | # detach 一下更安全保险 85 | confusion_matrix.add(score.detach(), target.detach()) 86 | 87 | if (ii + 1)%opt.print_freq == 0: 88 | vis.plot('loss', loss_meter.value()[0]) 89 | 90 | # 进入debug模式 91 | if os.path.exists(opt.debug_file): 92 | import ipdb; 93 | ipdb.set_trace() 94 | 95 | 96 | model.save() 97 | 98 | # validate and visualize 99 | val_cm,val_accuracy = val(model,val_dataloader) 100 | 101 | vis.plot('val_accuracy',val_accuracy) 102 | vis.log("epoch:{epoch},lr:{lr},loss:{loss},train_cm:{train_cm},val_cm:{val_cm}".format( 103 | epoch = epoch,loss = loss_meter.value()[0],val_cm = str(val_cm.value()),train_cm=str(confusion_matrix.value()),lr=lr)) 104 | 105 | # update learning rate 106 | if loss_meter.value()[0] > previous_loss: 107 | lr = lr * opt.lr_decay 108 | # 第二种降低学习率的方法:不会有moment等信息的丢失 109 | for param_group in optimizer.param_groups: 110 | param_group['lr'] = lr 111 | 112 | 113 | previous_loss = loss_meter.value()[0] 114 | 115 | @t.no_grad() 116 | def val(model,dataloader): 117 | """ 118 | 计算模型在验证集上的准确率等信息 119 | """ 120 | model.eval() 121 | confusion_matrix = meter.ConfusionMeter(opt.classifier_num) 122 | for ii, (val_input, label) in tqdm(enumerate(dataloader)): 123 | val_input = val_input.to(opt.device) 124 | score = model(val_input) 125 | confusion_matrix.add(score.detach().squeeze(), label.type(t.LongTensor)) 126 | 127 | model.train() 128 | cm_value = confusion_matrix.value() 129 | cm_value_sum = 0 130 | for i in range(opt.classifier_num): 131 | cm_value_sum += cm_value[i][i] 132 | accuracy = 100. * (cm_value_sum) / (cm_value.sum()) 133 | return confusion_matrix, accuracy 134 | 135 | def help(): 136 | """ 137 | 打印帮助的信息: python file.py help 138 | """ 139 | 140 | print(""" 141 | usage : python file.py [--args=value] 142 | := train | test | help 143 | example: 144 | python {0} train --env='env0701' --lr=0.01 145 | python {0} test --dataset='path/to/dataset/root/' 146 | python {0} help 147 | avaiable args:""".format(__file__)) 148 | 149 | from inspect import getsource 150 | source = (getsource(opt.__class__)) 151 | print(source) 152 | 153 | if __name__=='__main__': 154 | import fire 155 | fire.Fire() 156 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from PIL import Image 4 | from findplate.testnetwork import detect 5 | from findplate.testnetwork import identify 6 | 7 | 8 | # 图像预处理 9 | def preprocess(img): 10 | # 将图片转换为HSV颜色空间 11 | hsv_img = cv2.cvtColor(img,cv2.COLOR_BGR2HSV) 12 | # 车牌照为蓝色,设置蓝色的hsv阈值,提取出图片中的蓝色区域 13 | h, s, v = hsv_img[:, :, 0], hsv_img[:, :, 1], hsv_img[:, :, 2] 14 | plate_color_img = (((h > 100) & (h < 124))) & (s > 120) & (v > 60) 15 | # 将图片数据格式转为8UC1的二值图 16 | plate_color_img = plate_color_img.astype('uint8') * 255 17 | # 对图片进行膨胀处理,使车牌成为一个整体 18 | element = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5)) 19 | plate_color_img = cv2.dilate(plate_color_img, element, iterations = 1) 20 | return plate_color_img 21 | 22 | # 找到车牌位置 23 | def findPlate(plate_color_img, im): 24 | # 在膨胀后的二值图像中寻找所有的轮廓,并存入数组 25 | contours, hierarchy = cv2.findContours(plate_color_img,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE) 26 | regions = [] 27 | # 遍历轮廓 28 | for contour in contours: 29 | area = cv2.contourArea(contour) 30 | # 去除面积很小的轮廓 31 | if (area < (1/500 * plate_color_img.shape[0] * plate_color_img.shape[1]) ): 32 | continue 33 | 34 | # 获取轮廓的最小外接矩形 35 | rect = cv2.minAreaRect(contour) 36 | rect_point = cv2.boxPoints(rect) 37 | rect_point = np.int0(rect_point) 38 | 39 | # 将矩形顶点重新排序,左上角开始顺时针排序 40 | k = 0 41 | min_point = rect_point[0][0] + rect_point[0][1] 42 | for i in range(len(rect_point)): 43 | if (rect_point[i][0] + rect_point[i][1] < min_point): 44 | min_point = rect_point[i][0] + rect_point[i][1] 45 | k = i 46 | 47 | new_rect = [rect_point[k], rect_point[(k+1)%4], rect_point[(k+2)%4], rect_point[(k+3)%4]] 48 | 49 | # 通过仿射变换对车牌图片进行校正,存入新图像 50 | plate_img = np.zeros((140,440,3), np.uint8) 51 | pts1 = np.float32(new_rect) 52 | pts2 = np.float32([[0,0],[440,0],[440,140],[0,140]]) 53 | matrix = cv2.getPerspectiveTransform(pts1, pts2) 54 | plate_img = cv2.warpPerspective(im, matrix, (440,140)) 55 | 56 | # 将图像转为PIL图像,喂入神经网络检测该区域是否为车牌 57 | detect_image = Image.fromarray(cv2.cvtColor(plate_img,cv2.COLOR_BGR2RGB)) 58 | result = detect(detect_image) 59 | print(result) 60 | if (result[0] == 'has'): 61 | return rect_point, plate_img 62 | 63 | return rect_point, plate_img 64 | 65 | # 拆分字符 66 | def getChar(plate_binary): 67 | plate_height, plate_width = plate_binary.shape[:2] 68 | 69 | # 将二值图像中像素投影到y轴计数 70 | y_white_pixels = [0 for x in range(plate_height)] 71 | for i in range(plate_height): 72 | for j in range(plate_width): 73 | if (plate_binary[i,j] == 255): 74 | y_white_pixels[i] += 1 75 | 76 | # 通过占行像素的比例去除边框和杂质 77 | if (y_white_pixels[i] < 0.1*plate_width or y_white_pixels[i] > 0.8*plate_width ): 78 | y_white_pixels[i] = 0 79 | 80 | 81 | # 选取最长的投影作为字符位置 82 | flag = 0 83 | index = 0 84 | y_lenth = 0 85 | y_white_list = [] 86 | for i in range(plate_height): 87 | if y_white_pixels[i] != 0: 88 | if flag == 0: 89 | index = i 90 | flag = 1 91 | y_lenth += 1 92 | elif flag == 1: 93 | flag = 0 94 | y_white_list.append([index, y_lenth]) 95 | y_lenth = 0 96 | y_white_list.sort(key=lambda x:x[1], reverse=True) 97 | y_top = y_white_list[0][0] 98 | y_bottom = y_top + y_white_list[0][1] - 1 99 | y_crop_img = plate_binary[y_top:y_bottom, :] 100 | cv2.imshow('yci',y_crop_img) 101 | # cv2.waitKey() 102 | 103 | # 将像素对x轴投影,选取最长的7个投影 104 | x_white_pixels = [0 for x in range(plate_width)] 105 | for i in range(plate_width): 106 | for j in range(y_crop_img.shape[0]): 107 | if (y_crop_img[j,i] == 255): 108 | x_white_pixels[i] += 1 109 | 110 | flag = 0 111 | index = 0 112 | x_lenth = 0 113 | x_white_list = [] 114 | for i in range(plate_width): 115 | if x_white_pixels[i] >= 6: 116 | if flag == 0: 117 | index = i 118 | flag = 1 119 | x_lenth += 1 120 | # 添加图像边缘的投影 121 | if i == plate_width - 1: 122 | x_white_list.append([index, x_lenth]) 123 | elif flag == 1: 124 | flag = 0 125 | x_white_list.append([index, x_lenth]) 126 | x_lenth = 0 127 | print(x_white_list) 128 | 129 | # 去除中间的点 130 | for x in x_white_list: 131 | flag = 0 132 | if x[1] < 20: 133 | for i in range(x[1]): 134 | if x_white_pixels[x[0]+i] > 0.5 * y_crop_img.shape[0]: 135 | flag = 1 136 | break 137 | if flag == 0: 138 | x[1] = 0 139 | print(x_white_list) 140 | 141 | # 最左边是省份代号,长度必定大于30,但“川”字需要特殊处理 142 | flag = 0 143 | for i in range(len(x_white_list)): 144 | x = x_white_list 145 | if x[i][1] < 30: 146 | if flag == 0: 147 | if x[i+1][1] < 30 and x[i+2][1] < 30 and x[i+2][0]+x[i+2][1]-x[i][0] < 55: 148 | x_white_list[i][1] = x[i+2][0]+x[i+2][1]-x[i][0] 149 | x_white_list[i+1][1] = 0 150 | x_white_list[i+2][1] = 0 151 | flag = 1 152 | else: 153 | x_white_list[i][1] = 0 154 | else: 155 | x_white_list[i][1] = 0 156 | else: 157 | break 158 | 159 | x_white_list.sort(key=lambda x:x[1], reverse=True) 160 | x_char_list = x_white_list[:7] 161 | x_char_list.sort() 162 | print(x_char_list) 163 | 164 | 165 | # 将每个字符存入单独的图像中 166 | img_array = [] 167 | for x_char in x_char_list: 168 | img_array.append(y_crop_img[:,x_char[0]:x_char[0]+x_char[1]]) 169 | for i in range(len(img_array)): 170 | cv2.imshow(str(i), img_array[i]) 171 | 172 | 173 | new_img_array = [makeImgSquare(x) for x in img_array] 174 | 175 | 176 | 177 | 178 | pil_array = [Image.fromarray(cv2.cvtColor(x,cv2.COLOR_GRAY2RGB)) for x in new_img_array] 179 | result = ''.join(identify(pil_array)) 180 | return result 181 | 182 | def makeImgSquare(img): 183 | height, width = img.shape[:2] 184 | square_length = height 185 | new_img = np.zeros((square_length, square_length, 1), np.uint8) 186 | for i in range(square_length): 187 | for j in range(width): 188 | col = j + int((square_length-width) / 2) 189 | new_img[i,col] = img[i,j] 190 | new_img = cv2.resize(new_img, (20,20), interpolation=cv2.INTER_LINEAR) 191 | return new_img 192 | 193 | def recognition(path): 194 | im = cv2.imread(path) 195 | # im = cv2.imread('./imgs/pictures/42.jpg') 196 | height, width = im.shape[:2] 197 | plate_color_img = preprocess(im) 198 | # cv2.imshow('pci',plate_color_img) 199 | rect, plate = findPlate(plate_color_img, im) 200 | cv2.drawContours(im,[rect],-1,(0,255,0),3) 201 | cv2.imshow('im',im) 202 | cv2.imshow('plate', plate) 203 | 204 | 205 | 206 | plate_binary = cv2.cvtColor(plate,cv2.COLOR_BGR2GRAY) 207 | ret, plate_binary = cv2.threshold(plate_binary, 0, 255, cv2.THRESH_OTSU) 208 | cv2.imshow('binary', plate_binary) 209 | 210 | result = getChar(plate_binary) 211 | print(result) 212 | 213 | # cv2.waitKey() 214 | return result 215 | 216 | if __name__ == "__main__": 217 | path = input('Please input path:') 218 | recognition(path) --------------------------------------------------------------------------------