├── findplate
├── data
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ └── dataset.cpython-36.pyc
│ └── dataset.py
├── utils
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ └── visualize.cpython-36.pyc
│ └── visualize.py
├── plate.csv
├── char.csv
├── models
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── alexnet.cpython-36.pyc
│ │ ├── resnet34.cpython-36.pyc
│ │ ├── resnet50.cpython-36.pyc
│ │ ├── squeezenet.cpython-36.pyc
│ │ ├── basic_module.cpython-36.pyc
│ │ └── squeezenet_gray.cpython-36.pyc
│ ├── squeezenet.py
│ ├── squeezenet_gray.py
│ └── basic_module.py
├── checkpoints
│ ├── squeezenet_char.pth
│ └── squeezenet_plate.pth
├── config.py
└── testnetwork.py
├── 演示动画.gif
├── README.md
├── requirements.txt
├── gui.py
├── network.py
└── main.py
/findplate/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/findplate/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .visualize import Visualizer
--------------------------------------------------------------------------------
/findplate/plate.csv:
--------------------------------------------------------------------------------
1 | label_idx,label_name
2 | 0,has
3 | 1,no
4 |
--------------------------------------------------------------------------------
/演示动画.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CescMessi/carplaterecognize/HEAD/演示动画.gif
--------------------------------------------------------------------------------
/findplate/char.csv:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CescMessi/carplaterecognize/HEAD/findplate/char.csv
--------------------------------------------------------------------------------
/findplate/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .squeezenet import SqueezeNet
2 | from .squeezenet_gray import SqueezeNetGray
--------------------------------------------------------------------------------
/findplate/checkpoints/squeezenet_char.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CescMessi/carplaterecognize/HEAD/findplate/checkpoints/squeezenet_char.pth
--------------------------------------------------------------------------------
/findplate/checkpoints/squeezenet_plate.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CescMessi/carplaterecognize/HEAD/findplate/checkpoints/squeezenet_plate.pth
--------------------------------------------------------------------------------
/findplate/data/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CescMessi/carplaterecognize/HEAD/findplate/data/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/findplate/data/__pycache__/dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CescMessi/carplaterecognize/HEAD/findplate/data/__pycache__/dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/findplate/models/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CescMessi/carplaterecognize/HEAD/findplate/models/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/findplate/models/__pycache__/alexnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CescMessi/carplaterecognize/HEAD/findplate/models/__pycache__/alexnet.cpython-36.pyc
--------------------------------------------------------------------------------
/findplate/models/__pycache__/resnet34.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CescMessi/carplaterecognize/HEAD/findplate/models/__pycache__/resnet34.cpython-36.pyc
--------------------------------------------------------------------------------
/findplate/models/__pycache__/resnet50.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CescMessi/carplaterecognize/HEAD/findplate/models/__pycache__/resnet50.cpython-36.pyc
--------------------------------------------------------------------------------
/findplate/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CescMessi/carplaterecognize/HEAD/findplate/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/findplate/utils/__pycache__/visualize.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CescMessi/carplaterecognize/HEAD/findplate/utils/__pycache__/visualize.cpython-36.pyc
--------------------------------------------------------------------------------
/findplate/models/__pycache__/squeezenet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CescMessi/carplaterecognize/HEAD/findplate/models/__pycache__/squeezenet.cpython-36.pyc
--------------------------------------------------------------------------------
/findplate/models/__pycache__/basic_module.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CescMessi/carplaterecognize/HEAD/findplate/models/__pycache__/basic_module.cpython-36.pyc
--------------------------------------------------------------------------------
/findplate/models/__pycache__/squeezenet_gray.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CescMessi/carplaterecognize/HEAD/findplate/models/__pycache__/squeezenet_gray.cpython-36.pyc
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # carplaterecognize
2 | 使用pytorch和opencv的简易识别车牌程序
3 |
4 | 使用opencv找到类似车牌的物体,接着使用模型判断是否为车牌。若为车牌,将车牌图像拉伸至标准形状,对字符进行分割,每个字符单独使用模型进行识别。
5 |
6 | 模型训练代码已经包含,使用的是简单的squeezenet,可以自行修改。
7 |
8 | 
9 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.16.2
2 | tqdm>=4.31.1
3 | torchvision>=0.2.2
4 | torchnet>=0.0.4
5 | visdom>=0.1.8.8
6 | fire>=0.1.3
7 | opencv_python>=4.1.0.25
8 | ipdb>=0.12
9 | torch>=1.0.1
10 | Pillow>=6.0.0
11 |
--------------------------------------------------------------------------------
/gui.py:
--------------------------------------------------------------------------------
1 | from tkinter import *
2 | from tkinter.filedialog import askopenfilename
3 | from main import recognition
4 |
5 | def select():
6 | file_path = askopenfilename(title=u'选择文件', filetypes=[(".JPG", ".jpg")])
7 | number.set(recognition(file_path))
8 |
9 |
10 | root = Tk()
11 | root.title('车牌检测')
12 | number = StringVar()
13 | Label(root, text= '车牌号').grid(row = 0, column = 0)
14 | Entry(root, textvariable = number).grid(row = 0, column = 1)
15 | Button(root, text = '打开图片', command = select).grid(row = 1, column = 0)
16 | Button(root, text = '退出', command = root.quit).grid(row = 1, column = 3)
17 | root.mainloop()
18 |
--------------------------------------------------------------------------------
/findplate/models/squeezenet.py:
--------------------------------------------------------------------------------
1 | from torchvision.models import squeezenet1_1
2 | from findplate.models.basic_module import BasicModule
3 | from torch import nn
4 | from torch.optim import Adam
5 | from findplate.config import opt
6 |
7 | class SqueezeNet(BasicModule):
8 | def __init__(self, num_classes=2):
9 | super(SqueezeNet, self).__init__()
10 | self.model_name = 'squeezenet'
11 | self.model = squeezenet1_1(pretrained=False)
12 | # 修改 原始的num_class: 预训练模型是1000分类
13 | self.model.num_classes = num_classes
14 | self.model.classifier = nn.Sequential(
15 | nn.Dropout(p=0.5),
16 | nn.Conv2d(512, num_classes, 1),
17 | nn.ReLU(inplace=True),
18 | nn.AvgPool2d(13, stride=1)
19 | )
20 |
21 | def forward(self,x):
22 | return self.model(x)
23 |
24 | def get_optimizer(self, lr, weight_decay):
25 | # 因为使用了预训练模型,我们只需要训练后面的分类
26 | # 前面的特征提取部分可以保持不变
27 | return Adam(self.model.classifier.parameters(), lr, weight_decay=weight_decay)
--------------------------------------------------------------------------------
/findplate/models/squeezenet_gray.py:
--------------------------------------------------------------------------------
1 | from torchvision.models import squeezenet1_1
2 | from findplate.models.basic_module import BasicModule
3 | from torch import nn
4 | from torch.optim import Adam
5 | from findplate.config import opt
6 |
7 | class SqueezeNetGray(BasicModule):
8 | def __init__(self, num_classes=65):
9 | super(SqueezeNetGray, self).__init__()
10 | self.model_name = 'squeezenet_gray'
11 | self.model = squeezenet1_1(pretrained=False)
12 | # 修改 原始的num_class: 预训练模型是1000分类
13 | self.model.num_classes = num_classes
14 | self.model.classifier = nn.Sequential(
15 | nn.Dropout(p=0.5),
16 | nn.Conv2d(512, num_classes, 1),
17 | nn.ReLU(inplace=True),
18 | nn.AvgPool2d(13, stride=1)
19 | )
20 |
21 | def forward(self,x):
22 | return self.model(x)
23 |
24 | def get_optimizer(self, lr, weight_decay):
25 | # 因为使用了预训练模型,我们只需要训练后面的分类
26 | # 前面的特征提取部分可以保持不变
27 | return Adam(self.model.classifier.parameters(), lr, weight_decay=weight_decay)
--------------------------------------------------------------------------------
/findplate/models/basic_module.py:
--------------------------------------------------------------------------------
1 |
2 | import torch as t
3 | import time
4 |
5 |
6 | class BasicModule(t.nn.Module):
7 | """
8 | 封装了nn.Module,主要是提供了save和load两个方法
9 | """
10 |
11 | def __init__(self):
12 | super(BasicModule,self).__init__()
13 | self.model_name=str(type(self))# 默认名字
14 |
15 | def load(self, path):
16 | """
17 | 可加载指定路径的模型
18 | """
19 | self.load_state_dict(t.load(path, map_location='cpu'))
20 |
21 | def save(self, name=None):
22 | """
23 | 保存模型,默认使用“模型名字+时间”作为文件名
24 | """
25 | if name is None:
26 | prefix = './findplate/checkpoints/' + self.model_name + '_'
27 | name = time.strftime(prefix + '%m%d_%H%M%S.pth')
28 | t.save(self.state_dict(), name)
29 | return name
30 |
31 | def get_optimizer(self, lr, weight_decay):
32 | return t.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay)
33 |
34 |
35 | class Flat(t.nn.Module):
36 | """
37 | 把输入reshape成(batch_size,dim_length)
38 | """
39 |
40 | def __init__(self):
41 | super(Flat, self).__init__()
42 | #self.size = size
43 |
44 | def forward(self, x):
45 | return x.view(x.size(0), -1)
46 |
--------------------------------------------------------------------------------
/findplate/config.py:
--------------------------------------------------------------------------------
1 |
2 | import warnings
3 | import torch as t
4 |
5 | class DefaultConfig(object):
6 | env = 'default' # visdom 环境
7 | vis_port =8097 # visdom 端口
8 | model = 'SqueezeNet' # 使用的模型,名字必须与models/__init__.py中的名字一致
9 | classifier_num = 2 # 分类器最终的分类数量
10 | gray = False # 读取图片是否为灰度图
11 |
12 | train_data_root = './imgs/images/cnn_plate_train/' # 训练集存放路径
13 | test_data_root = './data/test/' # 测试集存放路径
14 | load_model_path = None # 加载预训练的模型的路径,为None代表不加载
15 |
16 | batch_size = 16 # batch size
17 | use_gpu = True # user GPU or not
18 | num_workers = 0 # how many workers for loading data
19 | print_freq = 20 # print info every N batch
20 |
21 | debug_file = '/tmp/debug' # if os.path.exists(debug_file): enter ipdb
22 | result_file = 'result.csv'
23 | id_file = './findplate/plate.csv'
24 |
25 | max_epoch = 100
26 | lr = 0.001 # initial learning rate
27 | lr_decay = 0.5 # when val_loss increase, lr = lr*lr_decay
28 | weight_decay = 0e-5 # 损失函数
29 |
30 |
31 | def _parse(self, kwargs):
32 | """
33 | 根据字典kwargs 更新 config参数
34 | """
35 | for k, v in kwargs.items():
36 | if not hasattr(self, k):
37 | warnings.warn("Warning: opt has not attribut %s" % k)
38 | setattr(self, k, v)
39 |
40 | self.device =t.device('cuda') if self.use_gpu else t.device('cpu')
41 |
42 |
43 | print('user config:')
44 | for k, v in self.__class__.__dict__.items():
45 | if not k.startswith('_'):
46 | print(k, getattr(self, k))
47 |
48 | opt = DefaultConfig()
49 |
--------------------------------------------------------------------------------
/findplate/utils/visualize.py:
--------------------------------------------------------------------------------
1 |
2 | import visdom
3 | import time
4 | import numpy as np
5 |
6 |
7 | class Visualizer(object):
8 | """
9 | 封装了visdom的基本操作,但是你仍然可以通过`self.vis.function`
10 | 调用原生的visdom接口
11 | """
12 |
13 | def __init__(self, env='default', **kwargs):
14 | self.vis = visdom.Visdom(env=env,use_incoming_socket=False, **kwargs)
15 |
16 | # 画的第几个数,相当于横座标
17 | # 保存(’loss',23) 即loss的第23个点
18 | self.index = {}
19 | self.log_text = ''
20 |
21 | def reinit(self, env='default', **kwargs):
22 | """
23 | 修改visdom的配置
24 | """
25 | self.vis = visdom.Visdom(env=env, **kwargs)
26 | return self
27 |
28 | def plot_many(self, d):
29 | """
30 | 一次plot多个
31 | @params d: dict (name,value) i.e. ('loss',0.11)
32 | """
33 | for k, v in d.items():
34 | self.plot(k, v)
35 |
36 | def img_many(self, d):
37 | for k, v in d.items():
38 | self.img(k, v)
39 |
40 | def plot(self, name, y, **kwargs):
41 | """
42 | self.plot('loss',1.00)
43 | """
44 | x = self.index.get(name, 0)
45 | self.vis.line(Y=np.array([y]), X=np.array([x]),
46 | win=name,
47 | opts=dict(title=name),
48 | update=None if x == 0 else 'append',
49 | **kwargs
50 | )
51 | self.index[name] = x + 1
52 |
53 | def img(self, name, img_, **kwargs):
54 | """
55 | self.img('input_img',t.Tensor(64,64))
56 | self.img('input_imgs',t.Tensor(3,64,64))
57 | self.img('input_imgs',t.Tensor(100,1,64,64))
58 | self.img('input_imgs',t.Tensor(100,3,64,64),nrows=10)
59 |
60 | !!!don‘t ~~self.img('input_imgs',t.Tensor(100,64,64),nrows=10)~~!!!
61 | """
62 | self.vis.images(img_.cpu().numpy(),
63 | win=name,
64 | opts=dict(title=name),
65 | **kwargs
66 | )
67 |
68 | def log(self, info, win='log_text'):
69 | """
70 | self.log({'loss':1,'lr':0.0001})
71 | """
72 |
73 | self.log_text += ('[{time}] {info}
'.format(
74 | time=time.strftime('%m%d_%H%M%S'),
75 | info=info))
76 | self.vis.text(self.log_text, win)
77 |
78 | def __getattr__(self, name):
79 | return getattr(self.vis, name)
80 |
--------------------------------------------------------------------------------
/findplate/testnetwork.py:
--------------------------------------------------------------------------------
1 |
2 | from findplate.config import opt
3 | import os
4 | import sys
5 | import torch as t
6 | from findplate import models
7 | from findplate.data.dataset import MyDataset
8 | from torch.utils.data import DataLoader
9 | from torchnet import meter
10 | from findplate.utils.visualize import Visualizer
11 | from tqdm import tqdm
12 | from torchvision import transforms as T
13 |
14 | def resource_path(relative_path):
15 | try:
16 | base_path = sys._MEIPASS
17 | except Exception:
18 | base_path = os.path.abspath(".")
19 |
20 | return os.path.join(base_path, relative_path)
21 |
22 | # 判断是否为车牌
23 | @t.no_grad()
24 | def detect(img):
25 | # 载入模型和参数
26 | model = getattr(models, opt.model)().eval()
27 | model.load(resource_path('findplate/checkpoints/squeezenet_plate.pth'))
28 | # 归一化
29 | normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
30 | std=[0.229, 0.224, 0.225])
31 | # 变换
32 | transforms = T.Compose([
33 | T.Resize(224),
34 | T.CenterCrop(224),
35 | T.ToTensor(),
36 | normalize
37 | ])
38 | inputdata = transforms(img)
39 | inputdata = t.unsqueeze(inputdata, 0)
40 | # 将图像喂入模型,获取标签
41 | score = model(inputdata)
42 | id_label_dict = row_csv2dict(resource_path('findplate/plate.csv'))
43 | label = score.max(dim = 1)[1].detach().tolist()
44 | label = [id_label_dict[str(i)] for i in label]
45 | return label
46 |
47 | # 识别字符
48 | @t.no_grad()
49 | def identify(img_array):
50 | model = getattr(models, 'SqueezeNetGray')().eval()
51 | model.load(resource_path('findplate/checkpoints/squeezenet_char.pth'))
52 | normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
53 | std=[0.229, 0.224, 0.225])
54 | transforms = T.Compose([
55 | T.Resize(224),
56 | T.CenterCrop(224),
57 | T.ToTensor(),
58 | normalize
59 | ])
60 |
61 | # 将多个图像合并到一个tensor中
62 | flag = 0
63 | for img in img_array:
64 | data = transforms(img)
65 | data = t.unsqueeze(data, 0)
66 | if flag == 0:
67 | inputdata = data
68 | flag = 1
69 | else:
70 | inputdata = t.cat((inputdata, data), 0)
71 |
72 |
73 | score = model(inputdata)
74 | id_label_dict = row_csv2dict(resource_path('findplate/char.csv'))
75 | label = score.max(dim = 1)[1].detach().tolist()
76 | label = [id_label_dict[str(i)] for i in label]
77 | return label
78 |
79 |
80 |
81 |
82 | def row_csv2dict(csv_file):
83 | import csv
84 | dict_club={}
85 | with open(csv_file)as f:
86 | reader=csv.reader(f,delimiter=',')
87 | for row in reader:
88 | dict_club[row[0]]=row[1]
89 | return dict_club
90 |
91 |
--------------------------------------------------------------------------------
/findplate/data/dataset.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from PIL import Image
4 | from torch.utils import data
5 | import numpy as np
6 | from torchvision import transforms as T
7 | from torchvision.datasets import ImageFolder
8 | import random
9 | from findplate.config import opt
10 |
11 |
12 | class MyDataset(data.Dataset):
13 |
14 | def __init__(self, root, transforms=None, train=True, test=False):
15 | """
16 | 主要目标: 获取所有图片的地址,并根据训练,验证,测试划分数据
17 | """
18 | self.test = test
19 |
20 | if self.test:
21 | imgs = [os.path.join(root, img) for img in os.listdir(root)]
22 | else:
23 | dataset = ImageFolder(root)
24 | self.data_classes = dataset.classes
25 | imgs = [dataset.imgs[i][0] for i in range(len(dataset.imgs))]
26 | labels = [dataset.imgs[i][1] for i in range(len(dataset.imgs))]
27 | imgs_num = len(imgs)
28 |
29 | if self.test:
30 | self.imgs = imgs
31 |
32 | # 按7:3的比例划分训练集和验证集
33 | elif train:
34 | self.imgs = []
35 | self.labels = []
36 | for i in range(imgs_num):
37 | if random.random()<0.7:
38 | self.imgs.append(imgs[i])
39 | self.labels.append(labels[i])
40 | else:
41 | self.imgs = []
42 | self.labels = []
43 | for i in range(imgs_num):
44 | if random.random()>0.7:
45 | self.imgs.append(imgs[i])
46 | self.labels.append(labels[i])
47 | if transforms is None:
48 | normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
49 | std=[0.229, 0.224, 0.225])
50 | if self.test or not train:
51 | self.transforms = T.Compose([
52 | T.Resize(224),
53 | T.CenterCrop(224),
54 | T.ToTensor(),
55 | normalize
56 | ])
57 | else:
58 | self.transforms = T.Compose([
59 | T.Resize(256),
60 | T.RandomResizedCrop(224),
61 | T.RandomHorizontalFlip(),
62 | T.ToTensor(),
63 | normalize
64 | ])
65 |
66 | def id_to_class(self, index):
67 | return self.data_classes(index)
68 |
69 | def __getitem__(self, index):
70 | """
71 | 一次返回一张图片的数据
72 | """
73 | img_path = self.imgs[index]
74 | if self.test:
75 | # label = self.imgs[index].split('.')[-2].split('/')[-1]
76 | label = img_path.split('/')[-1]
77 | else:
78 | label = self.labels[index]
79 | data = Image.open(img_path)
80 | if opt.gray == True:
81 | dataRGB = data.convert('RGB')
82 | dataRGB = self.transforms(dataRGB)
83 | return dataRGB, label
84 |
85 | data = self.transforms(data)
86 | return data, label
87 |
88 | def __len__(self):
89 | return len(self.imgs)
90 |
--------------------------------------------------------------------------------
/network.py:
--------------------------------------------------------------------------------
1 |
2 | from findplate.config import opt
3 | import os
4 | import torch as t
5 | from findplate import models
6 | from findplate.data.dataset import MyDataset
7 | from torch.utils.data import DataLoader
8 | from torchnet import meter
9 | from findplate.utils.visualize import Visualizer
10 | from tqdm import tqdm
11 | from torchvision import transforms as T
12 |
13 |
14 |
15 |
16 |
17 | def write_csv(results,file_name,col1_name,col2_name):
18 | import csv
19 | with open(file_name,'w',newline='') as f:
20 | writer = csv.writer(f)
21 | writer.writerow([col1_name,col2_name])
22 | writer.writerows(results)
23 |
24 |
25 | def train(**kwargs):
26 | opt._parse(kwargs)
27 | vis = Visualizer(opt.env,port = opt.vis_port)
28 |
29 | # step1: configure model
30 | model = getattr(models, opt.model)()
31 | if opt.load_model_path:
32 | model.load(opt.load_model_path)
33 | model.to(opt.device)
34 |
35 | # step2: data
36 | train_data = MyDataset(opt.train_data_root,train=True)
37 | val_data = MyDataset(opt.train_data_root,train=False)
38 | train_dataloader = DataLoader(train_data,opt.batch_size,
39 | shuffle=True,num_workers=opt.num_workers)
40 | val_dataloader = DataLoader(val_data,opt.batch_size,
41 | shuffle=False,num_workers=opt.num_workers)
42 | # write id and classes into csv file
43 | data_id_to_class = []
44 | label_idx = 0
45 | for label_name in train_data.data_classes:
46 | data_id_to_class.append([label_idx, label_name])
47 | label_idx += 1
48 | print(data_id_to_class)
49 | id_file_name = opt.id_file
50 | write_csv(data_id_to_class,id_file_name,'label_idx','label_name')
51 |
52 | # step3: criterion and optimizer
53 | criterion = t.nn.CrossEntropyLoss()
54 | lr = opt.lr
55 | optimizer = model.get_optimizer(lr, opt.weight_decay)
56 |
57 | # step4: meters
58 | loss_meter = meter.AverageValueMeter()
59 | confusion_matrix = meter.ConfusionMeter(opt.classifier_num)
60 | previous_loss = 1e10
61 |
62 | # train
63 | for epoch in range(opt.max_epoch):
64 |
65 | loss_meter.reset()
66 | confusion_matrix.reset()
67 |
68 | for ii,(data,label) in tqdm(enumerate(train_dataloader)):
69 |
70 | # train model
71 | input = data.to(opt.device)
72 | target = label.to(opt.device)
73 |
74 |
75 | optimizer.zero_grad()
76 | score = model(input)
77 | loss = criterion(score,target)
78 | loss.backward()
79 | optimizer.step()
80 |
81 |
82 | # meters update and visualize
83 | loss_meter.add(loss.item())
84 | # detach 一下更安全保险
85 | confusion_matrix.add(score.detach(), target.detach())
86 |
87 | if (ii + 1)%opt.print_freq == 0:
88 | vis.plot('loss', loss_meter.value()[0])
89 |
90 | # 进入debug模式
91 | if os.path.exists(opt.debug_file):
92 | import ipdb;
93 | ipdb.set_trace()
94 |
95 |
96 | model.save()
97 |
98 | # validate and visualize
99 | val_cm,val_accuracy = val(model,val_dataloader)
100 |
101 | vis.plot('val_accuracy',val_accuracy)
102 | vis.log("epoch:{epoch},lr:{lr},loss:{loss},train_cm:{train_cm},val_cm:{val_cm}".format(
103 | epoch = epoch,loss = loss_meter.value()[0],val_cm = str(val_cm.value()),train_cm=str(confusion_matrix.value()),lr=lr))
104 |
105 | # update learning rate
106 | if loss_meter.value()[0] > previous_loss:
107 | lr = lr * opt.lr_decay
108 | # 第二种降低学习率的方法:不会有moment等信息的丢失
109 | for param_group in optimizer.param_groups:
110 | param_group['lr'] = lr
111 |
112 |
113 | previous_loss = loss_meter.value()[0]
114 |
115 | @t.no_grad()
116 | def val(model,dataloader):
117 | """
118 | 计算模型在验证集上的准确率等信息
119 | """
120 | model.eval()
121 | confusion_matrix = meter.ConfusionMeter(opt.classifier_num)
122 | for ii, (val_input, label) in tqdm(enumerate(dataloader)):
123 | val_input = val_input.to(opt.device)
124 | score = model(val_input)
125 | confusion_matrix.add(score.detach().squeeze(), label.type(t.LongTensor))
126 |
127 | model.train()
128 | cm_value = confusion_matrix.value()
129 | cm_value_sum = 0
130 | for i in range(opt.classifier_num):
131 | cm_value_sum += cm_value[i][i]
132 | accuracy = 100. * (cm_value_sum) / (cm_value.sum())
133 | return confusion_matrix, accuracy
134 |
135 | def help():
136 | """
137 | 打印帮助的信息: python file.py help
138 | """
139 |
140 | print("""
141 | usage : python file.py [--args=value]
142 | := train | test | help
143 | example:
144 | python {0} train --env='env0701' --lr=0.01
145 | python {0} test --dataset='path/to/dataset/root/'
146 | python {0} help
147 | avaiable args:""".format(__file__))
148 |
149 | from inspect import getsource
150 | source = (getsource(opt.__class__))
151 | print(source)
152 |
153 | if __name__=='__main__':
154 | import fire
155 | fire.Fire()
156 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | from PIL import Image
4 | from findplate.testnetwork import detect
5 | from findplate.testnetwork import identify
6 |
7 |
8 | # 图像预处理
9 | def preprocess(img):
10 | # 将图片转换为HSV颜色空间
11 | hsv_img = cv2.cvtColor(img,cv2.COLOR_BGR2HSV)
12 | # 车牌照为蓝色,设置蓝色的hsv阈值,提取出图片中的蓝色区域
13 | h, s, v = hsv_img[:, :, 0], hsv_img[:, :, 1], hsv_img[:, :, 2]
14 | plate_color_img = (((h > 100) & (h < 124))) & (s > 120) & (v > 60)
15 | # 将图片数据格式转为8UC1的二值图
16 | plate_color_img = plate_color_img.astype('uint8') * 255
17 | # 对图片进行膨胀处理,使车牌成为一个整体
18 | element = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
19 | plate_color_img = cv2.dilate(plate_color_img, element, iterations = 1)
20 | return plate_color_img
21 |
22 | # 找到车牌位置
23 | def findPlate(plate_color_img, im):
24 | # 在膨胀后的二值图像中寻找所有的轮廓,并存入数组
25 | contours, hierarchy = cv2.findContours(plate_color_img,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE)
26 | regions = []
27 | # 遍历轮廓
28 | for contour in contours:
29 | area = cv2.contourArea(contour)
30 | # 去除面积很小的轮廓
31 | if (area < (1/500 * plate_color_img.shape[0] * plate_color_img.shape[1]) ):
32 | continue
33 |
34 | # 获取轮廓的最小外接矩形
35 | rect = cv2.minAreaRect(contour)
36 | rect_point = cv2.boxPoints(rect)
37 | rect_point = np.int0(rect_point)
38 |
39 | # 将矩形顶点重新排序,左上角开始顺时针排序
40 | k = 0
41 | min_point = rect_point[0][0] + rect_point[0][1]
42 | for i in range(len(rect_point)):
43 | if (rect_point[i][0] + rect_point[i][1] < min_point):
44 | min_point = rect_point[i][0] + rect_point[i][1]
45 | k = i
46 |
47 | new_rect = [rect_point[k], rect_point[(k+1)%4], rect_point[(k+2)%4], rect_point[(k+3)%4]]
48 |
49 | # 通过仿射变换对车牌图片进行校正,存入新图像
50 | plate_img = np.zeros((140,440,3), np.uint8)
51 | pts1 = np.float32(new_rect)
52 | pts2 = np.float32([[0,0],[440,0],[440,140],[0,140]])
53 | matrix = cv2.getPerspectiveTransform(pts1, pts2)
54 | plate_img = cv2.warpPerspective(im, matrix, (440,140))
55 |
56 | # 将图像转为PIL图像,喂入神经网络检测该区域是否为车牌
57 | detect_image = Image.fromarray(cv2.cvtColor(plate_img,cv2.COLOR_BGR2RGB))
58 | result = detect(detect_image)
59 | print(result)
60 | if (result[0] == 'has'):
61 | return rect_point, plate_img
62 |
63 | return rect_point, plate_img
64 |
65 | # 拆分字符
66 | def getChar(plate_binary):
67 | plate_height, plate_width = plate_binary.shape[:2]
68 |
69 | # 将二值图像中像素投影到y轴计数
70 | y_white_pixels = [0 for x in range(plate_height)]
71 | for i in range(plate_height):
72 | for j in range(plate_width):
73 | if (plate_binary[i,j] == 255):
74 | y_white_pixels[i] += 1
75 |
76 | # 通过占行像素的比例去除边框和杂质
77 | if (y_white_pixels[i] < 0.1*plate_width or y_white_pixels[i] > 0.8*plate_width ):
78 | y_white_pixels[i] = 0
79 |
80 |
81 | # 选取最长的投影作为字符位置
82 | flag = 0
83 | index = 0
84 | y_lenth = 0
85 | y_white_list = []
86 | for i in range(plate_height):
87 | if y_white_pixels[i] != 0:
88 | if flag == 0:
89 | index = i
90 | flag = 1
91 | y_lenth += 1
92 | elif flag == 1:
93 | flag = 0
94 | y_white_list.append([index, y_lenth])
95 | y_lenth = 0
96 | y_white_list.sort(key=lambda x:x[1], reverse=True)
97 | y_top = y_white_list[0][0]
98 | y_bottom = y_top + y_white_list[0][1] - 1
99 | y_crop_img = plate_binary[y_top:y_bottom, :]
100 | cv2.imshow('yci',y_crop_img)
101 | # cv2.waitKey()
102 |
103 | # 将像素对x轴投影,选取最长的7个投影
104 | x_white_pixels = [0 for x in range(plate_width)]
105 | for i in range(plate_width):
106 | for j in range(y_crop_img.shape[0]):
107 | if (y_crop_img[j,i] == 255):
108 | x_white_pixels[i] += 1
109 |
110 | flag = 0
111 | index = 0
112 | x_lenth = 0
113 | x_white_list = []
114 | for i in range(plate_width):
115 | if x_white_pixels[i] >= 6:
116 | if flag == 0:
117 | index = i
118 | flag = 1
119 | x_lenth += 1
120 | # 添加图像边缘的投影
121 | if i == plate_width - 1:
122 | x_white_list.append([index, x_lenth])
123 | elif flag == 1:
124 | flag = 0
125 | x_white_list.append([index, x_lenth])
126 | x_lenth = 0
127 | print(x_white_list)
128 |
129 | # 去除中间的点
130 | for x in x_white_list:
131 | flag = 0
132 | if x[1] < 20:
133 | for i in range(x[1]):
134 | if x_white_pixels[x[0]+i] > 0.5 * y_crop_img.shape[0]:
135 | flag = 1
136 | break
137 | if flag == 0:
138 | x[1] = 0
139 | print(x_white_list)
140 |
141 | # 最左边是省份代号,长度必定大于30,但“川”字需要特殊处理
142 | flag = 0
143 | for i in range(len(x_white_list)):
144 | x = x_white_list
145 | if x[i][1] < 30:
146 | if flag == 0:
147 | if x[i+1][1] < 30 and x[i+2][1] < 30 and x[i+2][0]+x[i+2][1]-x[i][0] < 55:
148 | x_white_list[i][1] = x[i+2][0]+x[i+2][1]-x[i][0]
149 | x_white_list[i+1][1] = 0
150 | x_white_list[i+2][1] = 0
151 | flag = 1
152 | else:
153 | x_white_list[i][1] = 0
154 | else:
155 | x_white_list[i][1] = 0
156 | else:
157 | break
158 |
159 | x_white_list.sort(key=lambda x:x[1], reverse=True)
160 | x_char_list = x_white_list[:7]
161 | x_char_list.sort()
162 | print(x_char_list)
163 |
164 |
165 | # 将每个字符存入单独的图像中
166 | img_array = []
167 | for x_char in x_char_list:
168 | img_array.append(y_crop_img[:,x_char[0]:x_char[0]+x_char[1]])
169 | for i in range(len(img_array)):
170 | cv2.imshow(str(i), img_array[i])
171 |
172 |
173 | new_img_array = [makeImgSquare(x) for x in img_array]
174 |
175 |
176 |
177 |
178 | pil_array = [Image.fromarray(cv2.cvtColor(x,cv2.COLOR_GRAY2RGB)) for x in new_img_array]
179 | result = ''.join(identify(pil_array))
180 | return result
181 |
182 | def makeImgSquare(img):
183 | height, width = img.shape[:2]
184 | square_length = height
185 | new_img = np.zeros((square_length, square_length, 1), np.uint8)
186 | for i in range(square_length):
187 | for j in range(width):
188 | col = j + int((square_length-width) / 2)
189 | new_img[i,col] = img[i,j]
190 | new_img = cv2.resize(new_img, (20,20), interpolation=cv2.INTER_LINEAR)
191 | return new_img
192 |
193 | def recognition(path):
194 | im = cv2.imread(path)
195 | # im = cv2.imread('./imgs/pictures/42.jpg')
196 | height, width = im.shape[:2]
197 | plate_color_img = preprocess(im)
198 | # cv2.imshow('pci',plate_color_img)
199 | rect, plate = findPlate(plate_color_img, im)
200 | cv2.drawContours(im,[rect],-1,(0,255,0),3)
201 | cv2.imshow('im',im)
202 | cv2.imshow('plate', plate)
203 |
204 |
205 |
206 | plate_binary = cv2.cvtColor(plate,cv2.COLOR_BGR2GRAY)
207 | ret, plate_binary = cv2.threshold(plate_binary, 0, 255, cv2.THRESH_OTSU)
208 | cv2.imshow('binary', plate_binary)
209 |
210 | result = getChar(plate_binary)
211 | print(result)
212 |
213 | # cv2.waitKey()
214 | return result
215 |
216 | if __name__ == "__main__":
217 | path = input('Please input path:')
218 | recognition(path)
--------------------------------------------------------------------------------