├── db.sqlite3 ├── crnn ├── __init__.py ├── .crnn_keras.py.swp ├── config.py ├── utils.py ├── crnn_torch.py ├── crnn_torch_chinese.py ├── util.py ├── network_torch.py ├── dataset.py └── keys.py ├── itypes ├── __init__.py ├── migrations │ └── __init__.py ├── admin.py ├── apps.py ├── urls.py └── views.py ├── utils ├── __init__.py ├── aws │ ├── __init__.py │ ├── mime.sh │ ├── resume.py │ └── userdata.sh ├── wandb_logging │ ├── __init__.py │ ├── log_dataset.py │ └── wandb_utils.py ├── google_app_engine │ ├── additional_requirements.txt │ ├── app.yaml │ └── Dockerfile ├── activations.py ├── google_utils.py ├── autoanchor.py ├── metrics.py ├── loss.py ├── torch_utils.py └── plots.py ├── invoice_ocr ├── __init__.py ├── wsgi.py ├── urls.py └── settings.py ├── images └── a9f89e262f4d10220724004d99f3fc8.png ├── util ├── MyException.py ├── response.py ├── create_img.py ├── ofd_util.py ├── qrcode.py ├── pdf_read.py └── slide.py ├── manage.py ├── obj_det ├── ocr_context.py ├── objd_util.py ├── roll_detect.py ├── evat_detect.py ├── no_tax_detect.py ├── vat_detect.py ├── taxi_detect.py ├── detect.py ├── tra_detect.py └── title_detect.py ├── single_ocr └── opencv_direction.py ├── manage.spec ├── README.md └── requirements.txt /db.sqlite3: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /crnn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /itypes/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /invoice_ocr/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/aws/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /itypes/migrations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/wandb_logging/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /itypes/admin.py: -------------------------------------------------------------------------------- 1 | from django.contrib import admin 2 | 3 | # Register your models here. 4 | -------------------------------------------------------------------------------- /crnn/.crnn_keras.py.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/384863451/invoice_ocr/HEAD/crnn/.crnn_keras.py.swp -------------------------------------------------------------------------------- /itypes/apps.py: -------------------------------------------------------------------------------- 1 | from django.apps import AppConfig 2 | 3 | 4 | class ItypesConfig(AppConfig): 5 | name = 'itypes' 6 | -------------------------------------------------------------------------------- /images/a9f89e262f4d10220724004d99f3fc8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/384863451/invoice_ocr/HEAD/images/a9f89e262f4d10220724004d99f3fc8.png -------------------------------------------------------------------------------- /util/MyException.py: -------------------------------------------------------------------------------- 1 | class MyException(Exception): # 继承异常类 2 | def __init__(self, code, msg): 3 | self.code = code 4 | self.msg = msg -------------------------------------------------------------------------------- /utils/google_app_engine/additional_requirements.txt: -------------------------------------------------------------------------------- 1 | # add these requirements in your app on top of the existing ones 2 | pip==18.1 3 | Flask==1.0.2 4 | gunicorn==19.9.0 5 | -------------------------------------------------------------------------------- /crnn/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | pwd = os.getcwd() 4 | GPU = False 5 | ocrModel_chinese = os.path.join(pwd,"models", "ocr-lstm.pth") 6 | ocrModel = os.path.join(pwd,"models", "ocr-english.pth") 7 | -------------------------------------------------------------------------------- /utils/google_app_engine/app.yaml: -------------------------------------------------------------------------------- 1 | runtime: custom 2 | env: flex 3 | 4 | service: yolov5app 5 | 6 | liveness_check: 7 | initial_delay_sec: 600 8 | 9 | manual_scaling: 10 | instances: 1 11 | resources: 12 | cpu: 1 13 | memory_gb: 4 14 | disk_size_gb: 20 -------------------------------------------------------------------------------- /itypes/urls.py: -------------------------------------------------------------------------------- 1 | from django.urls import path, include 2 | 3 | import itypes.views 4 | 5 | urlpatterns = [ 6 | path('demo', itypes.views.demo), 7 | path('detection', itypes.views.detection), 8 | path('detection_images', itypes.views.detection_images), 9 | path('batch_img', itypes.views.batch_img), 10 | ] -------------------------------------------------------------------------------- /util/response.py: -------------------------------------------------------------------------------- 1 | from django.http import JsonResponse 2 | 3 | 4 | class response: # 继承异常类 5 | def __init__(self, code, msg): 6 | self.code = code 7 | self.msg = msg 8 | 9 | def result(self): 10 | responseData = {'code': self.code, 'msg': self.msg} 11 | return JsonResponse(responseData, safe=False, json_dumps_params={'ensure_ascii': False}) -------------------------------------------------------------------------------- /util/create_img.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import random 3 | import os 4 | 5 | use = 0 6 | 7 | def no_tax(label, img, res): 8 | if res is None: 9 | res = "" 10 | if ":" in res: 11 | res = "" 12 | dir = "F:\\crnn_demo" 13 | ran = round(random.uniform(0, 10000), 2) 14 | if not os.path.exists(dir + "\\" + label): 15 | os.makedirs(dir + "\\" + label) 16 | cv2.imencode('.png', img)[1].tofile(dir + "\\" + label + "\\" + res + "_" + str(ran) + ".png") -------------------------------------------------------------------------------- /invoice_ocr/wsgi.py: -------------------------------------------------------------------------------- 1 | """ 2 | WSGI config for invoice_ocr project. 3 | 4 | It exposes the WSGI callable as a module-level variable named ``application``. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/2.0/howto/deployment/wsgi/ 8 | """ 9 | 10 | import os 11 | 12 | from django.core.wsgi import get_wsgi_application 13 | 14 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "invoice_ocr.settings") 15 | 16 | application = get_wsgi_application() 17 | -------------------------------------------------------------------------------- /manage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import sys 4 | import matplotlib 5 | import seaborn 6 | import utils.autoanchor 7 | import utils 8 | import utils.torch_utils 9 | import util.qrcode 10 | import tqdm 11 | import matplotlib 12 | import models.yolo 13 | 14 | if __name__ == "__main__": 15 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "invoice_ocr.settings") 16 | try: 17 | from django.core.management import execute_from_command_line 18 | except ImportError as exc: 19 | raise ImportError( 20 | "Couldn't import Django. Are you sure it's installed and " 21 | "available on your PYTHONPATH environment variable? Did you " 22 | "forget to activate a virtual environment?" 23 | ) from exc 24 | execute_from_command_line(sys.argv) 25 | -------------------------------------------------------------------------------- /utils/aws/mime.sh: -------------------------------------------------------------------------------- 1 | # AWS EC2 instance startup 'MIME' script https://aws.amazon.com/premiumsupport/knowledge-center/execute-user-data-ec2/ 2 | # This script will run on every instance restart, not only on first start 3 | # --- DO NOT COPY ABOVE COMMENTS WHEN PASTING INTO USERDATA --- 4 | 5 | Content-Type: multipart/mixed; boundary="//" 6 | MIME-Version: 1.0 7 | 8 | --// 9 | Content-Type: text/cloud-config; charset="us-ascii" 10 | MIME-Version: 1.0 11 | Content-Transfer-Encoding: 7bit 12 | Content-Disposition: attachment; filename="cloud-config.txt" 13 | 14 | #cloud-config 15 | cloud_final_modules: 16 | - [scripts-user, always] 17 | 18 | --// 19 | Content-Type: text/x-shellscript; charset="us-ascii" 20 | MIME-Version: 1.0 21 | Content-Transfer-Encoding: 7bit 22 | Content-Disposition: attachment; filename="userdata.txt" 23 | 24 | #!/bin/bash 25 | # --- paste contents of userdata.sh here --- 26 | --// 27 | -------------------------------------------------------------------------------- /invoice_ocr/urls.py: -------------------------------------------------------------------------------- 1 | """invoice_ocr URL Configuration 2 | 3 | The `urlpatterns` list routes URLs to views. For more information please see: 4 | https://docs.djangoproject.com/en/2.0/topics/http/urls/ 5 | Examples: 6 | Function views 7 | 1. Add an import: from my_app import views 8 | 2. Add a URL to urlpatterns: path('', views.home, name='home') 9 | Class-based views 10 | 1. Add an import: from other_app.views import Home 11 | 2. Add a URL to urlpatterns: path('', Home.as_view(), name='home') 12 | Including another URLconf 13 | 1. Import the include() function: from django.urls import include, path 14 | 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) 15 | """ 16 | from django.contrib import admin 17 | from django.urls import path, include 18 | 19 | urlpatterns = [ 20 | path('admin/', admin.site.urls), 21 | path('type/', include('itypes.urls')) 22 | ] 23 | -------------------------------------------------------------------------------- /utils/wandb_logging/log_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import yaml 4 | 5 | from wandb_utils import WandbLogger 6 | 7 | WANDB_ARTIFACT_PREFIX = 'wandb-artifact://' 8 | 9 | 10 | def create_dataset_artifact(opt): 11 | with open(opt.data) as f: 12 | data = yaml.load(f, Loader=yaml.SafeLoader) # data dict 13 | logger = WandbLogger(opt, '', None, data, job_type='Dataset Creation') 14 | 15 | 16 | if __name__ == '__main__': 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path') 19 | parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset') 20 | parser.add_argument('--project', type=str, default='YOLOv5', help='name of W&B Project') 21 | opt = parser.parse_args() 22 | opt.resume = False # Explicitly disallow resume check for dataset upload job 23 | 24 | create_dataset_artifact(opt) 25 | -------------------------------------------------------------------------------- /utils/google_app_engine/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM gcr.io/google-appengine/python 2 | 3 | # Create a virtualenv for dependencies. This isolates these packages from 4 | # system-level packages. 5 | # Use -p python3 or -p python3.7 to select python version. Default is version 2. 6 | RUN virtualenv /env -p python3 7 | 8 | # Setting these environment variables are the same as running 9 | # source /env/bin/activate. 10 | ENV VIRTUAL_ENV /env 11 | ENV PATH /env/bin:$PATH 12 | 13 | RUN apt-get update && apt-get install -y python-opencv 14 | 15 | # Copy the application's requirements.txt and run pip to install all 16 | # dependencies into the virtualenv. 17 | ADD requirements.txt /app/requirements.txt 18 | RUN pip install -r /app/requirements.txt 19 | 20 | # Add the application source code. 21 | ADD . /app 22 | 23 | # Run a WSGI server to serve the application. gunicorn must be declared as 24 | # a dependency in requirements.txt. 25 | CMD gunicorn -b :$PORT main:app 26 | -------------------------------------------------------------------------------- /obj_det/ocr_context.py: -------------------------------------------------------------------------------- 1 | from obj_det.detect import invoice_detection as det 2 | from obj_det.tra_detect import invoice_detection as tra 3 | from obj_det.vat_detect import invoice_detection as vat 4 | from obj_det.taxi_detect import invoice_detection as taxi 5 | from obj_det.roll_detect import invoice_detection as roll 6 | from obj_det.title_detect import invoice_detection as title 7 | from obj_det.evat_detect import invoice_detection as evat 8 | from obj_det.no_tax_detect import invoice_detection as noVat 9 | from single_ocr.opencv_direction import angle_detect_dnn 10 | from crnn.crnn_torch import crnnOcr as crnnOcr 11 | from crnn.crnn_torch_chinese import crnnOcr as ccrnnOcr 12 | 13 | 14 | class TextOcrModel(object): 15 | def __init__(self): 16 | self.ocrModel = crnnOcr 17 | self.chineseModel = ccrnnOcr 18 | self.textModel = None 19 | self.angleModel = angle_detect_dnn 20 | self.det = det 21 | self.tra = tra 22 | self.vat = vat 23 | self.taxi = taxi 24 | self.roll = roll 25 | self.title = title 26 | self.evat = evat 27 | self.noVat = noVat 28 | 29 | 30 | context = TextOcrModel() 31 | -------------------------------------------------------------------------------- /single_ocr/opencv_direction.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import numpy as np 4 | import threading 5 | 6 | lock = threading.Lock() 7 | 8 | pwd = os.getcwd() 9 | AngleModelPb = os.path.join(pwd, "models", "saved_model.pb") 10 | AngleModelPbtxt = os.path.join(pwd, "models", "saved_model.pbtxt") 11 | angleNet = cv2.dnn.readNetFromTensorflow(AngleModelPb,AngleModelPbtxt)##dnn 文字方向检测 12 | 13 | def angle_detect_dnn(img, adjust=True): 14 | """ 15 | 文字方向检测 16 | """ 17 | h, w = img.shape[:2] 18 | ROTATE = [0, 90, 180, 270] 19 | if adjust: 20 | thesh = 0.05 21 | xmin, ymin, xmax, ymax = int(thesh * w), int(thesh * h), w - int(thesh * w), h - int(thesh * h) 22 | img = img[ymin:ymax, xmin:xmax] ##剪切图片边缘 23 | 24 | inputBlob = cv2.dnn.blobFromImage(img, 25 | scalefactor=1.0, 26 | size=(224, 224), 27 | swapRB=True, 28 | mean=[103.939, 116.779, 123.68], crop=False) 29 | lock.acquire(timeout=30) 30 | angleNet.setInput(inputBlob) 31 | pred = angleNet.forward() 32 | lock.release() 33 | index = np.argmax(pred, axis=1)[0] 34 | return ROTATE[index] -------------------------------------------------------------------------------- /crnn/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # encoding: utf-8 3 | from PIL import Image 4 | import numpy as np 5 | class strLabelConverter(object): 6 | 7 | def __init__(self, alphabet): 8 | self.alphabet = alphabet + 'ç' # for `-1` index 9 | self.dict = {} 10 | for i, char in enumerate(alphabet): 11 | # NOTE: 0 is reserved for 'blank' required by wrap_ctc 12 | self.dict[char] = i + 1 13 | 14 | def decode(self,res): 15 | N = len(res) 16 | raw = [] 17 | for i in range(N): 18 | if res[i] != 0 and (not (i > 0 and res[i - 1] == res[i])): 19 | raw.append(self.alphabet[res[i] - 1]) 20 | return ''.join(raw) 21 | 22 | 23 | class resizeNormalize(object): 24 | 25 | def __init__(self, size, interpolation=Image.BILINEAR): 26 | self.size = size 27 | self.interpolation = interpolation 28 | 29 | def __call__(self, img): 30 | size = self.size 31 | imgW,imgH = size 32 | scale = img.size[1]*1.0 / imgH 33 | w = img.size[0] / scale 34 | w = int(w) 35 | img = img.resize((w,imgH),self.interpolation) 36 | w,h = img.size 37 | img = (np.array(img)/255.0-0.5)/0.5 38 | 39 | return img -------------------------------------------------------------------------------- /utils/aws/resume.py: -------------------------------------------------------------------------------- 1 | # Resume all interrupted trainings in yolov5/ dir including DDP trainings 2 | # Usage: $ python utils/aws/resume.py 3 | 4 | import os 5 | import sys 6 | from pathlib import Path 7 | 8 | import torch 9 | import yaml 10 | 11 | sys.path.append('./') # to run '$ python *.py' files in subdirectories 12 | 13 | port = 0 # --master_port 14 | path = Path('').resolve() 15 | for last in path.rglob('*/**/last.pt'): 16 | ckpt = torch.load(last) 17 | if ckpt['optimizer'] is None: 18 | continue 19 | 20 | # Load opt.yaml 21 | with open(last.parent.parent / 'opt.yaml') as f: 22 | opt = yaml.load(f, Loader=yaml.SafeLoader) 23 | 24 | # Get device count 25 | d = opt['device'].split(',') # devices 26 | nd = len(d) # number of devices 27 | ddp = nd > 1 or (nd == 0 and torch.cuda.device_count() > 1) # distributed data parallel 28 | 29 | if ddp: # multi-GPU 30 | port += 1 31 | cmd = f'python -m torch.distributed.launch --nproc_per_node {nd} --master_port {port} train.py --resume {last}' 32 | else: # single-GPU 33 | cmd = f'python train.py --resume {last}' 34 | 35 | cmd += ' > /dev/null 2>&1 &' # redirect output to dev/null and run in daemon thread 36 | print(cmd) 37 | os.system(cmd) 38 | -------------------------------------------------------------------------------- /manage.spec: -------------------------------------------------------------------------------- 1 | # -*- mode: python ; coding: utf-8 -*- 2 | 3 | 4 | block_cipher = None 5 | 6 | 7 | a = Analysis(['manage.py'], 8 | pathex=[], 9 | binaries=[], 10 | datas=[], 11 | hiddenimports=[], 12 | hookspath=[], 13 | hooksconfig={}, 14 | runtime_hooks=[], 15 | excludes=[], 16 | win_no_prefer_redirects=False, 17 | win_private_assemblies=False, 18 | cipher=block_cipher, 19 | noarchive=False) 20 | pyz = PYZ(a.pure, a.zipped_data, 21 | cipher=block_cipher) 22 | 23 | exe = EXE(pyz, 24 | a.scripts, 25 | [], 26 | exclude_binaries=True, 27 | name='manage', 28 | debug=False, 29 | bootloader_ignore_signals=False, 30 | strip=False, 31 | upx=True, 32 | console=True, 33 | disable_windowed_traceback=False, 34 | target_arch=None, 35 | codesign_identity=None, 36 | entitlements_file=None ) 37 | coll = COLLECT(exe, 38 | a.binaries, 39 | a.zipfiles, 40 | a.datas, 41 | strip=False, 42 | upx=True, 43 | upx_exclude=[], 44 | name='manage') 45 | -------------------------------------------------------------------------------- /utils/aws/userdata.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # AWS EC2 instance startup script https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/user-data.html 3 | # This script will run only once on first instance start (for a re-start script see mime.sh) 4 | # /home/ubuntu (ubuntu) or /home/ec2-user (amazon-linux) is working dir 5 | # Use >300 GB SSD 6 | 7 | cd home/ubuntu 8 | if [ ! -d yolov5 ]; then 9 | echo "Running first-time script." # install dependencies, download COCO, pull Docker 10 | git clone https://github.com/ultralytics/yolov5 && sudo chmod -R 777 yolov5 11 | cd yolov5 12 | bash data/scripts/get_coco.sh && echo "Data done." & 13 | sudo docker pull ultralytics/yolov5:latest && echo "Docker done." & 14 | python -m pip install --upgrade pip && pip install -r requirements.txt && python detect.py && echo "Requirements done." & 15 | wait && echo "All tasks done." # finish background tasks 16 | else 17 | echo "Running re-start script." # resume interrupted runs 18 | i=0 19 | list=$(sudo docker ps -qa) # container list i.e. $'one\ntwo\nthree\nfour' 20 | while IFS= read -r id; do 21 | ((i++)) 22 | echo "restarting container $i: $id" 23 | sudo docker start $id 24 | # sudo docker exec -it $id python train.py --resume # single-GPU 25 | sudo docker exec -d $id python utils/aws/resume.py # multi-scenario 26 | done <<<"$list" 27 | fi 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 混合报销票据识别 2 | 识别文件类型:图片,pdf,ofd, 0,90,180,270四种度数。 3 | 识别类型:增值税专用发票, 增值税普通发票, 增值税电子专用发票, 增值税电子普通发票, 增值税普通发票(卷式), 过路费发票, 火车票, 飞机票, 客运票, 出租车票, 定额, 通用机打发票 4 | ## 环境 5 | 1. python3.5/3.6 6 | 2. 依赖项安装:pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple 7 | 3. 有GPU环境的可修改安装requirements.txt对应版本的tensorflow-gpu,config.py文件中控制GPU的开关 8 | ## 模型架构 9 | YOLOv5 + CRNN + CTC 10 | 11 | ## 模型 12 | 1. 模型下载地址:链接:链接:https://pan.baidu.com/s/1E_OE9HOjjFh6GZdPWQVbMg 提取码:voqi 13 | 2. 将下载完毕的模型文件夹models放置于项目根目录下 14 | ## 服务启动 15 | 1. 控制台 python manage.py runserver 127.0.0.1:8080 16 | 2. 端口可自行修改 17 | 3. 服务调用地址:http://*.*.*.*: [端口号]/detection_images,http://127.0.0.1:8080/detection,例:http://127.0.0.1:8080/detection_images 18 | 19 | ## 测试demo 20 | 1. 测试工具:postman,可自行下载安装 21 | 2. 4张增值税发票混拍 22 | 23 | ![Image text](https://raw.githubusercontent.com/384863451/invoice_ocr/master/images/a9f89e262f4d10220724004d99f3fc8.png) 24 | 25 | # 代码执行过程说明 26 | - 使用django命令启动 27 | - 首先对图片做处理,可以接收的参数为图片文件,图片base64编码,图片下载地址 28 | - 图片中发票定位,并把识别结果放到list 29 | - 判断对应的发票类型进一步识别发票具体部位。 30 | - 识别到关键部位通过crnn识别具体信息 31 | - 电子发票特别优化,可以识别pdf和ofd 32 | 33 | 34 | ## 后期开发计划 35 | - 增值税发票只识别了五要素,后续打算结合发票查验直接获取全票面 36 | - 其他发票都只识别了几个部位,后期有空完善 37 | - crnn使用了chineseocr项目自带的,正在做,工作量太大有空更新 38 | 39 | ## 参考 40 | chineseocr https://github.com/chineseocr/chineseocr 41 | 42 | ##总结 43 | 新手做着玩,代码写的很乱。 44 | -------------------------------------------------------------------------------- /util/ofd_util.py: -------------------------------------------------------------------------------- 1 | from xml.dom.minidom import parse 2 | 3 | import shutil 4 | import zipfile 5 | import os 6 | 7 | 8 | def unzip_file(zip_path, unzip_path=None): 9 | """ 10 | :param zip_path: ofd格式文件路径 11 | :param unzip_path: 解压后的文件存放目录 12 | :return: unzip_path 13 | """ 14 | if not unzip_path: 15 | unzip_path = zip_path.split('.')[0] 16 | with zipfile.ZipFile(zip_path, 'r') as f: 17 | for file in f.namelist(): 18 | f.extract(file, path=unzip_path) 19 | 20 | return unzip_path 21 | 22 | def get_info(dir_path, unzip_file_path=None, removed=True): 23 | """ 24 | :param dir_path: 压缩文件路径 25 | :param unzip_file_path: 解压后的文件路径 26 | :param removed: 是否删除解压后的目录 27 | :return: ofd_info,字典形式的发票信息 28 | """ 29 | if not os.path.exists(dir_path): 30 | os.makedirs(dir_path) 31 | file_path = unzip_file(dir_path, unzip_file_path) 32 | io = f"{file_path}/OFD.xml" 33 | element = parse(io).documentElement 34 | nodes = element.getElementsByTagName('ofd:CustomDatas') 35 | ofd_info = {} 36 | for i in range(len(nodes)): 37 | sun_node = nodes[i].childNodes 38 | for j in range(len(sun_node)): 39 | name = sun_node[j].getAttribute('Name') 40 | value = sun_node[j].firstChild.data 41 | ofd_info[name] =value 42 | if removed: 43 | shutil.rmtree(unzip_file_path) 44 | return ofd_info -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | altgraph==0.17.2 3 | astunparse==1.6.3 4 | cachetools==4.2.4 5 | certifi==2020.6.20 6 | charset-normalizer==2.0.9 7 | colorama==0.4.4 8 | cycler==0.11.0 9 | dataclasses==0.8 10 | Django==2.0 11 | future==0.18.2 12 | gast==0.3.3 13 | google-auth==2.3.3 14 | google-auth-oauthlib==0.4.6 15 | google-pasta==0.2.0 16 | grpcio==1.42.0 17 | h5py==2.10.0 18 | idna==3.3 19 | image==1.5.33 20 | importlib-metadata==4.8.2 21 | jupyter-contrib-core==0.3.3 22 | jupyter-highlight-selected-word==0.2.0 23 | jupyter-latex-envs==1.4.6 24 | Keras-Preprocessing==1.1.2 25 | kiwisolver==1.3.1 26 | lmdb==1.2.1 27 | lxml==4.6.3 28 | Markdown==3.3.6 29 | matplotlib==3.3.4 30 | numpy==1.18.5 31 | oauthlib==3.1.1 32 | opencv-python==4.5.4.60 33 | opt-einsum==3.3.0 34 | pandas==1.1.5 35 | pefile==2021.9.3 36 | Pillow==8.4.0 37 | protobuf==3.19.1 38 | pyasn1==0.4.8 39 | pyasn1-modules==0.2.8 40 | pyinstaller==4.7 41 | pyinstaller-hooks-contrib==2021.4 42 | PyMuPDF==1.19.3 43 | pyparsing==3.0.6 44 | python-dateutil==2.8.2 45 | pytz==2021.3 46 | pywin32-ctypes==0.2.0 47 | PyYAML==5.4.1 48 | pyzbar==0.1.8 49 | requests==2.26.0 50 | requests-oauthlib==1.3.0 51 | rsa==4.8 52 | scipy==1.4.1 53 | seaborn==0.11.2 54 | six==1.16.0 55 | tensorboard==2.7.0 56 | tensorboard-data-server==0.6.1 57 | tensorboard-plugin-wit==1.8.0 58 | tensorflow==2.3.0 59 | tensorflow-estimator==2.3.0 60 | termcolor==1.1.0 61 | torch==1.10.0 62 | torchvision==0.11.1 63 | tqdm==4.62.3 64 | typing_extensions==4.0.1 65 | urllib3==1.26.7 66 | Werkzeug==2.0.2 67 | wincertstore==0.2 68 | wrapt==1.13.3 69 | zipp==3.6.0 70 | zxing==0.14 71 | -------------------------------------------------------------------------------- /crnn/crnn_torch.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import torch 3 | import numpy as np 4 | from torch.autograd import Variable 5 | from crnn.utils import strLabelConverter, resizeNormalize 6 | from crnn.network_torch import CRNN 7 | from crnn import keys 8 | from collections import OrderedDict 9 | from crnn.config import ocrModel, GPU 10 | from PIL import Image 11 | import cv2 12 | import threading 13 | 14 | lock = threading.Lock() 15 | def crnnSource(): 16 | 17 | alphabet = keys.alphabetEnglish##英文模型 18 | 19 | converter = strLabelConverter(alphabet) 20 | model = CRNN(32, 1, len(alphabet)+1, 256, 1,lstmFlag=True).cpu() 21 | trainWeights = torch.load(ocrModel,map_location=lambda storage, loc: storage) 22 | modelWeights = OrderedDict() 23 | for k, v in trainWeights.items(): 24 | name = k.replace('module.','') # remove `module.` 25 | modelWeights[name] = v 26 | # load params 27 | 28 | model.load_state_dict(modelWeights) 29 | 30 | return model,converter 31 | 32 | ##加载模型 33 | model,converter = crnnSource() 34 | model.eval() 35 | def crnnOcr(image): 36 | """ 37 | crnn模型,ocr识别 38 | image:cv2 39 | """ 40 | pil_img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) 41 | image = pil_img.convert('L') 42 | scale = image.size[1]*1.0 / 32 43 | w = image.size[0] / scale 44 | w = int(w) 45 | transformer = resizeNormalize((w, 32)) 46 | image = transformer(image) 47 | image = image.astype(np.float32) 48 | image = torch.from_numpy(image) 49 | 50 | if torch.cuda.is_available() and GPU: 51 | image = image.cuda() 52 | else: 53 | image = image.cpu() 54 | 55 | image = image.view(1,1, *image.size()) 56 | image = Variable(image) 57 | lock.acquire(timeout=3) 58 | preds = model(image) 59 | lock.release() 60 | _, preds = preds.max(2) 61 | preds = preds.transpose(1, 0).contiguous().view(-1) 62 | sim_pred = converter.decode(preds) 63 | return sim_pred 64 | 65 | 66 | -------------------------------------------------------------------------------- /util/qrcode.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import zxing as zxing 3 | import uuid 4 | from obj_det.detect import converter as multype 5 | 6 | 7 | def qrcode(img, invoice): 8 | try: 9 | id = uuid.uuid4() 10 | path = "images/qrcode/" + str(id) + ".jpg" 11 | h, w, _ = img.shape 12 | newimage = cv2.resize(img, (w * 2, h * 2), cv2.INTER_LINEAR) 13 | cv2.imwrite(path, newimage) 14 | reader = zxing.BarCodeReader() 15 | barcode = reader.decode(path) 16 | data = barcode.parsed 17 | datas = data.split(",") 18 | if datas[2] != '' and datas[2] != None: 19 | invoice['invoice_code'] = datas[2] 20 | if datas[3] != '' and datas[3] != None: 21 | invoice['invoice_number'] = datas[3] 22 | if datas[4] != '' and datas[4] != None: 23 | invoice['totalAmount'] = "¥" + datas[4] 24 | if datas[5] != '' and datas[5] != None: 25 | invoice['billingDate'] = datas[5][0:4] + "年" + datas[5][4:6] + "月" + datas[5][6:8] + "日" 26 | if datas[1] == "04" or datas[1] == "10" or datas[1] == "11": 27 | invoice['checkCode'] = datas[6] 28 | if invoice['invoiceType'] != "14": 29 | invoice['invoiceType'] = datas[1] 30 | invoice['invoice_type_name'] = multype[datas[1]] 31 | return True 32 | except Exception as e: 33 | print("二维码未识别" + str(e)) 34 | return False 35 | 36 | 37 | def qrcode_no_tax(img, invoice): 38 | try: 39 | img = cv2.resize(img, (300, 300)) 40 | id = uuid.uuid4() 41 | path = "images/qrcode/" + str(id) + ".jpg" 42 | cv2.imwrite(path, img) 43 | reader = zxing.BarCodeReader() 44 | barcode = reader.decode(path) 45 | data = barcode.parsed 46 | if "http" in data: 47 | paramOrg = data.split("?")[1] 48 | param_2s = paramOrg.split("&") 49 | for param_2 in param_2s: 50 | param = param_2.split("=") 51 | invoice[param[0]] = param[1] 52 | return True 53 | except Exception as e: 54 | print("二维码未识别" + str(e)) 55 | return False 56 | -------------------------------------------------------------------------------- /crnn/crnn_torch_chinese.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import torch 3 | import numpy as np 4 | from torch.autograd import Variable 5 | from crnn.utils import strLabelConverter, resizeNormalize 6 | from crnn.network_torch import CRNN 7 | from crnn import keys 8 | from collections import OrderedDict 9 | from crnn.config import ocrModel_chinese, GPU 10 | from PIL import Image 11 | import cv2 12 | import threading 13 | 14 | lock = threading.Lock() 15 | def crnnSource(): 16 | """ 17 | 加载模型 18 | """ 19 | alphabet = keys.alphabetChinese##中英文模型 20 | 21 | converter = strLabelConverter(alphabet) 22 | model = CRNN(32, 1, len(alphabet)+1, 256, 1,lstmFlag=True).cpu() 23 | 24 | trainWeights = torch.load(ocrModel_chinese,map_location=lambda storage, loc: storage) 25 | modelWeights = OrderedDict() 26 | for k, v in trainWeights.items(): 27 | name = k.replace('module.','') # remove `module.` 28 | modelWeights[name] = v 29 | # load params 30 | 31 | model.load_state_dict(modelWeights) 32 | 33 | return model,converter 34 | 35 | ##加载模型 36 | model,converter = crnnSource() 37 | model.eval() 38 | def crnnOcr(image): 39 | """ 40 | crnn模型,ocr识别 41 | image:PIL.Image.convert("L") 42 | """ 43 | pil_img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) 44 | image = pil_img.convert('L') 45 | scale = image.size[1]*1.0 / 32 46 | w = image.size[0] / scale 47 | w = int(w) 48 | transformer = resizeNormalize((w, 32)) 49 | image = transformer(image) 50 | image = image.astype(np.float32) 51 | image = torch.from_numpy(image) 52 | 53 | if torch.cuda.is_available() and GPU: 54 | image = image.cuda() 55 | else: 56 | image = image.cpu() 57 | 58 | image = image.view(1,1, *image.size()) 59 | image = Variable(image) 60 | lock.acquire(timeout=3) 61 | preds = model(image) 62 | lock.release() 63 | _, preds = preds.max(2) 64 | preds = preds.transpose(1, 0).contiguous().view(-1) 65 | sim_pred = converter.decode(preds) 66 | return sim_pred 67 | 68 | 69 | -------------------------------------------------------------------------------- /utils/activations.py: -------------------------------------------------------------------------------- 1 | # Activation functions 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | # SiLU https://arxiv.org/pdf/1606.08415.pdf ---------------------------------------------------------------------------- 9 | class SiLU(nn.Module): # export-friendly version of nn.SiLU() 10 | @staticmethod 11 | def forward(x): 12 | return x * torch.sigmoid(x) 13 | 14 | 15 | class Hardswish(nn.Module): # export-friendly version of nn.Hardswish() 16 | @staticmethod 17 | def forward(x): 18 | # return x * F.hardsigmoid(x) # for torchscript and CoreML 19 | return x * F.hardtanh(x + 3, 0., 6.) / 6. # for torchscript, CoreML and ONNX 20 | 21 | 22 | class MemoryEfficientSwish(nn.Module): 23 | class F(torch.autograd.Function): 24 | @staticmethod 25 | def forward(ctx, x): 26 | ctx.save_for_backward(x) 27 | return x * torch.sigmoid(x) 28 | 29 | @staticmethod 30 | def backward(ctx, grad_output): 31 | x = ctx.saved_tensors[0] 32 | sx = torch.sigmoid(x) 33 | return grad_output * (sx * (1 + x * (1 - sx))) 34 | 35 | def forward(self, x): 36 | return self.F.apply(x) 37 | 38 | 39 | # Mish https://github.com/digantamisra98/Mish -------------------------------------------------------------------------- 40 | class Mish(nn.Module): 41 | @staticmethod 42 | def forward(x): 43 | return x * F.softplus(x).tanh() 44 | 45 | 46 | class MemoryEfficientMish(nn.Module): 47 | class F(torch.autograd.Function): 48 | @staticmethod 49 | def forward(ctx, x): 50 | ctx.save_for_backward(x) 51 | return x.mul(torch.tanh(F.softplus(x))) # x * tanh(ln(1 + exp(x))) 52 | 53 | @staticmethod 54 | def backward(ctx, grad_output): 55 | x = ctx.saved_tensors[0] 56 | sx = torch.sigmoid(x) 57 | fx = F.softplus(x).tanh() 58 | return grad_output * (fx + x * sx * (1 - fx * fx)) 59 | 60 | def forward(self, x): 61 | return self.F.apply(x) 62 | 63 | 64 | # FReLU https://arxiv.org/abs/2007.11824 ------------------------------------------------------------------------------- 65 | class FReLU(nn.Module): 66 | def __init__(self, c1, k=3): # ch_in, kernel 67 | super().__init__() 68 | self.conv = nn.Conv2d(c1, c1, k, 1, 1, groups=c1, bias=False) 69 | self.bn = nn.BatchNorm2d(c1) 70 | 71 | def forward(self, x): 72 | return torch.max(x, self.bn(self.conv(x))) 73 | -------------------------------------------------------------------------------- /crnn/util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # encoding: utf-8 3 | 4 | import torch 5 | import torch.nn as nn 6 | import collections 7 | 8 | 9 | class strLabelConverter(object): 10 | 11 | def __init__(self, alphabet): 12 | self.alphabet = alphabet + 'ç' # for `-1` index 13 | self.dict = {} 14 | for i, char in enumerate(alphabet): 15 | # NOTE: 0 is reserved for 'blank' required by wrap_ctc 16 | self.dict[char] = i + 1 17 | def encode(self, text, depth=0): 18 | """Support batch or single str.""" 19 | length = [] 20 | result=[] 21 | for str in text: 22 | length.append(len(str)) 23 | for char in str: 24 | #print(char) 25 | index = self.dict[char] 26 | result.append(index) 27 | text = result 28 | return (torch.IntTensor(text), torch.IntTensor(length)) 29 | 30 | def decode(self, t, length, raw=False): 31 | if length.numel() == 1: 32 | length = length[0] 33 | t = t[:length] 34 | if raw: 35 | return ''.join([self.alphabet[i - 1] for i in t]) 36 | else: 37 | char_list = [] 38 | for i in range(length): 39 | if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): 40 | char_list.append(self.alphabet[t[i] - 1]) 41 | return ''.join(char_list) 42 | else: 43 | texts = [] 44 | index = 0 45 | for i in range(length.numel()): 46 | l = length[i] 47 | texts.append(self.decode( 48 | t[index:index + l], torch.IntTensor([l]), raw=raw)) 49 | index += l 50 | return texts 51 | 52 | 53 | class averager(object): 54 | 55 | def __init__(self): 56 | self.reset() 57 | 58 | def add(self, v): 59 | self.n_count += v.data.numel() 60 | # NOTE: not `+= v.sum()`, which will add a node in the compute graph, 61 | # which lead to memory leak 62 | self.sum += v.data.sum() 63 | 64 | def reset(self): 65 | self.n_count = 0 66 | self.sum = 0 67 | 68 | def val(self): 69 | res = 0 70 | if self.n_count != 0: 71 | res = self.sum / float(self.n_count) 72 | return res 73 | 74 | 75 | def oneHot(v, v_length, nc): 76 | batchSize = v_length.size(0) 77 | maxLength = v_length.max() 78 | v_onehot = torch.FloatTensor(batchSize, maxLength, nc).fill_(0) 79 | acc = 0 80 | for i in range(batchSize): 81 | length = v_length[i] 82 | label = v[acc:acc + length].view(-1, 1).long() 83 | v_onehot[i, :length].scatter_(1, label, 1.0) 84 | acc += length 85 | return v_onehot 86 | 87 | 88 | def loadData(v, data): 89 | v.data.resize_(data.size()).copy_(data) 90 | 91 | 92 | def prettyPrint(v): 93 | print('Size {0}, Type: {1}'.format(str(v.size()), v.data.type())) 94 | print('| Max: %f | Min: %f | Mean: %f' % (v.max().data[0], v.min().data[0], v.mean().data[0])) 95 | 96 | 97 | def assureRatio(img): 98 | """Ensure imgH <= imgW.""" 99 | b, c, h, w = img.size() 100 | if h > w: 101 | main = nn.UpsamplingBilinear2d(size=(h, h), scale_factor=None) 102 | img = main(img) 103 | return img 104 | -------------------------------------------------------------------------------- /crnn/network_torch.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | class BidirectionalLSTM(nn.Module): 3 | 4 | def __init__(self, nIn, nHidden, nOut): 5 | super(BidirectionalLSTM, self).__init__() 6 | self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True) 7 | self.embedding = nn.Linear(nHidden * 2, nOut) 8 | 9 | def forward(self, input): 10 | recurrent, _ = self.rnn(input) 11 | T, b, h = recurrent.size() 12 | t_rec = recurrent.view(T * b, h) 13 | output = self.embedding(t_rec) # [T * b, nOut] 14 | output = output.view(T, b, -1) 15 | return output 16 | 17 | 18 | 19 | class CRNN(nn.Module): 20 | 21 | def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False,lstmFlag=True): 22 | """ 23 | 是否加入lstm特征层 24 | """ 25 | super(CRNN, self).__init__() 26 | assert imgH % 16 == 0, 'imgH has to be a multiple of 16' 27 | 28 | ks = [3, 3, 3, 3, 3, 3, 2] 29 | ps = [1, 1, 1, 1, 1, 1, 0] 30 | ss = [1, 1, 1, 1, 1, 1, 1] 31 | nm = [64, 128, 256, 256, 512, 512, 512] 32 | self.lstmFlag = lstmFlag 33 | 34 | cnn = nn.Sequential() 35 | 36 | def convRelu(i, batchNormalization=False): 37 | nIn = nc if i == 0 else nm[i - 1] 38 | nOut = nm[i] 39 | cnn.add_module('conv{0}'.format(i), 40 | nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i])) 41 | if batchNormalization: 42 | cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut)) 43 | if leakyRelu: 44 | cnn.add_module('relu{0}'.format(i), 45 | nn.LeakyReLU(0.2, inplace=True)) 46 | else: 47 | cnn.add_module('relu{0}'.format(i), nn.ReLU(True)) 48 | 49 | convRelu(0) 50 | cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64 51 | convRelu(1) 52 | cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32 53 | convRelu(2, True) 54 | convRelu(3) 55 | cnn.add_module('pooling{0}'.format(2), 56 | nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16 57 | convRelu(4, True) 58 | convRelu(5) 59 | cnn.add_module('pooling{0}'.format(3), 60 | nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16 61 | convRelu(6, True) # 512x1x16 62 | 63 | self.cnn = cnn 64 | if self.lstmFlag: 65 | self.rnn = nn.Sequential( 66 | BidirectionalLSTM(512, nh, nh), 67 | BidirectionalLSTM(nh, nh, nclass)) 68 | else: 69 | self.linear = nn.Linear(nh*2, nclass) 70 | 71 | 72 | def forward(self, input): 73 | # conv features 74 | conv = self.cnn(input) 75 | b, c, h, w = conv.size() 76 | 77 | assert h == 1, "the height of conv must be 1" 78 | conv = conv.squeeze(2) 79 | conv = conv.permute(2, 0, 1) # [w, b, c] 80 | if self.lstmFlag: 81 | # rnn features 82 | output = self.rnn(conv) 83 | else: 84 | T, b, h = conv.size() 85 | 86 | t_rec = conv.contiguous().view(T * b, h) 87 | 88 | output = self.linear(t_rec) # [T * b, nOut] 89 | output = output.view(T, b, -1) 90 | 91 | 92 | return output 93 | -------------------------------------------------------------------------------- /util/pdf_read.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import time,os.path,requests,re 3 | from pdfminer.pdfinterp import PDFResourceManager, PDFPageInterpreter 4 | from pdfminer.converter import PDFPageAggregator 5 | from pdfminer.layout import LAParams,LTTextBoxHorizontal,LTImage,LTCurve,LTFigure 6 | from pdfminer.pdfpage import PDFTextExtractionNotAllowed,PDFPage 7 | from pdfminer.pdfparser import PDFParser 8 | from pdfminer.pdfdocument import PDFDocument 9 | from docx import Document 10 | import fitz 11 | document = Document() 12 | 13 | ''' 14 | pip install pdfminer3k 15 | pip install pdfminer.six 安装这个引入的内容不会报错 16 | ''' 17 | 18 | class CPdf2TxtManager(): 19 | 20 | def changePdfToText(self, filePath): 21 | res = {"state": 0} 22 | context = "" 23 | # 以二进制读模式打开 24 | file = open(filePath, 'rb') 25 | #用文件对象来创建一个pdf文档分析器 26 | praser = PDFParser(file) 27 | # 创建一个PDF文档对象存储文档结构,提供密码初始化,没有就不用传该参数 28 | doc = PDFDocument(praser, password='') 29 | ##检查文件是否允许文本提取 30 | if not doc.is_extractable: 31 | raise PDFTextExtractionNotAllowed 32 | 33 | # 创建PDf 资源管理器 来管理共享资源,#caching = False不缓存 34 | rsrcmgr = PDFResourceManager(caching = False) 35 | # 创建一个PDF设备对象 36 | laparams = LAParams() 37 | # 创建一个PDF页面聚合对象 38 | device = PDFPageAggregator(rsrcmgr, laparams=laparams) 39 | # 创建一个PDF解析器对象 40 | interpreter = PDFPageInterpreter(rsrcmgr, device) 41 | # 获得文档的目录(纲要),文档没有纲要会报错 42 | #PDF文档没有目录时会报:raise PDFNoOutlines pdfminer.pdfdocument.PDFNoOutlines 43 | # print(doc.get_outlines()) 44 | 45 | # 获取page列表 46 | print(PDFPage.get_pages(doc)) 47 | # 用来计数页面,图片,曲线,figure,水平文本框等对象的数量 48 | num_page, num_image, num_curve, num_figure, num_TextBoxHorizontal = 0, 0, 0, 0, 0 49 | # 循环遍历列表,每次处理一个page的内容 50 | for page in PDFPage.create_pages(doc): 51 | num_page += 1 # 页面增一 52 | # 利用解释器的process_page()方法解析读取单独页数 53 | interpreter.process_page(page) 54 | # 接受该页面的LTPage对象 55 | layout = device.get_result() 56 | fileNames = os.path.splitext(filePath) 57 | # 这里layout是一个LTPage对象 里面存放着 这个page解析出的各种对象 58 | # 一般包括LTTextBox, LTFigure, LTImage, LTTextBoxHorizontal 等等 59 | for x in layout: 60 | if hasattr(x, "get_text") or isinstance(x, LTTextBoxHorizontal): 61 | results = x.get_text().replace(u'\xa0', u' ') 62 | if '作废' in results or '作 废' in results or "作 废" in results: 63 | res['state'] = "2" 64 | context = context + results 65 | # 如果x是水平文本对象的话 66 | if isinstance(x, LTTextBoxHorizontal): 67 | num_TextBoxHorizontal += 1 # 水平文本框对象增一 68 | if isinstance(x, LTImage): # 图片对象 69 | num_image += 1 70 | if isinstance(x, LTCurve): # 曲线对象 71 | num_curve += 1 72 | if isinstance(x, LTFigure): # figure对象 73 | num_figure += 1 74 | 75 | print('对象数量:%s,页面数:%s,图片数:%s,曲线数:%s,' 76 | '水平文本框:%s,'%(num_figure,num_page,num_image,num_curve,num_TextBoxHorizontal)) 77 | if '红' in context and '字' in context and '冲' in context and '销' in context: 78 | res['state'] = "3" 79 | return res 80 | 81 | def parse(path): 82 | pdf2TxtManager = CPdf2TxtManager() 83 | return pdf2TxtManager.changePdfToText(path) 84 | -------------------------------------------------------------------------------- /invoice_ocr/settings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Django settings for invoice_ocr project. 3 | 4 | Generated by 'django-admin startproject' using Django 2.0. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/2.0/topics/settings/ 8 | 9 | For the full list of settings and their values, see 10 | https://docs.djangoproject.com/en/2.0/ref/settings/ 11 | """ 12 | 13 | import os 14 | 15 | # Build paths inside the project like this: os.path.join(BASE_DIR, ...) 16 | BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 17 | 18 | 19 | # Quick-start development settings - unsuitable for production 20 | # See https://docs.djangoproject.com/en/2.0/howto/deployment/checklist/ 21 | 22 | # SECURITY WARNING: keep the secret key used in production secret! 23 | SECRET_KEY = '%0bi$p*zo9s@*qsn6^#tz&d4o&^&8gd6a10ah_m2*%#a95m5+^' 24 | 25 | # SECURITY WARNING: don't run with debug turned on in production! 26 | DEBUG = True 27 | 28 | ALLOWED_HOSTS = ['*'] 29 | 30 | 31 | # Application definition 32 | 33 | INSTALLED_APPS = [ 34 | 'django.contrib.admin', 35 | 'django.contrib.auth', 36 | 'django.contrib.contenttypes', 37 | 'django.contrib.sessions', 38 | 'django.contrib.messages', 39 | 'django.contrib.staticfiles', 40 | 'itypes', 41 | ] 42 | 43 | MIDDLEWARE = [ 44 | 'django.middleware.security.SecurityMiddleware', 45 | 'django.contrib.sessions.middleware.SessionMiddleware', 46 | 'django.middleware.common.CommonMiddleware', 47 | #'django.middleware.csrf.CsrfViewMiddleware', 48 | 'django.contrib.auth.middleware.AuthenticationMiddleware', 49 | 'django.contrib.messages.middleware.MessageMiddleware', 50 | 'django.middleware.clickjacking.XFrameOptionsMiddleware', 51 | ] 52 | 53 | ROOT_URLCONF = 'invoice_ocr.urls' 54 | 55 | TEMPLATES = [ 56 | { 57 | 'BACKEND': 'django.template.backends.django.DjangoTemplates', 58 | 'DIRS': [], 59 | 'APP_DIRS': True, 60 | 'OPTIONS': { 61 | 'context_processors': [ 62 | 'django.template.context_processors.debug', 63 | 'django.template.context_processors.request', 64 | 'django.contrib.auth.context_processors.auth', 65 | 'django.contrib.messages.context_processors.messages', 66 | ], 67 | }, 68 | }, 69 | ] 70 | 71 | WSGI_APPLICATION = 'invoice_ocr.wsgi.application' 72 | 73 | 74 | # Database 75 | # https://docs.djangoproject.com/en/2.0/ref/settings/#databases 76 | 77 | DATABASES = { 78 | 'default': { 79 | 'ENGINE': 'django.db.backends.sqlite3', 80 | 'NAME': os.path.join(BASE_DIR, 'db.sqlite3'), 81 | } 82 | } 83 | 84 | 85 | # Password validation 86 | # https://docs.djangoproject.com/en/2.0/ref/settings/#auth-password-validators 87 | 88 | AUTH_PASSWORD_VALIDATORS = [ 89 | { 90 | 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', 91 | }, 92 | { 93 | 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', 94 | }, 95 | { 96 | 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', 97 | }, 98 | { 99 | 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', 100 | }, 101 | ] 102 | 103 | 104 | # Internationalization 105 | # https://docs.djangoproject.com/en/2.0/topics/i18n/ 106 | 107 | LANGUAGE_CODE = 'en-us' 108 | 109 | TIME_ZONE = 'UTC' 110 | 111 | USE_I18N = True 112 | 113 | USE_L10N = True 114 | 115 | USE_TZ = True 116 | 117 | 118 | # Static files (CSS, JavaScript, Images) 119 | # https://docs.djangoproject.com/en/2.0/howto/static-files/ 120 | 121 | STATIC_URL = '/static/' 122 | -------------------------------------------------------------------------------- /itypes/views.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | 3 | import cv2 4 | import requests 5 | 6 | from obj_det.objd_util import detection as det 7 | from util.response import response as re 8 | import os, time, base64 9 | from crnn.crnn_torch import crnnOcr as crnnOcr 10 | from crnn.crnn_torch_chinese import crnnOcr as ccrnnOcr 11 | import json 12 | import uuid 13 | 14 | allowed_extension = ['jpg', 'png', 'JPG', 'pdf', 'ofd'] 15 | 16 | 17 | # 检查文件扩展名 18 | def allowed_file(filename): 19 | return '.' in filename and filename.rsplit('.', 1)[1].lower() in allowed_extension 20 | 21 | 22 | def demo(request): 23 | aaa = "F:/result/split_end/trian_split/03p048145.jpg" 24 | img = cv2.imread(aaa) 25 | res = crnnOcr(img) 26 | res2 = ccrnnOcr(img) 27 | return re(0, res).result() 28 | 29 | 30 | a = "20280624" 31 | pub_access_token = "43ac8eaadesc3ab489g8717320drd65" 32 | 33 | 34 | def detection(request): 35 | currentDate = time.strftime('%Y%m%d', time.localtime(time.time())) 36 | if currentDate > a: 37 | return re(1, "请求参数不正确").result() 38 | post_param = request.POST 39 | get_param = request.GET 40 | if 'file' in request.FILES: 41 | return detection_images(request) 42 | if post_param.__contains__('urls'): 43 | return detection_url(request) 44 | if not post_param.__contains__('file') or not post_param.__contains__('name') or not get_param.__contains__( 45 | 'access_token'): 46 | return re(1, "请求参数不正确").result() 47 | file = post_param['file'] 48 | name = post_param['name'] 49 | access_token = get_param['access_token'] 50 | # if pub_access_token != access_token: 51 | # return re(1, "请求参数不正确").result() 52 | file = base64.b64decode(file) 53 | invoice_file_name = name 54 | if not allowed_file(invoice_file_name): 55 | return re(102, "失败,文件格式问题").result() 56 | with open(os.path.join("images", invoice_file_name), 'wb+') as destination: 57 | destination.write(file) 58 | destination.close() 59 | list_invoice = det(invoice_file_name) 60 | return re(0, list_invoice).result() 61 | 62 | 63 | def detection_url(request): 64 | post_param = request.POST 65 | get_param = request.GET 66 | urls = post_param['urls'] 67 | access_token = get_param.get('access_token') 68 | # if pub_access_token != access_token: 69 | # return re(1, "请求参数不正确").result() 70 | 71 | all_invoice = [] 72 | for url in urls.split(","): 73 | name = str(uuid.uuid4()) + ".pdf" 74 | invoice_file_name = name 75 | r = requests.get(url) 76 | with open(os.path.join("images", invoice_file_name), 'wb+') as destination: 77 | destination.write(r.content) 78 | destination.close() 79 | list_invoice = det(invoice_file_name) 80 | for invoice in list_invoice: 81 | invoice['pdfurl'] = url 82 | all_invoice.extend(list_invoice) 83 | return re(0, all_invoice).result() 84 | 85 | 86 | def detection_images(request): 87 | currentDate = time.strftime('%Y%m%d', time.localtime(time.time())) 88 | if currentDate > a: 89 | return re(1, "请求参数不正确").result() 90 | # 校验请求参数 91 | if 'file' not in request.FILES: 92 | return re(1, "请求参数不正确").result() 93 | file = request.FILES['file'] 94 | invoice_file_name = file.name 95 | if not allowed_file(invoice_file_name): 96 | return re(102, "失败,文件格式问题").result() 97 | destination = open(os.path.join("images", invoice_file_name), 'wb+') 98 | for chunk in file.chunks(): # 分块写入文件 99 | destination.write(chunk) 100 | destination.close() 101 | list_invoice = det(invoice_file_name) 102 | return re(0, list_invoice).result() 103 | 104 | 105 | def batch_img(request): 106 | file_dir = "F:\\aaa\\images" 107 | for root, dirs, files in os.walk(file_dir): 108 | for file in files: 109 | file_path = root + "\\" + file 110 | shutil.copy(file_path, "D:\\pycharm\\invoice_ocr\\images\\" + file) 111 | det(file) 112 | return re(0, "a").result() 113 | -------------------------------------------------------------------------------- /util/slide.py: -------------------------------------------------------------------------------- 1 | import zxing as zxing 2 | from selenium import webdriver 3 | from selenium.webdriver.common.by import By 4 | from selenium.webdriver.support import expected_conditions as EC 5 | from selenium.webdriver.support.wait import WebDriverWait 6 | from PIL import Image 7 | from six import BytesIO 8 | import time 9 | from selenium.webdriver import ActionChains 10 | 11 | 12 | def get_url(url): 13 | browser = webdriver.Chrome("D:\\pycharm\\invoice_ocr\\driver\\chromedriver.exe") 14 | browser.get(url) 15 | browser.maximize_window() 16 | time.sleep(5) 17 | #aaa = browser.find_element_by_xpath("//*[@class='login-by-item-app hiddenAppScan']").click() 18 | browser.implicitly_wait(10) 19 | wait = WebDriverWait(browser,10) 20 | wait.until(EC.presence_of_element_located((By.CLASS_NAME, 'geetest_radar_btn'))) 21 | btn = browser.find_element_by_css_selector('.geetest_radar_btn') 22 | btn.click() 23 | time.sleep(0.5) 24 | return browser 25 | 26 | def get_position(img_label): 27 | location = img_label.location 28 | size = img_label.size 29 | top, bottom, left, right = location['y'], location['y'] + size['height'], location['x'], location['x'] + size[ 30 | 'width'] 31 | return (left, top, right, bottom) 32 | 33 | def get_screenshot(browser): 34 | screenshot = browser.get_screenshot_as_png() 35 | f = BytesIO() 36 | f.write(screenshot) 37 | return Image.open(f) 38 | 39 | def get_position_scale(browser,screen_shot): 40 | height = browser.execute_script('return document.documentElement.clientHeight') 41 | width = browser.execute_script('return document.documentElement.clientWidth') 42 | x_scale = screen_shot.size[0] / (width+10) 43 | y_scale = screen_shot.size[1] / (height) 44 | return (x_scale,y_scale) 45 | 46 | def get_slideimg_screenshot(screenshot,position,scale): 47 | x_scale,y_scale = scale 48 | position = [position[0] * x_scale, position[1] * y_scale, position[2] * x_scale, position[3] * y_scale] 49 | return screenshot.crop(position) 50 | 51 | def compare_pixel(img1,img2,x,y): 52 | pixel1 = img1.load()[x,y] 53 | pixel2 = img2.load()[x,y] 54 | threshold = 50 55 | if abs(pixel1[0]-pixel2[0])<=threshold: 56 | if abs(pixel1[1]-pixel2[1])<=threshold: 57 | if abs(pixel1[2]-pixel2[2])<=threshold: 58 | return True 59 | return False 60 | 61 | 62 | def compare(full_img,slice_img): 63 | left = 0 64 | for i in range(full_img.size[0]): 65 | for j in range(full_img.size[1]): 66 | if not compare_pixel(full_img,slice_img,i,j): 67 | return i 68 | return left 69 | 70 | def get_track(distance): 71 | """ 72 | 根据偏移量获取移动轨迹 73 | :param distance: 偏移量 74 | :return: 移动轨迹 75 | """ 76 | # 移动轨迹 77 | track = [] 78 | # 当前位移 79 | current = 0 80 | # 减速阈值 81 | mid = distance * 4 / 5 82 | # 计算间隔 83 | t = 0.2 84 | # 初速度 85 | v = 0 86 | 87 | while current < distance: 88 | if current < mid: 89 | # 加速度为正 2 90 | a = 4 91 | else: 92 | # 加速度为负 3 93 | a = -3 94 | # 初速度 v0 95 | v0 = v 96 | # 当前速度 v = v0 + at 97 | v = v0 + a * t 98 | # 移动距离 x = v0t + 1/2 * a * t^2 99 | move = v0 * t + 1 / 2 * a * t * t 100 | # 当前位移 101 | current += move 102 | # 加入轨迹 103 | track.append(round(move)) 104 | return track 105 | 106 | 107 | def move_to_gap(browser,slider, tracks): 108 | """ 109 | 拖动滑块到缺口处 110 | :param slider: 滑块 111 | :param tracks: 轨迹 112 | :return: 113 | """ 114 | ActionChains(browser).click_and_hold(slider).perform() 115 | for x in tracks: 116 | ActionChains(browser).move_by_offset(xoffset=x, yoffset=0).perform() 117 | time.sleep(0.5) 118 | ActionChains(browser).release().perform() 119 | 120 | if __name__ == '__main__': 121 | path = "C:\\Users\\38486\\Desktop\\image\\qr\\big.png" 122 | reader = zxing.BarCodeReader() 123 | barcode = reader.decode(path) 124 | url = barcode.parsed 125 | print(url) 126 | 127 | -------------------------------------------------------------------------------- /obj_det/objd_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import time 3 | import fitz 4 | import threading 5 | from util.ofd_util import get_info 6 | from obj_det.ocr_context import context 7 | 8 | lock = threading.Lock() 9 | 10 | allowed_extension = ['jpg', 'png', 'JPG', 'pdf', 'ofd'] 11 | image_extension = ['jpg', 'png', 'JPG'] 12 | pdf_extension = ['pdf'] 13 | ofd_extension = ['ofd'] 14 | 15 | vat_names = ['01', '04'] 16 | e_vat_names = ['08', '10', '14'] 17 | tra_names = ['88'] 18 | taxi_names = ['92'] 19 | roll_names = ['11'] 20 | no_tax_names = ['81'] 21 | 22 | 23 | # 检查文件扩展名 24 | def allowed_file(filename, type_extension): 25 | return '.' in filename and filename.rsplit('.', 1)[1].lower() in type_extension 26 | 27 | 28 | def time_synchronized(): 29 | return time.time() 30 | 31 | 32 | def rotate(img, invoice): 33 | angle = context.angleModel(img) 34 | if angle != 0: 35 | index = 3 - angle / 90 36 | img = cv2.rotate(img, int(index)) 37 | cv2.imwrite(invoice['file_path'], img) 38 | 39 | 40 | def process_image(file_name): 41 | result = [] 42 | list_invoice = context.det(file_name) 43 | for invoice in list_invoice: 44 | invoice_type = invoice['invoiceType'] 45 | img = cv2.imread(invoice['file_path']) 46 | rotate(img, invoice) 47 | if str(invoice_type) in vat_names or str(invoice_type) in e_vat_names: 48 | context.title(file_name=invoice['file_path'], invoice=invoice) 49 | invoice_type = invoice['invoiceType'] 50 | if str(invoice_type) in tra_names: 51 | context.tra(file_name=invoice['file_path'], invoice=invoice) 52 | result.append(invoice) 53 | if str(invoice_type) in vat_names: 54 | context.vat(file_name=invoice['file_path'], invoice=invoice, context=context) 55 | result.append(invoice) 56 | if str(invoice_type) in e_vat_names: 57 | context.evat(file_name=invoice['file_path'], invoice=invoice, context=context) 58 | result.append(invoice) 59 | if str(invoice_type) in taxi_names: 60 | context.taxi(file_name=invoice['file_path'], invoice=invoice) 61 | result.append(invoice) 62 | if str(invoice_type) in roll_names: 63 | context.roll(file_name=invoice['file_path'], invoice=invoice) 64 | result.append(invoice) 65 | if str(invoice_type) in no_tax_names: 66 | context.noVat(file_name=invoice['file_path'], invoice=invoice, context=context) 67 | result.append(invoice) 68 | return result 69 | 70 | 71 | def process_pdf(file_name): 72 | # 打开PDF文件,生成一个对象 73 | doc = fitz.open("images/" + file_name) 74 | name = file_name.split(".")[0] 75 | result = [] 76 | for pg in range(doc.pageCount): 77 | file_name = name + str(pg) 78 | page = doc[pg] 79 | rotate = int(0) 80 | # 每个尺寸的缩放系数为2,这将为我们生成分辨率提高四倍的图像。 81 | zoom_x = 2.0 82 | zoom_y = 2.0 83 | trans = fitz.Matrix(zoom_x, zoom_y).preRotate(rotate) 84 | pm = page.getPixmap(matrix=trans, alpha=False) 85 | pm.writePNG('images/%s.png' % file_name) 86 | file_name_curr = '%s.png' % file_name 87 | invoices = process_image(file_name_curr) 88 | result.extend(invoices) 89 | return result 90 | 91 | 92 | def process_ofd(file_name): 93 | result = get_info("images/" + file_name, "images/zip" + file_name.split(".")[0]) 94 | invoices = [] 95 | checkcode = result.get('校验码') 96 | invoiceType = '10' 97 | invoice_type_name = '增值税电子普通发票' 98 | if checkcode == '' or checkcode is None: 99 | invoiceType = '08' 100 | invoice_type_name = '增值税电子专用发票' 101 | invoice = {"invoiceType": invoiceType, "invoice_type_name": invoice_type_name, "file_path": 'demo', 102 | "coordinate": [], 'invoice_code': result['发票代码'], 'invoice_number': result['发票号码'], 103 | 'totalAmount': result['合计金额'], 'billingDate': result['开票日期'], 104 | 'checkCode': result['校验码'].replace(" ", "")} 105 | 106 | invoices.append(invoice) 107 | return invoices 108 | 109 | 110 | def detection(file_name): 111 | if allowed_file(file_name, image_extension): 112 | return process_image(file_name) 113 | if allowed_file(file_name, pdf_extension): 114 | return process_pdf(file_name) 115 | if allowed_file(file_name, ofd_extension): 116 | return process_ofd(file_name) 117 | -------------------------------------------------------------------------------- /crnn/dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # encoding: utf-8 3 | 4 | import random 5 | import torch 6 | from torch.utils.data import Dataset 7 | from torch.utils.data import sampler 8 | import torchvision.transforms as transforms 9 | import lmdb 10 | import six 11 | import sys 12 | from PIL import Image 13 | import numpy as np 14 | 15 | 16 | class lmdbDataset(Dataset): 17 | 18 | def __init__(self, root=None, transform=None, target_transform=None): 19 | self.env = lmdb.open( 20 | root, 21 | max_readers=1, 22 | readonly=True, 23 | lock=False, 24 | readahead=False, 25 | meminit=False) 26 | 27 | if not self.env: 28 | print('cannot creat lmdb from %s' % (root)) 29 | sys.exit(0) 30 | 31 | with self.env.begin(write=False) as txn: 32 | nSamples = int(txn.get('num-samples')) 33 | self.nSamples = nSamples 34 | 35 | self.transform = transform 36 | self.target_transform = target_transform 37 | 38 | def __len__(self): 39 | return self.nSamples 40 | 41 | def __getitem__(self, index): 42 | assert index <= len(self), 'index range error' 43 | index += 1 44 | with self.env.begin(write=False) as txn: 45 | img_key = 'image-%09d' % index 46 | imgbuf = txn.get(img_key) 47 | 48 | buf = six.BytesIO() 49 | buf.write(imgbuf) 50 | buf.seek(0) 51 | try: 52 | img = Image.open(buf).convert('L') 53 | except IOError: 54 | print('Corrupted image for %d' % index) 55 | return self[index + 1] 56 | 57 | if self.transform is not None: 58 | img = self.transform(img) 59 | 60 | label_key = 'label-%09d' % index 61 | label = str(txn.get(label_key)) 62 | if self.target_transform is not None: 63 | label = self.target_transform(label) 64 | 65 | return (img, label) 66 | 67 | 68 | class resizeNormalize(object): 69 | 70 | def __init__(self, size, interpolation=Image.BILINEAR): 71 | self.size = size 72 | self.interpolation = interpolation 73 | self.toTensor = transforms.ToTensor() 74 | 75 | def __call__(self, img): 76 | img = img.resize(self.size, self.interpolation) 77 | img = self.toTensor(img) 78 | img.sub_(0.5).div_(0.5) 79 | return img 80 | 81 | 82 | class randomSequentialSampler(sampler.Sampler): 83 | 84 | def __init__(self, data_source, batch_size): 85 | self.num_samples = len(data_source) 86 | self.batch_size = batch_size 87 | 88 | def __iter__(self): 89 | n_batch = len(self) // self.batch_size 90 | tail = len(self) % self.batch_size 91 | index = torch.LongTensor(len(self)).fill_(0) 92 | for i in range(n_batch): 93 | random_start = random.randint(0, len(self) - self.batch_size) 94 | batch_index = random_start + torch.range(0, self.batch_size - 1) 95 | index[i * self.batch_size:(i + 1) * self.batch_size] = batch_index 96 | # deal with tail 97 | if tail: 98 | random_start = random.randint(0, len(self) - self.batch_size) 99 | tail_index = random_start + torch.range(0, tail - 1) 100 | index[(i + 1) * self.batch_size:] = tail_index 101 | 102 | return iter(index) 103 | 104 | def __len__(self): 105 | return self.num_samples 106 | 107 | 108 | class alignCollate(object): 109 | 110 | def __init__(self, imgH=32, imgW=128, keep_ratio=False, min_ratio=1): 111 | self.imgH = imgH 112 | self.imgW = imgW 113 | self.keep_ratio = keep_ratio 114 | self.min_ratio = min_ratio 115 | 116 | def __call__(self, batch): 117 | images, labels = zip(*batch) 118 | 119 | imgH = self.imgH 120 | imgW = self.imgW 121 | if self.keep_ratio: 122 | ratios = [] 123 | for image in images: 124 | w, h = image.size 125 | ratios.append(w / float(h)) 126 | ratios.sort() 127 | max_ratio = ratios[-1] 128 | imgW = int(np.floor(max_ratio * imgH)) 129 | imgW = max(imgH * self.min_ratio, imgW) # assure imgH >= imgW 130 | 131 | transform = resizeNormalize((imgW, imgH)) 132 | images = [transform(image) for image in images] 133 | images = torch.cat([t.unsqueeze(0) for t in images], 0) 134 | 135 | return images, labels 136 | -------------------------------------------------------------------------------- /utils/google_utils.py: -------------------------------------------------------------------------------- 1 | # Google utils: https://cloud.google.com/storage/docs/reference/libraries 2 | 3 | import os 4 | import platform 5 | import subprocess 6 | import time 7 | from pathlib import Path 8 | 9 | import requests 10 | import torch 11 | 12 | 13 | def gsutil_getsize(url=''): 14 | # gs://bucket/file size https://cloud.google.com/storage/docs/gsutil/commands/du 15 | s = subprocess.check_output(f'gsutil du {url}', shell=True).decode('utf-8') 16 | return eval(s.split(' ')[0]) if len(s) else 0 # bytes 17 | 18 | 19 | def attempt_download(file, repo='ultralytics/yolov5'): 20 | # Attempt file download if does not exist 21 | file = Path(str(file).strip().replace("'", '').lower()) 22 | 23 | if not file.exists(): 24 | try: 25 | response = requests.get(f'https://api.github.com/repos/{repo}/releases/latest').json() # github api 26 | assets = [x['name'] for x in response['assets']] # release assets, i.e. ['yolov5s.pt', 'yolov5m.pt', ...] 27 | tag = response['tag_name'] # i.e. 'v1.0' 28 | except: # fallback plan 29 | assets = ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt'] 30 | tag = subprocess.check_output('git tag', shell=True).decode().split()[-1] 31 | 32 | name = file.name 33 | if name in assets: 34 | msg = f'{file} missing, try downloading from https://github.com/{repo}/releases/' 35 | redundant = False # second download option 36 | try: # GitHub 37 | url = f'https://github.com/{repo}/releases/download/{tag}/{name}' 38 | print(f'Downloading {url} to {file}...') 39 | torch.hub.download_url_to_file(url, file) 40 | assert file.exists() and file.stat().st_size > 1E6 # check 41 | except Exception as e: # GCP 42 | print(f'Download error: {e}') 43 | assert redundant, 'No secondary mirror' 44 | url = f'https://storage.googleapis.com/{repo}/ckpt/{name}' 45 | print(f'Downloading {url} to {file}...') 46 | os.system(f'curl -L {url} -o {file}') # torch.hub.download_url_to_file(url, weights) 47 | finally: 48 | if not file.exists() or file.stat().st_size < 1E6: # check 49 | file.unlink(missing_ok=True) # remove partial downloads 50 | print(f'ERROR: Download failure: {msg}') 51 | print('') 52 | return 53 | 54 | 55 | def gdrive_download(id='16TiPfZj7htmTyhntwcZyEEAejOUxuT6m', file='tmp.zip'): 56 | # Downloads a file from Google Drive. from yolov5.utils.google_utils import *; gdrive_download() 57 | t = time.time() 58 | file = Path(file) 59 | cookie = Path('cookie') # gdrive cookie 60 | print(f'Downloading https://drive.google.com/uc?export=download&id={id} as {file}... ', end='') 61 | file.unlink(missing_ok=True) # remove existing file 62 | cookie.unlink(missing_ok=True) # remove existing cookie 63 | 64 | # Attempt file download 65 | out = "NUL" if platform.system() == "Windows" else "/dev/null" 66 | os.system(f'curl -c ./cookie -s -L "drive.google.com/uc?export=download&id={id}" > {out}') 67 | if os.path.exists('cookie'): # large file 68 | s = f'curl -Lb ./cookie "drive.google.com/uc?export=download&confirm={get_token()}&id={id}" -o {file}' 69 | else: # small file 70 | s = f'curl -s -L -o {file} "drive.google.com/uc?export=download&id={id}"' 71 | r = os.system(s) # execute, capture return 72 | cookie.unlink(missing_ok=True) # remove existing cookie 73 | 74 | # Error check 75 | if r != 0: 76 | file.unlink(missing_ok=True) # remove partial 77 | print('Download error ') # raise Exception('Download error') 78 | return r 79 | 80 | # Unzip if archive 81 | if file.suffix == '.zip': 82 | print('unzipping... ', end='') 83 | os.system(f'unzip -q {file}') # unzip 84 | file.unlink() # remove zip to free space 85 | 86 | print(f'Done ({time.time() - t:.1f}s)') 87 | return r 88 | 89 | 90 | def get_token(cookie="./cookie"): 91 | with open(cookie) as f: 92 | for line in f: 93 | if "download" in line: 94 | return line.split()[-1] 95 | return "" 96 | 97 | # def upload_blob(bucket_name, source_file_name, destination_blob_name): 98 | # # Uploads a file to a bucket 99 | # # https://cloud.google.com/storage/docs/uploading-objects#storage-upload-object-python 100 | # 101 | # storage_client = storage.Client() 102 | # bucket = storage_client.get_bucket(bucket_name) 103 | # blob = bucket.blob(destination_blob_name) 104 | # 105 | # blob.upload_from_filename(source_file_name) 106 | # 107 | # print('File {} uploaded to {}.'.format( 108 | # source_file_name, 109 | # destination_blob_name)) 110 | # 111 | # 112 | # def download_blob(bucket_name, source_blob_name, destination_file_name): 113 | # # Uploads a blob from a bucket 114 | # storage_client = storage.Client() 115 | # bucket = storage_client.get_bucket(bucket_name) 116 | # blob = bucket.blob(source_blob_name) 117 | # 118 | # blob.download_to_filename(destination_file_name) 119 | # 120 | # print('Blob {} downloaded to {}.'.format( 121 | # source_blob_name, 122 | # destination_file_name)) 123 | -------------------------------------------------------------------------------- /crnn/keys.py: -------------------------------------------------------------------------------- 1 | #coding:UTF-8 2 | alphabetChinese = u'\'疗绚诚娇溜题贿者廖更纳加奉公一就汴计与路房原妇208-7其>:],,骑刈全消昏傈安久钟嗅不影处驽蜿资关椤地瘸专问忖票嫉炎韵要月田节陂鄙捌备拳伺眼网盎大傍心东愉汇蹿科每业里航晏字平录先13彤鲶产稍督腴有象岳注绍在泺文定核名水过理让偷率等这发”为含肥酉相鄱七编猥锛日镀蒂掰倒辆栾栗综涩州雌滑馀了机块司宰甙兴矽抚保用沧秩如收息滥页疑埠!!姥异橹钇向下跄的椴沫国绥獠报开民蜇何分凇长讥藏掏施羽中讲派嘟人提浼间世而古多倪唇饯控庚首赛蜓味断制觉技替艰溢潮夕钺外摘枋动双单啮户枇确锦曜杜或能效霜盒然侗电晁放步鹃新杖蜂吒濂瞬评总隍对独合也是府青天诲墙组滴级邀帘示已时骸仄泅和遨店雇疫持巍踮境只亨目鉴崤闲体泄杂作般轰化解迂诿蛭璀腾告版服省师小规程线海办引二桧牌砺洄裴修图痫胡许犊事郛基柴呼食研奶律蛋因葆察戏褒戒再李骁工貂油鹅章啄休场给睡纷豆器捎说敏学会浒设诊格廓查来霓室溆¢诡寥焕舜柒狐回戟砾厄实翩尿五入径惭喹股宇篝|;美期云九祺扮靠锝槌系企酰阊暂蚕忻豁本羹执条钦H獒限进季楦于芘玖铋茯未答粘括样精欠矢甥帷嵩扣令仔风皈行支部蓉刮站蜡救钊汗松嫌成可.鹤院从交政怕活调球局验髌第韫谗串到圆年米/*友忿检区看自敢刃个兹弄流留同没齿星聆轼湖什三建蛔儿椋汕震颧鲤跟力情璺铨陪务指族训滦鄣濮扒商箱十召慷辗所莞管护臭横硒嗓接侦六露党馋驾剖高侬妪幂猗绺骐央酐孝筝课徇缰门男西项句谙瞒秃篇教碲罚声呐景前富嘴鳌稀免朋啬睐去赈鱼住肩愕速旁波厅健茼厥鲟谅投攸炔数方击呋谈绩别愫僚躬鹧胪炳招喇膨泵蹦毛结54谱识陕粽婚拟构且搜任潘比郢妨醪陀桔碘扎选哈骷楷亿明缆脯监睫逻婵共赴淝凡惦及达揖谩澹减焰蛹番祁柏员禄怡峤龙白叽生闯起细装谕竟聚钙上导渊按艾辘挡耒盹饪臀记邮蕙受各医搂普滇朗茸带翻酚(光堤墟蔷万幻〓瑙辈昧盏亘蛀吉铰请子假闻税井诩哨嫂好面琐校馊鬣缂营访炖占农缀否经钚棵趟张亟吏茶谨捻论迸堂玉信吧瞠乡姬寺咬溏苄皿意赉宝尔钰艺特唳踉都荣倚登荐丧奇涵批炭近符傩感道着菊虹仲众懈濯颞眺南释北缝标既茗整撼迤贲挎耱拒某妍卫哇英矶藩治他元领膜遮穗蛾飞荒棺劫么市火温拈棚洼转果奕卸迪伸泳斗邡侄涨屯萋胭氡崮枞惧冒彩斜手豚随旭淑妞形菌吲沱争驯歹挟兆柱传至包内响临红功弩衡寂禁老棍耆渍织害氵渑布载靥嗬虽苹咨娄库雉榜帜嘲套瑚亲簸欧边6腿旮抛吹瞳得镓梗厨继漾愣憨士策窑抑躯襟脏参贸言干绸鳄穷藜音折详)举悍甸癌黎谴死罩迁寒驷袖媒蒋掘模纠恣观祖蛆碍位稿主澧跌筏京锏帝贴证糠才黄鲸略炯饱四出园犀牧容汉杆浈汰瑷造虫瘩怪驴济应花沣谔夙旅价矿以考su呦晒巡茅准肟瓴詹仟褂译桌混宁怦郑抿些余鄂饴攒珑群阖岔琨藓预环洮岌宀杲瀵最常囡周踊女鼓袭喉简范薯遐疏粱黜禧法箔斤遥汝奥直贞撑置绱集她馅逗钧橱魉[恙躁唤9旺膘待脾惫购吗依盲度瘿蠖俾之镗拇鲵厝簧续款展啃表剔品钻腭损清锶统涌寸滨贪链吠冈伎迥咏吁览防迅失汾阔逵绀蔑列川凭努熨揪利俱绉抢鸨我即责膦易毓鹊刹玷岿空嘞绊排术估锷违们苟铜播肘件烫审鲂广像铌惰铟巳胍鲍康憧色恢想拷尤疳知SYFDA峄裕帮握搔氐氘难墒沮雨叁缥悴藐湫娟苑稠颛簇后阕闭蕤缚怎佞码嘤蔡痊舱螯帕赫昵升烬岫、疵蜻髁蕨隶烛械丑盂梁强鲛由拘揉劭龟撤钩呕孛费妻漂求阑崖秤甘通深补赃坎床啪承吼量暇钼烨阂擎脱逮称P神属矗华届狍葑汹育患窒蛰佼静槎运鳗庆逝曼疱克代官此麸耧蚌晟例础榛副测唰缢迹灬霁身岁赭扛又菡乜雾板读陷徉贯郁虑变钓菜圾现琢式乐维渔浜左吾脑钡警T啵拴偌漱湿硕止骼魄积燥联踢玛|则窿见振畿送班钽您赵刨印讨踝籍谡舌崧汽蔽沪酥绒怖财帖肱私莎勋羔霸励哼帐将帅渠纪婴娩岭厘滕吻伤坝冠戊隆瘁介涧物黍并姗奢蹑掣垸锴命箍捉病辖琰眭迩艘绌繁寅若毋思诉类诈燮轲酮狂重反职筱县委磕绣奖晋濉志徽肠呈獐坻口片碰几村柿劳料获亩惕晕厌号罢池正鏖煨家棕复尝懋蜥锅岛扰队坠瘾钬@卧疣镇譬冰彷频黯据垄采八缪瘫型熹砰楠襁箐但嘶绳啤拍盥穆傲洗盯塘怔筛丿台恒喂葛永¥烟酒桦书砂蚝缉态瀚袄圳轻蛛超榧遛姒奘铮右荽望偻卡丶氰附做革索戚坨桷唁垅榻岐偎坛莨山殊微骇陈爨推嗝驹澡藁呤卤嘻糅逛侵郓酌德摇※鬃被慨殡羸昌泡戛鞋河宪沿玲鲨翅哽源铅语照邯址荃佬顺鸳町霭睾瓢夸椁晓酿痈咔侏券噎湍签嚷离午尚社锤背孟使浪缦潍鞅军姹驶笑鳟鲁》孽钜绿洱礴焯椰颖囔乌孔巴互性椽哞聘昨早暮胶炀隧低彗昝铁呓氽藉喔癖瑗姨权胱韦堑蜜酋楝砝毁靓歙锲究屋喳骨辨碑武鸠宫辜烊适坡殃培佩供走蜈迟翼况姣凛浔吃飘债犟金促苛崇坂莳畔绂兵蠕斋根砍亢欢恬崔剁餐榫快扶‖濒缠鳜当彭驭浦篮昀锆秸钳弋娣瞑夷龛苫拱致%嵊障隐弑初娓抉汩累蓖"唬助苓昙押毙破城郧逢嚏獭瞻溱婿赊跨恼璧萃姻貉灵炉密氛陶砸谬衔点琛沛枳层岱诺脍榈埂征冷裁打蹴素瘘逞蛐聊激腱萘踵飒蓟吆取咙簋涓矩曝挺揣座你史舵焱尘苏笈脚溉榨诵樊邓焊义庶儋蟋蒲赦呷杞诠豪还试颓茉太除紫逃痴草充鳕珉祗墨渭烩蘸慕璇镶穴嵘恶骂险绋幕碉肺戳刘潞秣纾潜銮洛须罘销瘪汞兮屉r林厕质探划狸殚善煊烹〒锈逯宸辍泱柚袍远蹋嶙绝峥娥缍雀徵认镱谷=贩勉撩鄯斐洋非祚泾诒饿撬威晷搭芍锥笺蓦候琊档礁沼卵荠忑朝凹瑞头仪弧孵畏铆突衲车浩气茂悖厢枕酝戴湾邹飚攘锂写宵翁岷无喜丈挑嗟绛殉议槽具醇淞笃郴阅饼底壕砚弈询缕庹翟零筷暨舟闺甯撞麂茌蔼很珲捕棠角阉媛娲诽剿尉爵睬韩诰匣危糍镯立浏阳少盆舔擘匪申尬铣旯抖赘瓯居ˇ哮游锭茏歌坏甚秒舞沙仗劲潺阿燧郭嗖霏忠材奂耐跺砀输岖媳氟极摆灿今扔腻枝奎药熄吨话q额慑嘌协喀壳埭视著於愧陲翌峁颅佛腹聋侯咎叟秀颇存较罪哄岗扫栏钾羌己璨枭霉煌涸衿键镝益岢奏连夯睿冥均糖狞蹊稻爸刿胥煜丽肿璃掸跚灾垂樾濑乎莲窄犹撮战馄软络显鸢胸宾妲恕埔蝌份遇巧瞟粒恰剥桡博讯凯堇阶滤卖斌骚彬兑磺樱舷两娱福仃差找桁÷净把阴污戬雷碓蕲楚罡焖抽妫咒仑闱尽邑菁爱贷沥鞑牡嗉崴骤塌嗦订拮滓捡锻次坪杩臃箬融珂鹗宗枚降鸬妯阄堰盐毅必杨崃俺甬状莘货耸菱腼铸唏痤孚澳懒溅翘疙杷淼缙骰喊悉砻坷艇赁界谤纣宴晃茹归饭梢铡街抄肼鬟苯颂撷戈炒咆茭瘙负仰客琉铢封卑珥椿镧窨鬲寿御袤铃萎砖餮脒裳肪孕嫣馗嵇恳氯江石褶冢祸阻狈羞银靳透咳叼敷芷啥它瓤兰痘懊逑肌往捺坊甩呻〃沦忘膻祟菅剧崆智坯臧霍墅攻眯倘拢骠铐庭岙瓠′缺泥迢捶??郏喙掷沌纯秘种听绘固螨团香盗妒埚蓝拖旱荞铀血遏汲辰叩拽幅硬惶桀漠措泼唑齐肾念酱虚屁耶旗砦闵婉馆拭绅韧忏窝醋葺顾辞倜堆辋逆玟贱疾董惘倌锕淘嘀莽俭笏绑鲷杈择蟀粥嗯驰逾案谪褓胫哩昕颚鲢绠躺鹄崂儒俨丝尕泌啊萸彰幺吟骄苣弦脊瑰〈诛镁析闪剪侧哟框螃守嬗燕狭铈缮概迳痧鲲俯售笼痣扉挖满咋援邱扇歪便玑绦峡蛇叨〖泽胃斓喋怂坟猪该蚬炕弥赞棣晔娠挲狡创疖铕镭稷挫弭啾翔粉履苘哦楼秕铂土锣瘟挣栉习享桢袅磨桂谦延坚蔚噗署谟猬钎恐嬉雒倦衅亏璩睹刻殿王算雕麻丘柯骆丸塍谚添鲈垓桎蚯芥予飕镦谌窗醚菀亮搪莺蒿羁足J真轶悬衷靛翊掩哒炅掐冼妮l谐稚荆擒犯陵虏浓崽刍陌傻孜千靖演矜钕煽杰酗渗伞栋俗泫戍罕沾疽灏煦芬磴叱阱榉湃蜀叉醒彪租郡篷屎良垢隗弱陨峪砷掴颁胎雯绵贬沐撵隘篙暖曹陡栓填臼彦瓶琪潼哪鸡摩啦俟锋域耻蔫疯纹撇毒绶痛酯忍爪赳歆嘹辕烈册朴钱吮毯癜娃谀邵厮炽璞邃丐追词瓒忆轧芫谯喷弟半冕裙掖墉绮寝苔势顷褥切衮君佳嫒蚩霞佚洙逊镖暹唛&殒顶碗獗轭铺蛊废恹汨崩珍那杵曲纺夏薰傀闳淬姘舀拧卷楂恍讪厩寮篪赓乘灭盅鞣沟慎挂饺鼾杳树缨丛絮娌臻嗳篡侩述衰矛圈蚜匕筹匿濞晨叶骋郝挚蚴滞增侍描瓣吖嫦蟒匾圣赌毡癞恺百曳需篓肮庖帏卿驿遗蹬鬓骡歉芎胳屐禽烦晌寄媾狄翡苒船廉终痞殇々畦饶改拆悻萄£瓿乃訾桅匮溧拥纱铍骗蕃龋缬父佐疚栎醍掳蓄x惆颜鲆榆〔猎敌暴谥鲫贾罗玻缄扦芪癣落徒臾恿猩托邴肄牵春陛耀刊拓蓓邳堕寇枉淌啡湄兽酷萼碚濠萤夹旬戮梭琥椭昔勺蜊绐晚孺僵宣摄冽旨萌忙蚤眉噼蟑付契瓜悼颡壁曾窕颢澎仿俑浑嵌浣乍碌褪乱蔟隙玩剐葫箫纲围伐决伙漩瑟刑肓镳缓蹭氨皓典畲坍铑檐塑洞倬储胴淳戾吐灼惺妙毕珐缈虱盖羰鸿磅谓髅娴苴唷蚣霹抨贤唠犬誓逍庠逼麓籼釉呜碧秧氩摔霄穸纨辟妈映完牛缴嗷炊恩荔茆掉紊慌莓羟阙萁磐另蕹辱鳐湮吡吩唐睦垠舒圜冗瞿溺芾囱匠僳汐菩饬漓黑霰浸濡窥毂蒡兢驻鹉芮诙迫雳厂忐臆猴鸣蚪栈箕羡渐莆捍眈哓趴蹼埕嚣骛宏淄斑噜严瑛垃椎诱压庾绞焘廿抡迄棘夫纬锹眨瞌侠脐竞瀑孳骧遁姜颦荪滚萦伪逸粳爬锁矣役趣洒颔诏逐奸甭惠攀蹄泛尼拼阮鹰亚颈惑勒〉际肛爷刚钨丰养冶鲽辉蔻画覆皴妊麦返醉皂擀〗酶凑粹悟诀硖港卜z杀涕±舍铠抵弛段敝镐奠拂轴跛袱et沉菇俎薪峦秭蟹历盟菠寡液肢喻染裱悱抱氙赤捅猛跑氮谣仁尺辊窍烙衍架擦倏璐瑁币楞胖夔趸邛惴饕虔蝎§哉贝宽辫炮扩饲籽魏菟锰伍猝末琳哚蛎邂呀姿鄞却歧仙恸椐森牒寤袒婆虢雅钉朵贼欲苞寰故龚坭嘘咫礼硷兀睢汶’铲烧绕诃浃钿哺柜讼颊璁腔洽咐脲簌筠镣玮鞠谁兼姆挥梯蝴谘漕刷躏宦弼b垌劈麟莉揭笙渎仕嗤仓配怏抬错泯镊孰猿邪仍秋鼬壹歇吵炼<尧射柬廷胧霾凳隋肚浮梦祥株堵退L鹫跎凶毽荟炫栩玳甜沂鹿顽伯爹赔蛴徐匡欣狰缸雹蟆疤默沤啜痂衣禅wih辽葳黝钗停沽棒馨颌肉吴硫悯劾娈马啧吊悌镑峭帆瀣涉咸疸滋泣翦拙癸钥蜒+尾庄凝泉婢渴谊乞陆锉糊鸦淮IBN晦弗乔庥葡尻席橡傣渣拿惩麋斛缃矮蛏岘鸽姐膏催奔镒喱蠡摧钯胤柠拐璋鸥卢荡倾^_珀逄萧塾掇贮笆聂圃冲嵬M滔笕值炙偶蜱搐梆汪蔬腑鸯蹇敞绯仨祯谆梧糗鑫啸豺囹猾巢柄瀛筑踌沭暗苁鱿蹉脂蘖牢热木吸溃宠序泞偿拜檩厚朐毗螳吞媚朽担蝗橘畴祈糟盱隼郜惜珠裨铵焙琚唯咚噪骊丫滢勤棉呸咣淀隔蕾窈饨挨煅短匙粕镜赣撕墩酬馁豌颐抗酣氓佑搁哭递耷涡桃贻碣截瘦昭镌蔓氚甲猕蕴蓬散拾纛狼猷铎埋旖矾讳囊糜迈粟蚂紧鲳瘢栽稼羊锄斟睁桥瓮蹙祉醺鼻昱剃跳篱跷蒜翎宅晖嗑壑峻癫屏狠陋袜途憎祀莹滟佶溥臣约盛峰磁慵婪拦莅朕鹦粲裤哎疡嫖琵窟堪谛嘉儡鳝斩郾驸酊妄胜贺徙傅噌钢栅庇恋匝巯邈尸锚粗佟蛟薹纵蚊郅绢锐苗俞篆淆膀鲜煎诶秽寻涮刺怀噶巨褰魅灶灌桉藕谜舸薄搀恽借牯痉渥愿亓耘杠柩锔蚶钣珈喘蹒幽赐稗晤莱泔扯肯菪裆腩豉疆骜腐倭珏唔粮亡润慰伽橄玄誉醐胆龊粼塬陇彼削嗣绾芽妗垭瘴爽薏寨龈泠弹赢漪猫嘧涂恤圭茧烽屑痕巾赖荸凰腮畈亵蹲偃苇澜艮换骺烘苕梓颉肇哗悄氤涠葬屠鹭植竺佯诣鲇瘀鲅邦移滁冯耕癔戌茬沁巩悠湘洪痹锟循谋腕鳃钠捞焉迎碱伫急榷奈邝卯辄皲卟醛畹忧稳雄昼缩阈睑扌耗曦涅捏瞧邕淖漉铝耦禹湛喽莼琅诸苎纂硅始嗨傥燃臂赅嘈呆贵屹壮肋亍蚀卅豹腆邬迭浊}童螂捐圩勐触寞汊壤荫膺渌芳懿遴螈泰蓼蛤茜舅枫朔膝眙避梅判鹜璜牍缅垫藻黔侥惚懂踩腰腈札丞唾慈顿摹荻琬~斧沈滂胁胀幄莜Z匀鄄掌绰茎焚赋萱谑汁铒瞎夺蜗野娆冀弯篁懵灞隽芡脘俐辩芯掺喏膈蝈觐悚踹蔗熠鼠呵抓橼峨畜缔禾崭弃熊摒凸拗穹蒙抒祛劝闫扳阵醌踪喵侣搬仅荧赎蝾琦买婧瞄寓皎冻赝箩莫瞰郊笫姝筒枪遣煸袋舆痱涛母〇启践耙绲盘遂昊搞槿诬纰泓惨檬亻越Co憩熵祷钒暧塔阗胰咄娶魔琶钞邻扬杉殴咽弓〆髻】吭揽霆拄殖脆彻岩芝勃辣剌钝嘎甄佘皖伦授徕憔挪皇庞稔芜踏溴兖卒擢饥鳞煲‰账颗叻斯捧鳍琮讹蛙纽谭酸兔莒睇伟觑羲嗜宜褐旎辛卦诘筋鎏溪挛熔阜晰鳅丢奚灸呱献陉黛鸪甾萨疮拯洲疹辑叙恻谒允柔烂氏逅漆拎惋扈湟纭啕掬擞哥忽涤鸵靡郗瓷扁廊怨雏钮敦E懦憋汀拚啉腌岸f痼瞅尊咀眩飙忌仝迦熬毫胯篑茄腺凄舛碴锵诧羯後漏汤宓仞蚁壶谰皑铄棰罔辅晶苦牟闽\烃饮聿丙蛳朱煤涔鳖犁罐荼砒淦妤黏戎孑婕瑾戢钵枣捋砥衩狙桠稣阎肃梏诫孪昶婊衫嗔侃塞蜃樵峒貌屿欺缫阐栖诟珞荭吝萍嗽恂啻蜴磬峋俸豫谎徊镍韬魇晴U囟猜蛮坐囿伴亭肝佗蝠妃胞滩榴氖垩苋砣扪馏姓轩厉夥侈禀垒岑赏钛辐痔披纸碳“坞蠓挤荥沅悔铧帼蒌蝇apyng哀浆瑶凿桶馈皮奴苜佤伶晗铱炬优弊氢恃甫攥端锌灰稹炝曙邋亥眶碾拉萝绔捷浍腋姑菖凌涞麽锢桨潢绎镰殆锑渝铬困绽觎匈糙暑裹鸟盔肽迷綦『亳佝俘钴觇骥仆疝跪婶郯瀹唉脖踞针晾忒扼瞩叛椒疟嗡邗肆跆玫忡捣咧唆艄蘑潦笛阚沸泻掊菽贫斥髂孢镂赂麝鸾屡衬苷恪叠希粤爻喝茫惬郸绻庸撅碟宄妹膛叮饵崛嗲椅冤搅咕敛尹垦闷蝉霎勰败蓑泸肤鹌幌焦浠鞍刁舰乙竿裔。茵函伊兄丨娜匍謇莪宥似蝽翳酪翠粑薇祢骏赠叫Q噤噻竖芗莠潭俊羿耜O郫趁嗪囚蹶芒洁笋鹑敲硝啶堡渲揩』携宿遒颍扭棱割萜蔸葵琴捂饰衙耿掠募岂窖涟蔺瘤柞瞪怜匹距楔炜哆秦缎幼茁绪痨恨楸娅瓦桩雪嬴伏榔妥铿拌眠雍缇‘卓搓哌觞噩屈哧髓咦巅娑侑淫膳祝勾姊莴胄疃薛蜷胛巷芙芋熙闰勿窃狱剩钏幢陟铛慧靴耍k浙浇飨惟绗祜澈啼咪磷摞诅郦抹跃壬吕肖琏颤尴剡抠凋赚泊津宕殷倔氲漫邺涎怠$垮荬遵俏叹噢饽蜘孙筵疼鞭羧牦箭潴c眸祭髯啖坳愁芩驮倡巽穰沃胚怒凤槛剂趵嫁v邢灯鄢桐睽檗锯槟婷嵋圻诗蕈颠遭痢芸怯馥竭锗徜恭遍籁剑嘱苡龄僧桑潸弘澶楹悲讫愤腥悸谍椹呢桓葭攫阀翰躲敖柑郎笨橇呃魁燎脓葩磋垛玺狮沓砜蕊锺罹蕉翱虐闾巫旦茱嬷枯鹏贡芹汛矫绁拣禺佃讣舫惯乳趋疲挽岚虾衾蠹蹂飓氦铖孩稞瑜壅掀勘妓畅髋W庐牲蓿榕练垣唱邸菲昆婺穿绡麒蚱掂愚泷涪漳妩娉榄讷觅旧藤煮呛柳腓叭庵烷阡罂蜕擂猖咿媲脉【沏貅黠熏哲烁坦酵兜×潇撒剽珩圹乾摸樟帽嗒襄魂轿憬锡〕喃皆咖隅脸残泮袂鹂珊囤捆咤误徨闹淙芊淋怆囗拨梳渤RG绨蚓婀幡狩麾谢唢裸旌伉纶裂驳砼咛澄樨蹈宙澍倍貔操勇蟠摈砧虬够缁悦藿撸艹摁淹豇虎榭ˉ吱d°喧荀踱侮奋偕饷犍惮坑璎徘宛妆袈倩窦昂荏乖K怅撰鳙牙袁酞X痿琼闸雁趾荚虻涝《杏韭偈烤绫鞘卉症遢蓥诋杭荨匆竣簪辙敕虞丹缭咩黟m淤瑕咂铉硼茨嶂痒畸敬涿粪窘熟叔嫔盾忱裘憾梵赡珙咯娘庙溯胺葱痪摊荷卞乒髦寐铭坩胗枷爆溟嚼羚砬轨惊挠罄竽菏氧浅楣盼枢炸阆杯谏噬淇渺俪秆墓泪跻砌痰垡渡耽釜讶鳎煞呗韶舶绷鹳缜旷铊皱龌檀霖奄槐艳蝶旋哝赶骞蚧腊盈丁`蜚矸蝙睨嚓僻鬼醴夜彝磊笔拔栀糕厦邰纫逭纤眦膊馍躇烯蘼冬诤暄骶哑瘠」臊丕愈咱螺擅跋搏硪谄笠淡嘿骅谧鼎皋姚歼蠢驼耳胬挝涯狗蒽孓犷凉芦箴铤孤嘛坤V茴朦挞尖橙诞搴碇洵浚帚蜍漯柘嚎讽芭荤咻祠秉跖埃吓糯眷馒惹娼鲑嫩讴轮瞥靶褚乏缤宋帧删驱碎扑俩俄偏涣竹噱皙佰渚唧斡#镉刀崎筐佣夭贰肴峙哔艿匐牺镛缘仡嫡劣枸堀梨簿鸭蒸亦稽浴{衢束槲j阁揍疥棋潋聪窜乓睛插冉阪苍搽「蟾螟幸仇樽撂慢跤幔俚淅覃觊溶妖帛侨曰妾泗' 3 | alphabetEnglish='01234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ*' 4 | -------------------------------------------------------------------------------- /obj_det/roll_detect.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | from pathlib import Path 4 | import os 5 | 6 | import cv2 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | from numpy import random 10 | 11 | from models.experimental import attempt_load 12 | from utils.datasets import LoadStreams, LoadImages 13 | from utils.general import check_img_size, check_imshow, non_max_suppression, \ 14 | scale_coords, xyxy2xywh, strip_optimizer, set_logging 15 | from utils.torch_utils import select_device, time_synchronized 16 | import hashlib 17 | from crnn.crnn_torch_chinese import crnnOcr as ccrnnOcr 18 | from util.qrcode import qrcode 19 | import threading 20 | from util.create_img import no_tax 21 | from util.create_img import use 22 | 23 | lock = threading.Lock() 24 | converter = {'invoice_code': 'invoice_code', 'invoice_number': 'invoice_number', 'totalAmount': 'totalAmount', 25 | 'billingDate': 'billingDate', 'checkcode': 'checkCode', 26 | 'QRCode': 'QRCode', 'title': 'title'} 27 | pub_weights = "models/roll/best.pt" 28 | pub_view_img = False 29 | pub_save_txt = False 30 | pub_img_size = 640 31 | pub_nosave = False 32 | pub_project = "images" 33 | pub_device = "cpu" 34 | pub_augment = False 35 | pub_conf_thres = 0.25 36 | pub_iou_thres = 0.24 37 | pub_classes = None 38 | pub_agnostic_nms = False 39 | pub_save_conf = False 40 | 41 | # Initialize 42 | set_logging() 43 | device = select_device(pub_device) 44 | half = device.type != 'cpu' # half precision only supported on CUDA 45 | 46 | # Load model 47 | model = attempt_load(pub_weights, map_location=pub_device) # load FP32 model 48 | stride = int(model.stride.max()) # model stride 49 | imgsz = check_img_size(pub_img_size, s=stride) # check img_size 50 | if half: 51 | model.half() # to FP16 52 | 53 | 54 | def invoice_detection(file_name=None, invoice=None): 55 | uid = hashlib.md5(file_name.encode("utf8")).hexdigest() 56 | i = 0 57 | qrData = "" 58 | pub_source = file_name 59 | source, weights, view_img, save_txt, imgsz = pub_source, pub_weights, pub_view_img, pub_save_txt, pub_img_size 60 | save_img = True 61 | webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith( 62 | ('rtsp://', 'rtmp://', 'http://', 'https://')) 63 | 64 | # Directories 65 | save_dir = Path(pub_project) # increment run 66 | 67 | # Set Dataloader 68 | vid_path, vid_writer = None, None 69 | if webcam: 70 | view_img = check_imshow() 71 | cudnn.benchmark = True # set True to speed up constant image size inference 72 | dataset = LoadStreams(source, img_size=imgsz, stride=stride) 73 | else: 74 | dataset = LoadImages(source, img_size=imgsz, stride=stride) 75 | 76 | # Get names and colors 77 | names = model.module.names if hasattr(model, 'module') else model.names 78 | colors = [[random.randint(0, 255) for _ in range(3)] for _ in names] 79 | 80 | # Run inference 81 | if device.type != 'cpu': 82 | model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once 83 | t0 = time.time() 84 | for path, img, im0s, vid_cap in dataset: 85 | img = torch.from_numpy(img).to(device) 86 | img = img.half() if half else img.float() # uint8 to fp16/32 87 | img /= 255.0 # 0 - 255 to 0.0 - 1.0 88 | if img.ndimension() == 3: 89 | img = img.unsqueeze(0) 90 | 91 | lock.acquire(timeout=3) 92 | # Inference 93 | pred = model(img, augment=pub_augment)[0] 94 | lock.release() 95 | 96 | # Apply NMS 97 | pred = non_max_suppression(pred, pub_conf_thres, pub_iou_thres, classes=pub_classes, agnostic=pub_agnostic_nms) 98 | 99 | # Process detections 100 | for i, det in enumerate(pred): # detections per image 101 | if webcam: # batch_size >= 1 102 | p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count 103 | else: 104 | p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0) 105 | 106 | p = Path(p) # to Path 107 | save_path = str(save_dir / p.name) # img.jpg 108 | txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt 109 | s += '%gx%g ' % img.shape[2:] # print string 110 | gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh 111 | if len(det): 112 | # Rescale boxes from img_size to im0 size 113 | det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round() 114 | 115 | # Print results 116 | for c in det[:, -1].unique(): 117 | n = (det[:, -1] == c).sum() # detections per class 118 | s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string 119 | 120 | # Write results 121 | for *xyxy, conf, cls in reversed(det): 122 | if save_txt: # Write to file 123 | xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh 124 | line = (cls, *xywh, conf) if pub_save_conf else (cls, *xywh) # label format 125 | with open(txt_path + '.txt', 'a') as f: 126 | f.write(('%g ' * len(line)).rstrip() % line + '\n') 127 | 128 | orgfilename = file_name.rsplit('.', 1)[0].lower() 129 | orgfilename = orgfilename.rsplit('/', 1)[1] 130 | if save_img or view_img: # Add bbox to image 131 | label = names[int(cls)] 132 | newimg = im0[int(xyxy[1]):int(xyxy[3]), int(xyxy[0]): int(xyxy[2])] 133 | i = i + 1 134 | if "QRCode" == label: 135 | qrData = newimg 136 | invoice[converter[label]] = ccrnnOcr(newimg) 137 | if use == 1: 138 | no_tax(label, newimg, invoice[converter[label]]) 139 | for val in converter.values(): 140 | if invoice.get(val) is None: 141 | invoice[val] = "0" 142 | qrcode(qrData, invoice) 143 | return invoice 144 | -------------------------------------------------------------------------------- /obj_det/evat_detect.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | from pathlib import Path 4 | import os 5 | 6 | import cv2 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | from numpy import random 10 | 11 | from models.experimental import attempt_load 12 | from utils.datasets import LoadStreams, LoadImages 13 | from utils.general import check_img_size, check_imshow, non_max_suppression, \ 14 | scale_coords, xyxy2xywh, strip_optimizer, set_logging 15 | from utils.torch_utils import select_device, time_synchronized 16 | from util.qrcode import qrcode 17 | import threading 18 | import hashlib 19 | from util.create_img import no_tax 20 | from util.create_img import use 21 | 22 | lock = threading.Lock() 23 | 24 | converter = {'invoice_code': 'invoice_code', 'invoice_number': 'invoice_number', 'totalAmount': 'totalAmount', 25 | 'billingDate': 'billingDate', 'checkcode': 'checkCode', 26 | 'QRCode': 'QRCode', 'title': 'title'} 27 | pub_weights = "models/evat/best.pt" 28 | pub_view_img = False 29 | pub_save_txt = False 30 | pub_img_size = 640 31 | pub_nosave = False 32 | pub_project = "images" 33 | pub_device = "cpu" 34 | pub_augment = False 35 | pub_conf_thres = 0.25 36 | pub_iou_thres = 0.24 37 | pub_classes = None 38 | pub_agnostic_nms = False 39 | pub_save_conf = False 40 | 41 | # Initialize 42 | set_logging() 43 | device = select_device(pub_device) 44 | half = device.type != 'cpu' # half precision only supported on CUDA 45 | 46 | # Load model 47 | model = attempt_load(pub_weights, map_location=pub_device) # load FP32 model 48 | stride = int(model.stride.max()) # model stride 49 | imgsz = check_img_size(pub_img_size, s=stride) # check img_size 50 | if half: 51 | model.half() # to FP16 52 | 53 | 54 | def invoice_detection(file_name=None, invoice=None, context=None): 55 | uid = hashlib.md5(file_name.encode("utf8")).hexdigest() 56 | i = 0 57 | 58 | pub_source = file_name 59 | source, weights, view_img, save_txt, imgsz = pub_source, pub_weights, pub_view_img, pub_save_txt, pub_img_size 60 | save_img = True 61 | webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith( 62 | ('rtsp://', 'rtmp://', 'http://', 'https://')) 63 | 64 | # Directories 65 | save_dir = Path(pub_project) # increment run 66 | 67 | # Set Dataloader 68 | vid_path, vid_writer = None, None 69 | if webcam: 70 | view_img = check_imshow() 71 | cudnn.benchmark = True # set True to speed up constant image size inference 72 | dataset = LoadStreams(source, img_size=imgsz, stride=stride) 73 | else: 74 | dataset = LoadImages(source, img_size=imgsz, stride=stride) 75 | 76 | # Get names and colors 77 | names = model.module.names if hasattr(model, 'module') else model.names 78 | colors = [[random.randint(0, 255) for _ in range(3)] for _ in names] 79 | 80 | # Run inference 81 | if device.type != 'cpu': 82 | model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once 83 | t0 = time.time() 84 | for path, img, im0s, vid_cap in dataset: 85 | img = torch.from_numpy(img).to(device) 86 | img = img.half() if half else img.float() # uint8 to fp16/32 87 | img /= 255.0 # 0 - 255 to 0.0 - 1.0 88 | if img.ndimension() == 3: 89 | img = img.unsqueeze(0) 90 | 91 | lock.acquire(timeout=3) 92 | # Inference 93 | pred = model(img, augment=pub_augment)[0] 94 | lock.release() 95 | 96 | # Apply NMS 97 | pred = non_max_suppression(pred, pub_conf_thres, pub_iou_thres, classes=pub_classes, agnostic=pub_agnostic_nms) 98 | t2 = time_synchronized() 99 | 100 | # Process detections 101 | for i, det in enumerate(pred): # detections per image 102 | if webcam: # batch_size >= 1 103 | p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count 104 | else: 105 | p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0) 106 | 107 | p = Path(p) # to Path 108 | save_path = str(save_dir / p.name) # img.jpg 109 | txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt 110 | s += '%gx%g ' % img.shape[2:] # print string 111 | gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh 112 | if len(det): 113 | # Rescale boxes from img_size to im0 size 114 | det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round() 115 | 116 | # Print results 117 | for c in det[:, -1].unique(): 118 | n = (det[:, -1] == c).sum() # detections per class 119 | s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string 120 | 121 | # Write results 122 | for *xyxy, conf, cls in reversed(det): 123 | if save_txt: # Write to file 124 | xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh 125 | line = (cls, *xywh, conf) if pub_save_conf else (cls, *xywh) # label format 126 | with open(txt_path + '.txt', 'a') as f: 127 | f.write(('%g ' * len(line)).rstrip() % line + '\n') 128 | 129 | orgfilename = file_name.rsplit('.', 1)[0].lower() 130 | orgfilename = orgfilename.rsplit('/', 1)[1] 131 | if save_img or view_img: # Add bbox to image 132 | label = names[int(cls)] 133 | if "title" == label: 134 | continue 135 | newimg = im0[int(xyxy[1]):int(xyxy[3]), int(xyxy[0]): int(xyxy[2])] 136 | if "QRCode" == label and qrcode(newimg, invoice): 137 | break 138 | invoice[converter[label]] = context.chineseModel(newimg) 139 | if use == 1: 140 | no_tax(label, newimg, invoice[converter[label]]) 141 | for val in converter.values(): 142 | if invoice.get(val) is None: 143 | invoice[val] = "0" 144 | return invoice 145 | -------------------------------------------------------------------------------- /obj_det/no_tax_detect.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | from pathlib import Path 4 | import os 5 | 6 | import cv2 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | from numpy import random 10 | 11 | from models.experimental import attempt_load 12 | from utils.datasets import LoadStreams, LoadImages 13 | from utils.general import check_img_size, check_imshow, non_max_suppression, \ 14 | scale_coords, xyxy2xywh, set_logging 15 | from utils.torch_utils import select_device, time_synchronized 16 | import hashlib 17 | from util.qrcode import qrcode_no_tax 18 | from util.create_img import no_tax 19 | from util.create_img import use 20 | 21 | import threading 22 | 23 | lock = threading.Lock() 24 | 25 | pub_weights = "models/no_tax/best.pt" 26 | pub_view_img = False 27 | pub_save_txt = False 28 | pub_img_size = 640 29 | pub_nosave = False 30 | pub_project = "images" 31 | pub_device = "cpu" 32 | pub_augment = False 33 | pub_conf_thres = 0.25 34 | pub_iou_thres = 0.24 35 | pub_classes = None 36 | pub_agnostic_nms = False 37 | pub_save_conf = False 38 | 39 | # Initialize 40 | set_logging() 41 | device = select_device(pub_device) 42 | half = device.type != 'cpu' # half precision only supported on CUDA 43 | 44 | # Load model 45 | model = attempt_load(pub_weights, map_location=pub_device) # load FP32 model 46 | stride = int(model.stride.max()) # model stride 47 | imgsz = check_img_size(pub_img_size, s=stride) # check img_size 48 | if half: 49 | model.half() # to FP16 50 | 51 | 52 | def invoice_number_process(img): 53 | height, width, _ = img.shape 54 | for i in range(height): 55 | for j in range(width): 56 | dot = img[i, j] 57 | dot0 = dot[0] 58 | dot1 = dot[1] 59 | dot2 = dot[2] 60 | if dot2 < dot1 or dot2 < dot0: 61 | img[i, j] = [0, 0, 0] 62 | continue 63 | return img 64 | 65 | 66 | def invoice_detection(file_name=None, invoice=None, context=None): 67 | uid = hashlib.md5(file_name.encode("utf8")).hexdigest() 68 | i = 0 69 | 70 | pub_source = file_name 71 | source, weights, view_img, save_txt, imgsz = pub_source, pub_weights, pub_view_img, pub_save_txt, pub_img_size 72 | save_img = True 73 | webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith( 74 | ('rtsp://', 'rtmp://', 'http://', 'https://')) 75 | 76 | # Directories 77 | save_dir = Path(pub_project) # increment run 78 | 79 | # Set Dataloader 80 | vid_path, vid_writer = None, None 81 | if webcam: 82 | view_img = check_imshow() 83 | cudnn.benchmark = True # set True to speed up constant image size inference 84 | dataset = LoadStreams(source, img_size=imgsz, stride=stride) 85 | else: 86 | dataset = LoadImages(source, img_size=imgsz, stride=stride) 87 | 88 | # Get names and colors 89 | names = model.module.names if hasattr(model, 'module') else model.names 90 | colors = [[random.randint(0, 255) for _ in range(3)] for _ in names] 91 | 92 | # Run inference 93 | if device.type != 'cpu': 94 | model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once 95 | t0 = time.time() 96 | for path, img, im0s, vid_cap in dataset: 97 | img = torch.from_numpy(img).to(device) 98 | img = img.half() if half else img.float() # uint8 to fp16/32 99 | img /= 255.0 # 0 - 255 to 0.0 - 1.0 100 | if img.ndimension() == 3: 101 | img = img.unsqueeze(0) 102 | 103 | lock.acquire(timeout=3) 104 | # Inference 105 | pred = model(img, augment=pub_augment)[0] 106 | lock.release() 107 | 108 | # Apply NMS 109 | pred = non_max_suppression(pred, pub_conf_thres, pub_iou_thres, classes=pub_classes, agnostic=pub_agnostic_nms) 110 | t2 = time_synchronized() 111 | 112 | # Process detections 113 | for i, det in enumerate(pred): # detections per image 114 | if webcam: # batch_size >= 1 115 | p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count 116 | else: 117 | p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0) 118 | 119 | p = Path(p) # to Path 120 | save_path = str(save_dir / p.name) # img.jpg 121 | txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt 122 | s += '%gx%g ' % img.shape[2:] # print string 123 | gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh 124 | if len(det): 125 | # Rescale boxes from img_size to im0 size 126 | det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round() 127 | 128 | # Print results 129 | for c in det[:, -1].unique(): 130 | n = (det[:, -1] == c).sum() # detections per class 131 | s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string 132 | 133 | # Write results 134 | for *xyxy, conf, cls in reversed(det): 135 | if save_txt: # Write to file 136 | xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh 137 | line = (cls, *xywh, conf) if pub_save_conf else (cls, *xywh) # label format 138 | with open(txt_path + '.txt', 'a') as f: 139 | f.write(('%g ' * len(line)).rstrip() % line + '\n') 140 | 141 | orgfilename = file_name.rsplit('.', 1)[0].lower() 142 | orgfilename = orgfilename.rsplit('/', 1)[1] 143 | if save_img or view_img: # Add bbox to image 144 | label = names[int(cls)] 145 | newimg = im0[int(xyxy[1]):int(xyxy[3]), int(xyxy[0]): int(xyxy[2])] 146 | i = i + 1 147 | if "qrcode" == label and qrcode_no_tax(newimg, invoice): 148 | break 149 | invoice[label] = context.chineseModel(newimg) 150 | if use == 1: 151 | no_tax(label, newimg, invoice[label]) 152 | 153 | return invoice 154 | -------------------------------------------------------------------------------- /obj_det/vat_detect.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | from pathlib import Path 4 | import os 5 | 6 | import cv2 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | from numpy import random 10 | 11 | from models.experimental import attempt_load 12 | from utils.datasets import LoadStreams, LoadImages 13 | from utils.general import check_img_size, check_imshow, non_max_suppression, \ 14 | scale_coords, xyxy2xywh, set_logging 15 | from utils.torch_utils import select_device, time_synchronized 16 | import hashlib 17 | from util.qrcode import qrcode 18 | from util.create_img import no_tax 19 | from util.create_img import use 20 | 21 | import threading 22 | 23 | lock = threading.Lock() 24 | 25 | converter = {'invoice_code': 'invoice_code', 'invoice_number': 'invoice_number', 'totalAmount': 'totalAmount', 26 | 'billingDate': 'billingDate', 'checkcode': 'checkCode', 27 | 'QRCode': 'QRCode', 'title': 'title'} 28 | pub_weights = "models/vat/best.pt" 29 | pub_view_img = False 30 | pub_save_txt = False 31 | pub_img_size = 640 32 | pub_nosave = False 33 | pub_project = "images" 34 | pub_device = "cpu" 35 | pub_augment = False 36 | pub_conf_thres = 0.25 37 | pub_iou_thres = 0.24 38 | pub_classes = None 39 | pub_agnostic_nms = False 40 | pub_save_conf = False 41 | 42 | # Initialize 43 | set_logging() 44 | device = select_device(pub_device) 45 | half = device.type != 'cpu' # half precision only supported on CUDA 46 | 47 | # Load model 48 | model = attempt_load(pub_weights, map_location=pub_device) # load FP32 model 49 | stride = int(model.stride.max()) # model stride 50 | imgsz = check_img_size(pub_img_size, s=stride) # check img_size 51 | if half: 52 | model.half() # to FP16 53 | 54 | 55 | def invoice_detection(file_name=None, invoice=None, context=None): 56 | uid = hashlib.md5(file_name.encode("utf8")).hexdigest() 57 | i = 0 58 | 59 | pub_source = file_name 60 | source, weights, view_img, save_txt, imgsz = pub_source, pub_weights, pub_view_img, pub_save_txt, pub_img_size 61 | save_img = True 62 | webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith( 63 | ('rtsp://', 'rtmp://', 'http://', 'https://')) 64 | 65 | # Directories 66 | save_dir = Path(pub_project) # increment run 67 | 68 | # Set Dataloader 69 | vid_path, vid_writer = None, None 70 | if webcam: 71 | view_img = check_imshow() 72 | cudnn.benchmark = True # set True to speed up constant image size inference 73 | dataset = LoadStreams(source, img_size=imgsz, stride=stride) 74 | else: 75 | dataset = LoadImages(source, img_size=imgsz, stride=stride) 76 | 77 | # Get names and colors 78 | names = model.module.names if hasattr(model, 'module') else model.names 79 | colors = [[random.randint(0, 255) for _ in range(3)] for _ in names] 80 | 81 | # Run inference 82 | if device.type != 'cpu': 83 | model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once 84 | t0 = time.time() 85 | for path, img, im0s, vid_cap in dataset: 86 | img = torch.from_numpy(img).to(device) 87 | img = img.half() if half else img.float() # uint8 to fp16/32 88 | img /= 255.0 # 0 - 255 to 0.0 - 1.0 89 | if img.ndimension() == 3: 90 | img = img.unsqueeze(0) 91 | 92 | lock.acquire(timeout=3) 93 | # Inference 94 | pred = model(img, augment=pub_augment)[0] 95 | lock.release() 96 | 97 | # Apply NMS 98 | pred = non_max_suppression(pred, pub_conf_thres, pub_iou_thres, classes=pub_classes, agnostic=pub_agnostic_nms) 99 | t2 = time_synchronized() 100 | 101 | # Process detections 102 | for i, det in enumerate(pred): # detections per image 103 | if webcam: # batch_size >= 1 104 | p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count 105 | else: 106 | p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0) 107 | 108 | p = Path(p) # to Path 109 | save_path = str(save_dir / p.name) # img.jpg 110 | txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt 111 | s += '%gx%g ' % img.shape[2:] # print string 112 | gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh 113 | if len(det): 114 | # Rescale boxes from img_size to im0 size 115 | det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round() 116 | 117 | # Print results 118 | for c in det[:, -1].unique(): 119 | n = (det[:, -1] == c).sum() # detections per class 120 | s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string 121 | 122 | # Write results 123 | for *xyxy, conf, cls in reversed(det): 124 | if save_txt: # Write to file 125 | xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh 126 | line = (cls, *xywh, conf) if pub_save_conf else (cls, *xywh) # label format 127 | with open(txt_path + '.txt', 'a') as f: 128 | f.write(('%g ' * len(line)).rstrip() % line + '\n') 129 | 130 | orgfilename = file_name.rsplit('.', 1)[0].lower() 131 | orgfilename = orgfilename.rsplit('/', 1)[1] 132 | if save_img or view_img: # Add bbox to image 133 | label = names[int(cls)] 134 | if "title" == label: 135 | continue 136 | newimg = im0[int(xyxy[1]):int(xyxy[3]), int(xyxy[0]): int(xyxy[2])] 137 | i = i + 1 138 | if "QRCode" == label and qrcode(newimg, invoice): 139 | break 140 | invoice[converter[label]] = context.chineseModel(newimg) 141 | if use == 1: 142 | no_tax(label, newimg, invoice[converter[label]]) 143 | for val in converter.values(): 144 | if invoice.get(val) is None: 145 | invoice[val] = "0" 146 | return invoice 147 | -------------------------------------------------------------------------------- /obj_det/taxi_detect.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | from pathlib import Path 4 | import os 5 | 6 | import cv2 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | from numpy import random 10 | 11 | from models.experimental import attempt_load 12 | from utils.datasets import LoadStreams, LoadImages 13 | from utils.general import check_img_size, check_imshow, non_max_suppression, \ 14 | scale_coords, xyxy2xywh, strip_optimizer, set_logging 15 | from utils.torch_utils import select_device, time_synchronized 16 | import hashlib 17 | from crnn.crnn_torch_chinese import crnnOcr as ccrnnOcr 18 | from util.create_img import no_tax 19 | from util.create_img import use 20 | 21 | import threading 22 | 23 | lock = threading.Lock() 24 | 25 | converter = {'invoice_code': 'invoice_code', 'invoice_number': 'invoice_number', 'totalAmount': 'amountTax'} 26 | pub_weights = "models/taxi/best.pt" 27 | pub_view_img = False 28 | pub_save_txt = False 29 | pub_img_size = 640 30 | pub_nosave = False 31 | pub_project = "images" 32 | pub_device = "cpu" 33 | pub_augment = False 34 | pub_conf_thres = 0.25 35 | pub_iou_thres = 0.24 36 | pub_classes = None 37 | pub_agnostic_nms = False 38 | pub_save_conf = False 39 | 40 | # Initialize 41 | set_logging() 42 | device = select_device(pub_device) 43 | half = device.type != 'cpu' # half precision only supported on CUDA 44 | 45 | # Load model 46 | model = attempt_load(pub_weights, map_location=pub_device) # load FP32 model 47 | stride = int(model.stride.max()) # model stride 48 | imgsz = check_img_size(pub_img_size, s=stride) # check img_size 49 | if half: 50 | model.half() # to FP16 51 | 52 | 53 | def invoice_number_process(img): 54 | height, width, _ = img.shape 55 | for i in range(height): 56 | for j in range(width): 57 | dot = img[i, j] 58 | dot0 = dot[0] 59 | dot1 = dot[1] 60 | dot2 = dot[2] 61 | if dot2 < dot1 or dot2 < dot0: 62 | img[i, j] = [0, 0, 0] 63 | continue 64 | return img 65 | 66 | 67 | def invoice_detection(file_name=None, invoice=None): 68 | uid = hashlib.md5(file_name.encode("utf8")).hexdigest() 69 | i = 0 70 | 71 | pub_source = file_name 72 | source, weights, view_img, save_txt, imgsz = pub_source, pub_weights, pub_view_img, pub_save_txt, pub_img_size 73 | save_img = True 74 | webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith( 75 | ('rtsp://', 'rtmp://', 'http://', 'https://')) 76 | 77 | # Directories 78 | save_dir = Path(pub_project) # increment run 79 | 80 | # Set Dataloader 81 | vid_path, vid_writer = None, None 82 | if webcam: 83 | view_img = check_imshow() 84 | cudnn.benchmark = True # set True to speed up constant image size inference 85 | dataset = LoadStreams(source, img_size=imgsz, stride=stride) 86 | else: 87 | dataset = LoadImages(source, img_size=imgsz, stride=stride) 88 | 89 | # Get names and colors 90 | names = model.module.names if hasattr(model, 'module') else model.names 91 | colors = [[random.randint(0, 255) for _ in range(3)] for _ in names] 92 | 93 | # Run inference 94 | if device.type != 'cpu': 95 | model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once 96 | t0 = time.time() 97 | for path, img, im0s, vid_cap in dataset: 98 | img = torch.from_numpy(img).to(device) 99 | img = img.half() if half else img.float() # uint8 to fp16/32 100 | img /= 255.0 # 0 - 255 to 0.0 - 1.0 101 | if img.ndimension() == 3: 102 | img = img.unsqueeze(0) 103 | 104 | lock.acquire(timeout=3) 105 | # Inference 106 | pred = model(img, augment=pub_augment)[0] 107 | lock.release() 108 | 109 | # Apply NMS 110 | pred = non_max_suppression(pred, pub_conf_thres, pub_iou_thres, classes=pub_classes, agnostic=pub_agnostic_nms) 111 | t2 = time_synchronized() 112 | 113 | # Process detections 114 | for i, det in enumerate(pred): # detections per image 115 | if webcam: # batch_size >= 1 116 | p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count 117 | else: 118 | p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0) 119 | 120 | p = Path(p) # to Path 121 | save_path = str(save_dir / p.name) # img.jpg 122 | txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt 123 | s += '%gx%g ' % img.shape[2:] # print string 124 | gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh 125 | if len(det): 126 | # Rescale boxes from img_size to im0 size 127 | det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round() 128 | 129 | # Print results 130 | for c in det[:, -1].unique(): 131 | n = (det[:, -1] == c).sum() # detections per class 132 | s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string 133 | 134 | # Write results 135 | for *xyxy, conf, cls in reversed(det): 136 | if save_txt: # Write to file 137 | xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh 138 | line = (cls, *xywh, conf) if pub_save_conf else (cls, *xywh) # label format 139 | with open(txt_path + '.txt', 'a') as f: 140 | f.write(('%g ' * len(line)).rstrip() % line + '\n') 141 | 142 | orgfilename = file_name.rsplit('.', 1)[0].lower() 143 | orgfilename = orgfilename.rsplit('/', 1)[1] 144 | if save_img or view_img: # Add bbox to image 145 | label = names[int(cls)] 146 | newimg = im0[int(xyxy[1]):int(xyxy[3]), int(xyxy[0]): int(xyxy[2])] 147 | i = i + 1 148 | invoice[converter[label]] = ccrnnOcr(newimg) 149 | if use == 1: 150 | no_tax(label, newimg, invoice[converter[label]]) 151 | for val in converter.values(): 152 | if invoice.get(val) is None: 153 | invoice[val] = "0" 154 | return invoice 155 | -------------------------------------------------------------------------------- /obj_det/detect.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | from pathlib import Path 4 | import os 5 | 6 | import cv2 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | from numpy import random 10 | 11 | from models.experimental import attempt_load 12 | from utils.datasets import LoadStreams, LoadImages 13 | from utils.general import check_img_size, check_imshow, non_max_suppression, \ 14 | scale_coords, xyxy2xywh, strip_optimizer, set_logging 15 | from utils.torch_utils import select_device, time_synchronized 16 | import threading 17 | 18 | lock = threading.Lock() 19 | 20 | converter = {'01':'增值税专用发票', '04': '增值税普通发票', '08': '增值税电子专用发票', '10': '增值税电子普通发票', '11': '增值税普通发票(卷式)', '81': '非税财政电子票据', '86': '过路费发票', '88': '火车票', '89': '飞机票', '90':'客运票', '92':'出租车票', '93':'定额', '95':'通用机打发票'} 21 | pub_weights = "models/mutipart/best.pt" 22 | pub_view_img = False 23 | pub_save_txt = False 24 | pub_img_size = 640 25 | pub_nosave = False 26 | pub_project = "images" 27 | pub_device = "cpu" 28 | pub_augment = False 29 | pub_conf_thres = 0.25 30 | pub_iou_thres = 0.24 31 | pub_classes = None 32 | pub_agnostic_nms = False 33 | pub_save_conf = False 34 | 35 | # Initialize 36 | set_logging() 37 | device = select_device(pub_device) 38 | half = device.type != 'cpu' # half precision only supported on CUDA 39 | 40 | # Load model 41 | model = attempt_load(pub_weights, map_location=pub_device) # load FP32 model 42 | stride = int(model.stride.max()) # model stride 43 | imgsz = check_img_size(pub_img_size, s=stride) # check img_size 44 | if half: 45 | model.half() # to FP16 46 | 47 | 48 | def invoice_number_process(img): 49 | height, width, _ = img.shape 50 | for i in range(height): 51 | for j in range(width): 52 | dot = img[i, j] 53 | dot0 = dot[0] 54 | dot1 = dot[1] 55 | dot2 = dot[2] 56 | if dot2 < dot1 or dot2 < dot0: 57 | img[i, j] = [0, 0, 0] 58 | continue 59 | return img 60 | 61 | 62 | def invoice_detection(file_name=None): 63 | i = 0 64 | invoices = [] 65 | pub_source = "images/" + file_name 66 | source, weights, view_img, save_txt, imgsz = pub_source, pub_weights, pub_view_img, pub_save_txt, pub_img_size 67 | save_img = True 68 | webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith( 69 | ('rtsp://', 'rtmp://', 'http://', 'https://')) 70 | 71 | # Directories 72 | save_dir = Path(pub_project) # increment run 73 | 74 | # Set Dataloader 75 | vid_path, vid_writer = None, None 76 | if webcam: 77 | view_img = check_imshow() 78 | cudnn.benchmark = True # set True to speed up constant image size inference 79 | dataset = LoadStreams(source, img_size=imgsz, stride=stride) 80 | else: 81 | dataset = LoadImages(source, img_size=imgsz, stride=stride) 82 | 83 | # Get names and colors 84 | names = model.module.names if hasattr(model, 'module') else model.names 85 | colors = [[random.randint(0, 255) for _ in range(3)] for _ in names] 86 | 87 | # Run inference 88 | if device.type != 'cpu': 89 | model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once 90 | t0 = time.time() 91 | for path, img, im0s, vid_cap in dataset: 92 | img = torch.from_numpy(img).to(device) 93 | img = img.half() if half else img.float() # uint8 to fp16/32 94 | img /= 255.0 # 0 - 255 to 0.0 - 1.0 95 | if img.ndimension() == 3: 96 | img = img.unsqueeze(0) 97 | 98 | # Inference 99 | lock.acquire(timeout=3) 100 | pred = model(img, augment=pub_augment)[0] 101 | lock.release() 102 | 103 | # Apply NMS 104 | pred = non_max_suppression(pred, pub_conf_thres, pub_iou_thres, classes=pub_classes, agnostic=pub_agnostic_nms) 105 | 106 | # Process detections 107 | for i, det in enumerate(pred): # detections per image 108 | if webcam: # batch_size >= 1 109 | p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count 110 | else: 111 | p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0) 112 | 113 | p = Path(p) # to Path 114 | save_path = str(save_dir / p.name) # img.jpg 115 | txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt 116 | s += '%gx%g ' % img.shape[2:] # print string 117 | gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh 118 | if len(det): 119 | # Rescale boxes from img_size to im0 size 120 | det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round() 121 | 122 | # Print results 123 | for c in det[:, -1].unique(): 124 | n = (det[:, -1] == c).sum() # detections per class 125 | s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string 126 | 127 | # Write results 128 | for *xyxy, conf, cls in reversed(det): 129 | if save_txt: # Write to file 130 | xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh 131 | line = (cls, *xywh, conf) if pub_save_conf else (cls, *xywh) # label format 132 | with open(txt_path + '.txt', 'a') as f: 133 | f.write(('%g ' * len(line)).rstrip() % line + '\n') 134 | 135 | orgfilename = file_name.rsplit('.', 1)[0].lower() 136 | type = file_name.rsplit('.', 1)[1].lower() 137 | if save_img or view_img: # Add bbox to image 138 | label = names[int(cls)] 139 | newimg = im0[int(xyxy[1]):int(xyxy[3]), int(xyxy[0]): int(xyxy[2])] 140 | dir = "images/" + orgfilename 141 | if not os.path.exists(dir): 142 | os.makedirs(dir) 143 | newPath = dir + "/" + str(i) + "_" + orgfilename + "." + type 144 | cv2.imwrite(newPath, newimg) 145 | i = i + 1 146 | coordinate = [int(xyxy[0]), int(xyxy[1]), int(xyxy[2]), int(xyxy[3]), int(xyxy[0]), int(xyxy[1]), int(xyxy[2]), int(xyxy[3])] 147 | json = {"invoiceType": label, "invoice_type_name": converter[label], "file_path": newPath, "coordinate": coordinate} 148 | invoices.append(json) 149 | return invoices 150 | -------------------------------------------------------------------------------- /obj_det/tra_detect.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | from pathlib import Path 4 | import os 5 | 6 | import cv2 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | from numpy import random 10 | 11 | from models.experimental import attempt_load 12 | from utils.datasets import LoadStreams, LoadImages 13 | from utils.general import check_img_size, check_imshow, non_max_suppression, \ 14 | scale_coords, xyxy2xywh, strip_optimizer, set_logging 15 | from utils.torch_utils import select_device, time_synchronized 16 | import hashlib 17 | from crnn.crnn_torch import crnnOcr as crnnOcr 18 | from crnn.crnn_torch_chinese import crnnOcr as ccrnnOcr 19 | from util.create_img import no_tax 20 | from util.create_img import use 21 | 22 | import threading 23 | 24 | lock = threading.Lock() 25 | 26 | converter = {'invoice_number': 'invoice_number', 'totalAmount': 'amountTax', 'ids': 'ids', 'date': 'billingDate', 'checi': 'checi', 27 | 'start': 'starStation', 'end': 'endStation'} 28 | pub_weights = "models/huochepiao/best.pt" 29 | pub_view_img = False 30 | pub_save_txt = False 31 | pub_img_size = 640 32 | pub_nosave = False 33 | pub_project = "images" 34 | pub_device = "cpu" 35 | pub_augment = False 36 | pub_conf_thres = 0.25 37 | pub_iou_thres = 0.24 38 | pub_classes = None 39 | pub_agnostic_nms = False 40 | pub_save_conf = False 41 | 42 | # Initialize 43 | set_logging() 44 | device = select_device(pub_device) 45 | half = device.type != 'cpu' # half precision only supported on CUDA 46 | 47 | # Load model 48 | model = attempt_load(pub_weights, map_location=pub_device) # load FP32 model 49 | stride = int(model.stride.max()) # model stride 50 | imgsz = check_img_size(pub_img_size, s=stride) # check img_size 51 | if half: 52 | model.half() # to FP16 53 | 54 | 55 | def invoice_number_process(img): 56 | height, width, _ = img.shape 57 | for i in range(height): 58 | for j in range(width): 59 | dot = img[i, j] 60 | dot0 = dot[0] 61 | dot1 = dot[1] 62 | dot2 = dot[2] 63 | if dot2 < dot1 or dot2 < dot0: 64 | img[i, j] = [0, 0, 0] 65 | continue 66 | return img 67 | 68 | 69 | def invoice_detection(file_name=None, invoice=None): 70 | uid = hashlib.md5(file_name.encode("utf8")).hexdigest() 71 | i = 0 72 | 73 | pub_source = file_name 74 | source, weights, view_img, save_txt, imgsz = pub_source, pub_weights, pub_view_img, pub_save_txt, pub_img_size 75 | save_img = True 76 | webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith( 77 | ('rtsp://', 'rtmp://', 'http://', 'https://')) 78 | 79 | # Directories 80 | save_dir = Path(pub_project) # increment run 81 | 82 | # Set Dataloader 83 | vid_path, vid_writer = None, None 84 | if webcam: 85 | view_img = check_imshow() 86 | cudnn.benchmark = True # set True to speed up constant image size inference 87 | dataset = LoadStreams(source, img_size=imgsz, stride=stride) 88 | else: 89 | dataset = LoadImages(source, img_size=imgsz, stride=stride) 90 | 91 | # Get names and colors 92 | names = model.module.names if hasattr(model, 'module') else model.names 93 | colors = [[random.randint(0, 255) for _ in range(3)] for _ in names] 94 | 95 | # Run inference 96 | if device.type != 'cpu': 97 | model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once 98 | t0 = time.time() 99 | for path, img, im0s, vid_cap in dataset: 100 | img = torch.from_numpy(img).to(device) 101 | img = img.half() if half else img.float() # uint8 to fp16/32 102 | img /= 255.0 # 0 - 255 to 0.0 - 1.0 103 | if img.ndimension() == 3: 104 | img = img.unsqueeze(0) 105 | 106 | lock.acquire(timeout=3) 107 | # Inference 108 | pred = model(img, augment=pub_augment)[0] 109 | lock.release() 110 | 111 | # Apply NMS 112 | pred = non_max_suppression(pred, pub_conf_thres, pub_iou_thres, classes=pub_classes, agnostic=pub_agnostic_nms) 113 | t2 = time_synchronized() 114 | 115 | # Process detections 116 | for i, det in enumerate(pred): # detections per image 117 | if webcam: # batch_size >= 1 118 | p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count 119 | else: 120 | p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0) 121 | 122 | p = Path(p) # to Path 123 | save_path = str(save_dir / p.name) # img.jpg 124 | txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt 125 | s += '%gx%g ' % img.shape[2:] # print string 126 | gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh 127 | if len(det): 128 | # Rescale boxes from img_size to im0 size 129 | det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round() 130 | 131 | # Print results 132 | for c in det[:, -1].unique(): 133 | n = (det[:, -1] == c).sum() # detections per class 134 | s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string 135 | 136 | # Write results 137 | for *xyxy, conf, cls in reversed(det): 138 | if save_txt: # Write to file 139 | xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh 140 | line = (cls, *xywh, conf) if pub_save_conf else (cls, *xywh) # label format 141 | with open(txt_path + '.txt', 'a') as f: 142 | f.write(('%g ' * len(line)).rstrip() % line + '\n') 143 | 144 | orgfilename = file_name.rsplit('.', 1)[0].lower() 145 | orgfilename = orgfilename.rsplit('/', 1)[1] 146 | if save_img or view_img: # Add bbox to image 147 | label = names[int(cls)] 148 | newimg = im0[int(xyxy[1]):int(xyxy[3]), int(xyxy[0]): int(xyxy[2])] 149 | i = i + 1 150 | if "checi" == label or "invoice_number" == label: 151 | invoice[converter[label]] = crnnOcr(newimg) 152 | else: 153 | invoice[converter[label]] = ccrnnOcr(newimg) 154 | if use == 1: 155 | no_tax(label, newimg, invoice[converter[label]]) 156 | 157 | for val in converter.values(): 158 | if invoice.get(val) is None: 159 | invoice[val] = "0" 160 | return invoice 161 | -------------------------------------------------------------------------------- /obj_det/title_detect.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | from pathlib import Path 4 | import os 5 | 6 | import cv2 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | from numpy import random 10 | 11 | from models.experimental import attempt_load 12 | from utils.datasets import LoadStreams, LoadImages 13 | from utils.general import check_img_size, check_imshow, non_max_suppression, \ 14 | scale_coords, xyxy2xywh, strip_optimizer, set_logging 15 | from utils.torch_utils import select_device, time_synchronized 16 | import hashlib 17 | from crnn.crnn_torch_chinese import crnnOcr as ccrnnOcr 18 | 19 | import threading 20 | 21 | lock = threading.Lock() 22 | 23 | converter_name = {'01':'增值税专用发票', '04': '增值税普通发票', '08': '增值税电子专用发票', '10': '增值税电子普通发票', '11': '增值税普通发票(卷式)', '14': '增值税电子普通发票(通行费)'} 24 | converter = {'增值税专用发票':'01', '增值税普通发票': '04', '增值税电子专用发票': '08', '增值税电子普通发票': '10', '增值税普通发票(卷式)': '11', '增值税电子普通发票(通行费)': '14'} 25 | pub_weights = "models/title/best.pt" 26 | pub_view_img = False 27 | pub_save_txt = False 28 | pub_img_size = 640 29 | pub_nosave = False 30 | pub_project = "images" 31 | pub_device = "cpu" 32 | pub_augment = False 33 | pub_conf_thres = 0.25 34 | pub_iou_thres = 0.24 35 | pub_classes = None 36 | pub_agnostic_nms = False 37 | pub_save_conf = False 38 | 39 | # Initialize 40 | set_logging() 41 | device = select_device(pub_device) 42 | half = device.type != 'cpu' # half precision only supported on CUDA 43 | 44 | # Load model 45 | model = attempt_load(pub_weights, map_location=pub_device) # load FP32 model 46 | stride = int(model.stride.max()) # model stride 47 | imgsz = check_img_size(pub_img_size, s=stride) # check img_size 48 | if half: 49 | model.half() # to FP16 50 | 51 | 52 | def invoice_number_process(img): 53 | height, width, _ = img.shape 54 | for i in range(height): 55 | for j in range(width): 56 | dot = img[i, j] 57 | dot0 = dot[0] 58 | dot1 = dot[1] 59 | dot2 = dot[2] 60 | if dot2 < dot1 or dot2 < dot0: 61 | img[i, j] = [0, 0, 0] 62 | continue 63 | return img 64 | 65 | 66 | def invoice_detection(file_name=None, invoice=None): 67 | title = "" 68 | pub_source = file_name 69 | source, weights, view_img, save_txt, imgsz = pub_source, pub_weights, pub_view_img, pub_save_txt, pub_img_size 70 | save_img = True 71 | webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith( 72 | ('rtsp://', 'rtmp://', 'http://', 'https://')) 73 | 74 | # Directories 75 | save_dir = Path(pub_project) # increment run 76 | 77 | # Set Dataloader 78 | vid_path, vid_writer = None, None 79 | if webcam: 80 | view_img = check_imshow() 81 | cudnn.benchmark = True # set True to speed up constant image size inference 82 | dataset = LoadStreams(source, img_size=imgsz, stride=stride) 83 | else: 84 | dataset = LoadImages(source, img_size=imgsz, stride=stride) 85 | 86 | # Get names and colors 87 | names = model.module.names if hasattr(model, 'module') else model.names 88 | colors = [[random.randint(0, 255) for _ in range(3)] for _ in names] 89 | 90 | # Run inference 91 | if device.type != 'cpu': 92 | model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once 93 | t0 = time.time() 94 | for path, img, im0s, vid_cap in dataset: 95 | img = torch.from_numpy(img).to(device) 96 | img = img.half() if half else img.float() # uint8 to fp16/32 97 | img /= 255.0 # 0 - 255 to 0.0 - 1.0 98 | if img.ndimension() == 3: 99 | img = img.unsqueeze(0) 100 | 101 | lock.acquire(timeout=3) 102 | # Inference 103 | pred = model(img, augment=pub_augment)[0] 104 | lock.release() 105 | 106 | # Apply NMS 107 | pred = non_max_suppression(pred, pub_conf_thres, pub_iou_thres, classes=pub_classes, agnostic=pub_agnostic_nms) 108 | t2 = time_synchronized() 109 | 110 | # Process detections 111 | for i, det in enumerate(pred): # detections per image 112 | if webcam: # batch_size >= 1 113 | p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count 114 | else: 115 | p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0) 116 | 117 | p = Path(p) # to Path 118 | save_path = str(save_dir / p.name) # img.jpg 119 | txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt 120 | s += '%gx%g ' % img.shape[2:] # print string 121 | gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh 122 | if len(det): 123 | # Rescale boxes from img_size to im0 size 124 | det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round() 125 | 126 | # Print results 127 | for c in det[:, -1].unique(): 128 | n = (det[:, -1] == c).sum() # detections per class 129 | s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string 130 | 131 | # Write results 132 | for *xyxy, conf, cls in reversed(det): 133 | if save_txt: # Write to file 134 | xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh 135 | line = (cls, *xywh, conf) if pub_save_conf else (cls, *xywh) # label format 136 | with open(txt_path + '.txt', 'a') as f: 137 | f.write(('%g ' * len(line)).rstrip() % line + '\n') 138 | 139 | orgfilename = file_name.rsplit('.', 1)[0].lower() 140 | orgfilename = orgfilename.rsplit('/', 1)[1] 141 | if save_img or view_img: # Add bbox to image 142 | label = names[int(cls)] 143 | newimg = im0[int(xyxy[1]):int(xyxy[3]), int(xyxy[0]): int(xyxy[2])] 144 | if "title" == label: 145 | title = ccrnnOcr(newimg) 146 | if "go" == label: 147 | go = ccrnnOcr(newimg) 148 | title = "增值税电子普通发票(通行费)" 149 | if "专用发票" in title: 150 | if "电子" in title: 151 | invoice['invoiceType'] = "08" 152 | else: 153 | invoice['invoiceType'] = "01" 154 | if "普通发票" in title: 155 | if "电子" in title: 156 | invoice['invoiceType'] = "10" 157 | else: 158 | invoice['invoiceType'] = "04" 159 | if "增值税电子普通发票(通行费)" in title: 160 | invoice['invoiceType'] = "14" 161 | invoice['invoice_type_name'] = converter_name[invoice['invoiceType']] 162 | return invoice 163 | -------------------------------------------------------------------------------- /utils/autoanchor.py: -------------------------------------------------------------------------------- 1 | # Auto-anchor utils 2 | 3 | import numpy as np 4 | import torch 5 | import yaml 6 | from scipy.cluster.vq import kmeans 7 | from tqdm import tqdm 8 | 9 | from utils.general import colorstr 10 | 11 | 12 | def check_anchor_order(m): 13 | # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary 14 | a = m.anchor_grid.prod(-1).view(-1) # anchor area 15 | da = a[-1] - a[0] # delta a 16 | ds = m.stride[-1] - m.stride[0] # delta s 17 | if da.sign() != ds.sign(): # same order 18 | print('Reversing anchor order') 19 | m.anchors[:] = m.anchors.flip(0) 20 | m.anchor_grid[:] = m.anchor_grid.flip(0) 21 | 22 | 23 | def check_anchors(dataset, model, thr=4.0, imgsz=640): 24 | # Check anchor fit to data, recompute if necessary 25 | prefix = colorstr('autoanchor: ') 26 | print(f'\n{prefix}Analyzing anchors... ', end='') 27 | m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect() 28 | shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True) 29 | scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale 30 | wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float() # wh 31 | 32 | def metric(k): # compute metric 33 | r = wh[:, None] / k[None] 34 | x = torch.min(r, 1. / r).min(2)[0] # ratio metric 35 | best = x.max(1)[0] # best_x 36 | aat = (x > 1. / thr).float().sum(1).mean() # anchors above threshold 37 | bpr = (best > 1. / thr).float().mean() # best possible recall 38 | return bpr, aat 39 | 40 | anchors = m.anchor_grid.clone().cpu().view(-1, 2) # current anchors 41 | bpr, aat = metric(anchors) 42 | print(f'anchors/target = {aat:.2f}, Best Possible Recall (BPR) = {bpr:.4f}', end='') 43 | if bpr < 0.98: # threshold to recompute 44 | print('. Attempting to improve anchors, please wait...') 45 | na = m.anchor_grid.numel() // 2 # number of anchors 46 | try: 47 | anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False) 48 | except Exception as e: 49 | print(f'{prefix}ERROR: {e}') 50 | new_bpr = metric(anchors)[0] 51 | if new_bpr > bpr: # replace anchors 52 | anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors) 53 | m.anchor_grid[:] = anchors.clone().view_as(m.anchor_grid) # for inference 54 | m.anchors[:] = anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss 55 | check_anchor_order(m) 56 | print(f'{prefix}New anchors saved to model. Update model *.yaml to use these anchors in the future.') 57 | else: 58 | print(f'{prefix}Original anchors better than new anchors. Proceeding with original anchors.') 59 | print('') # newline 60 | 61 | 62 | def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True): 63 | """ Creates kmeans-evolved anchors from training dataset 64 | 65 | Arguments: 66 | path: path to dataset *.yaml, or a loaded dataset 67 | n: number of anchors 68 | img_size: image size used for training 69 | thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0 70 | gen: generations to evolve anchors using genetic algorithm 71 | verbose: print all results 72 | 73 | Return: 74 | k: kmeans evolved anchors 75 | 76 | Usage: 77 | from utils.autoanchor import *; _ = kmean_anchors() 78 | """ 79 | thr = 1. / thr 80 | prefix = colorstr('autoanchor: ') 81 | 82 | def metric(k, wh): # compute metrics 83 | r = wh[:, None] / k[None] 84 | x = torch.min(r, 1. / r).min(2)[0] # ratio metric 85 | # x = wh_iou(wh, torch.tensor(k)) # iou metric 86 | return x, x.max(1)[0] # x, best_x 87 | 88 | def anchor_fitness(k): # mutation fitness 89 | _, best = metric(torch.tensor(k, dtype=torch.float32), wh) 90 | return (best * (best > thr).float()).mean() # fitness 91 | 92 | def print_results(k): 93 | k = k[np.argsort(k.prod(1))] # sort small to large 94 | x, best = metric(k, wh0) 95 | bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr 96 | print(f'{prefix}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr') 97 | print(f'{prefix}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, ' 98 | f'past_thr={x[x > thr].mean():.3f}-mean: ', end='') 99 | for i, x in enumerate(k): 100 | print('%i,%i' % (round(x[0]), round(x[1])), end=', ' if i < len(k) - 1 else '\n') # use in *.cfg 101 | return k 102 | 103 | if isinstance(path, str): # *.yaml file 104 | with open(path) as f: 105 | data_dict = yaml.load(f, Loader=yaml.SafeLoader) # model dict 106 | from utils.datasets import LoadImagesAndLabels 107 | dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True) 108 | else: 109 | dataset = path # dataset 110 | 111 | # Get label wh 112 | shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True) 113 | wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh 114 | 115 | # Filter 116 | i = (wh0 < 3.0).any(1).sum() 117 | if i: 118 | print(f'{prefix}WARNING: Extremely small objects found. {i} of {len(wh0)} labels are < 3 pixels in size.') 119 | wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels 120 | # wh = wh * (np.random.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1 121 | 122 | # Kmeans calculation 123 | print(f'{prefix}Running kmeans for {n} anchors on {len(wh)} points...') 124 | s = wh.std(0) # sigmas for whitening 125 | k, dist = kmeans(wh / s, n, iter=30) # points, mean distance 126 | assert len(k) == n, print(f'{prefix}ERROR: scipy.cluster.vq.kmeans requested {n} points but returned only {len(k)}') 127 | k *= s 128 | wh = torch.tensor(wh, dtype=torch.float32) # filtered 129 | wh0 = torch.tensor(wh0, dtype=torch.float32) # unfiltered 130 | k = print_results(k) 131 | 132 | # Plot 133 | # k, d = [None] * 20, [None] * 20 134 | # for i in tqdm(range(1, 21)): 135 | # k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance 136 | # fig, ax = plt.subplots(1, 2, figsize=(14, 7), tight_layout=True) 137 | # ax = ax.ravel() 138 | # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.') 139 | # fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh 140 | # ax[0].hist(wh[wh[:, 0]<100, 0],400) 141 | # ax[1].hist(wh[wh[:, 1]<100, 1],400) 142 | # fig.savefig('wh.png', dpi=200) 143 | 144 | # Evolve 145 | npr = np.random 146 | f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma 147 | pbar = tqdm(range(gen), desc=f'{prefix}Evolving anchors with Genetic Algorithm:') # progress bar 148 | for _ in pbar: 149 | v = np.ones(sh) 150 | while (v == 1).all(): # mutate until a change occurs (prevent duplicates) 151 | v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0) 152 | kg = (k.copy() * v).clip(min=2.0) 153 | fg = anchor_fitness(kg) 154 | if fg > f: 155 | f, k = fg, kg.copy() 156 | pbar.desc = f'{prefix}Evolving anchors with Genetic Algorithm: fitness = {f:.4f}' 157 | if verbose: 158 | print_results(k) 159 | 160 | return print_results(k) 161 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | # Model validation metrics 2 | 3 | from pathlib import Path 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import torch 8 | 9 | from . import general 10 | 11 | 12 | def fitness(x): 13 | # Model fitness as a weighted combination of metrics 14 | w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95] 15 | return (x[:, :4] * w).sum(1) 16 | 17 | 18 | def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=()): 19 | """ Compute the average precision, given the recall and precision curves. 20 | Source: https://github.com/rafaelpadilla/Object-Detection-Metrics. 21 | # Arguments 22 | tp: True positives (nparray, nx1 or nx10). 23 | conf: Objectness value from 0-1 (nparray). 24 | pred_cls: Predicted object classes (nparray). 25 | target_cls: True object classes (nparray). 26 | plot: Plot precision-recall curve at mAP@0.5 27 | save_dir: Plot save directory 28 | # Returns 29 | The average precision as computed in py-faster-rcnn. 30 | """ 31 | 32 | # Sort by objectness 33 | i = np.argsort(-conf) 34 | tp, conf, pred_cls = tp[i], conf[i], pred_cls[i] 35 | 36 | # Find unique classes 37 | unique_classes = np.unique(target_cls) 38 | nc = unique_classes.shape[0] # number of classes, number of detections 39 | 40 | # Create Precision-Recall curve and compute AP for each class 41 | px, py = np.linspace(0, 1, 1000), [] # for plotting 42 | ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000)) 43 | for ci, c in enumerate(unique_classes): 44 | i = pred_cls == c 45 | n_l = (target_cls == c).sum() # number of labels 46 | n_p = i.sum() # number of predictions 47 | 48 | if n_p == 0 or n_l == 0: 49 | continue 50 | else: 51 | # Accumulate FPs and TPs 52 | fpc = (1 - tp[i]).cumsum(0) 53 | tpc = tp[i].cumsum(0) 54 | 55 | # Recall 56 | recall = tpc / (n_l + 1e-16) # recall curve 57 | r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases 58 | 59 | # Precision 60 | precision = tpc / (tpc + fpc) # precision curve 61 | p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score 62 | 63 | # AP from recall-precision curve 64 | for j in range(tp.shape[1]): 65 | ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j]) 66 | if plot and j == 0: 67 | py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5 68 | 69 | # Compute F1 (harmonic mean of precision and recall) 70 | f1 = 2 * p * r / (p + r + 1e-16) 71 | if plot: 72 | plot_pr_curve(px, py, ap, Path(save_dir) / 'PR_curve.png', names) 73 | plot_mc_curve(px, f1, Path(save_dir) / 'F1_curve.png', names, ylabel='F1') 74 | plot_mc_curve(px, p, Path(save_dir) / 'P_curve.png', names, ylabel='Precision') 75 | plot_mc_curve(px, r, Path(save_dir) / 'R_curve.png', names, ylabel='Recall') 76 | 77 | i = f1.mean(0).argmax() # max F1 index 78 | return p[:, i], r[:, i], ap, f1[:, i], unique_classes.astype('int32') 79 | 80 | 81 | def compute_ap(recall, precision): 82 | """ Compute the average precision, given the recall and precision curves 83 | # Arguments 84 | recall: The recall curve (list) 85 | precision: The precision curve (list) 86 | # Returns 87 | Average precision, precision curve, recall curve 88 | """ 89 | 90 | # Append sentinel values to beginning and end 91 | mrec = np.concatenate(([0.], recall, [recall[-1] + 0.01])) 92 | mpre = np.concatenate(([1.], precision, [0.])) 93 | 94 | # Compute the precision envelope 95 | mpre = np.flip(np.maximum.accumulate(np.flip(mpre))) 96 | 97 | # Integrate area under curve 98 | method = 'interp' # methods: 'continuous', 'interp' 99 | if method == 'interp': 100 | x = np.linspace(0, 1, 101) # 101-point interp (COCO) 101 | ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate 102 | else: # 'continuous' 103 | i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes 104 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve 105 | 106 | return ap, mpre, mrec 107 | 108 | 109 | class ConfusionMatrix: 110 | # Updated version of https://github.com/kaanakan/object_detection_confusion_matrix 111 | def __init__(self, nc, conf=0.25, iou_thres=0.45): 112 | self.matrix = np.zeros((nc + 1, nc + 1)) 113 | self.nc = nc # number of classes 114 | self.conf = conf 115 | self.iou_thres = iou_thres 116 | 117 | def process_batch(self, detections, labels): 118 | """ 119 | Return intersection-over-union (Jaccard index) of boxes. 120 | Both sets of boxes are expected to be in (x1, y1, x2, y2) format. 121 | Arguments: 122 | detections (Array[N, 6]), x1, y1, x2, y2, conf, class 123 | labels (Array[M, 5]), class, x1, y1, x2, y2 124 | Returns: 125 | None, updates confusion matrix accordingly 126 | """ 127 | detections = detections[detections[:, 4] > self.conf] 128 | gt_classes = labels[:, 0].int() 129 | detection_classes = detections[:, 5].int() 130 | iou = general.box_iou(labels[:, 1:], detections[:, :4]) 131 | 132 | x = torch.where(iou > self.iou_thres) 133 | if x[0].shape[0]: 134 | matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() 135 | if x[0].shape[0] > 1: 136 | matches = matches[matches[:, 2].argsort()[::-1]] 137 | matches = matches[np.unique(matches[:, 1], return_index=True)[1]] 138 | matches = matches[matches[:, 2].argsort()[::-1]] 139 | matches = matches[np.unique(matches[:, 0], return_index=True)[1]] 140 | else: 141 | matches = np.zeros((0, 3)) 142 | 143 | n = matches.shape[0] > 0 144 | m0, m1, _ = matches.transpose().astype(np.int16) 145 | for i, gc in enumerate(gt_classes): 146 | j = m0 == i 147 | if n and sum(j) == 1: 148 | self.matrix[gc, detection_classes[m1[j]]] += 1 # correct 149 | else: 150 | self.matrix[self.nc, gc] += 1 # background FP 151 | 152 | if n: 153 | for i, dc in enumerate(detection_classes): 154 | if not any(m1 == i): 155 | self.matrix[dc, self.nc] += 1 # background FN 156 | 157 | def matrix(self): 158 | return self.matrix 159 | 160 | def plot(self, save_dir='', names=()): 161 | try: 162 | import seaborn as sn 163 | 164 | array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6) # normalize 165 | array[array < 0.005] = np.nan # don't annotate (would appear as 0.00) 166 | 167 | fig = plt.figure(figsize=(12, 9), tight_layout=True) 168 | sn.set(font_scale=1.0 if self.nc < 50 else 0.8) # for label size 169 | labels = (0 < len(names) < 99) and len(names) == self.nc # apply names to ticklabels 170 | sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True, 171 | xticklabels=names + ['background FP'] if labels else "auto", 172 | yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1)) 173 | fig.axes[0].set_xlabel('True') 174 | fig.axes[0].set_ylabel('Predicted') 175 | fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250) 176 | except Exception as e: 177 | pass 178 | 179 | def print(self): 180 | for i in range(self.nc + 1): 181 | print(' '.join(map(str, self.matrix[i]))) 182 | 183 | 184 | # Plots ---------------------------------------------------------------------------------------------------------------- 185 | 186 | def plot_pr_curve(px, py, ap, save_dir='pr_curve.png', names=()): 187 | # Precision-recall curve 188 | fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) 189 | py = np.stack(py, axis=1) 190 | 191 | if 0 < len(names) < 21: # display per-class legend if < 21 classes 192 | for i, y in enumerate(py.T): 193 | ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision) 194 | else: 195 | ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision) 196 | 197 | ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean()) 198 | ax.set_xlabel('Recall') 199 | ax.set_ylabel('Precision') 200 | ax.set_xlim(0, 1) 201 | ax.set_ylim(0, 1) 202 | plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left") 203 | fig.savefig(Path(save_dir), dpi=250) 204 | 205 | 206 | def plot_mc_curve(px, py, save_dir='mc_curve.png', names=(), xlabel='Confidence', ylabel='Metric'): 207 | # Metric-confidence curve 208 | fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) 209 | 210 | if 0 < len(names) < 21: # display per-class legend if < 21 classes 211 | for i, y in enumerate(py): 212 | ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(confidence, metric) 213 | else: 214 | ax.plot(px, py.T, linewidth=1, color='grey') # plot(confidence, metric) 215 | 216 | y = py.mean(0) 217 | ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}') 218 | ax.set_xlabel(xlabel) 219 | ax.set_ylabel(ylabel) 220 | ax.set_xlim(0, 1) 221 | ax.set_ylim(0, 1) 222 | plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left") 223 | fig.savefig(Path(save_dir), dpi=250) 224 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | # Loss functions 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from utils.general import bbox_iou 7 | from utils.torch_utils import is_parallel 8 | 9 | 10 | def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441 11 | # return positive, negative label smoothing BCE targets 12 | return 1.0 - 0.5 * eps, 0.5 * eps 13 | 14 | 15 | class BCEBlurWithLogitsLoss(nn.Module): 16 | # BCEwithLogitLoss() with reduced missing label effects. 17 | def __init__(self, alpha=0.05): 18 | super(BCEBlurWithLogitsLoss, self).__init__() 19 | self.loss_fcn = nn.BCEWithLogitsLoss(reduction='none') # must be nn.BCEWithLogitsLoss() 20 | self.alpha = alpha 21 | 22 | def forward(self, pred, true): 23 | loss = self.loss_fcn(pred, true) 24 | pred = torch.sigmoid(pred) # prob from logits 25 | dx = pred - true # reduce only missing label effects 26 | # dx = (pred - true).abs() # reduce missing label and false label effects 27 | alpha_factor = 1 - torch.exp((dx - 1) / (self.alpha + 1e-4)) 28 | loss *= alpha_factor 29 | return loss.mean() 30 | 31 | 32 | class FocalLoss(nn.Module): 33 | # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5) 34 | def __init__(self, loss_fcn, gamma=1.5, alpha=0.25): 35 | super(FocalLoss, self).__init__() 36 | self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss() 37 | self.gamma = gamma 38 | self.alpha = alpha 39 | self.reduction = loss_fcn.reduction 40 | self.loss_fcn.reduction = 'none' # required to apply FL to each element 41 | 42 | def forward(self, pred, true): 43 | loss = self.loss_fcn(pred, true) 44 | # p_t = torch.exp(-loss) 45 | # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability 46 | 47 | # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py 48 | pred_prob = torch.sigmoid(pred) # prob from logits 49 | p_t = true * pred_prob + (1 - true) * (1 - pred_prob) 50 | alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha) 51 | modulating_factor = (1.0 - p_t) ** self.gamma 52 | loss *= alpha_factor * modulating_factor 53 | 54 | if self.reduction == 'mean': 55 | return loss.mean() 56 | elif self.reduction == 'sum': 57 | return loss.sum() 58 | else: # 'none' 59 | return loss 60 | 61 | 62 | class QFocalLoss(nn.Module): 63 | # Wraps Quality focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5) 64 | def __init__(self, loss_fcn, gamma=1.5, alpha=0.25): 65 | super(QFocalLoss, self).__init__() 66 | self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss() 67 | self.gamma = gamma 68 | self.alpha = alpha 69 | self.reduction = loss_fcn.reduction 70 | self.loss_fcn.reduction = 'none' # required to apply FL to each element 71 | 72 | def forward(self, pred, true): 73 | loss = self.loss_fcn(pred, true) 74 | 75 | pred_prob = torch.sigmoid(pred) # prob from logits 76 | alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha) 77 | modulating_factor = torch.abs(true - pred_prob) ** self.gamma 78 | loss *= alpha_factor * modulating_factor 79 | 80 | if self.reduction == 'mean': 81 | return loss.mean() 82 | elif self.reduction == 'sum': 83 | return loss.sum() 84 | else: # 'none' 85 | return loss 86 | 87 | 88 | class ComputeLoss: 89 | # Compute losses 90 | def __init__(self, model, autobalance=False): 91 | super(ComputeLoss, self).__init__() 92 | device = next(model.parameters()).device # get model device 93 | h = model.hyp # hyperparameters 94 | 95 | # Define criteria 96 | BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device)) 97 | BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device)) 98 | 99 | # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3 100 | self.cp, self.cn = smooth_BCE(eps=h.get('label_smoothing', 0.0)) # positive, negative BCE targets 101 | 102 | # Focal loss 103 | g = h['fl_gamma'] # focal loss gamma 104 | if g > 0: 105 | BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g) 106 | 107 | det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module 108 | self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, .02]) # P3-P7 109 | self.ssi = list(det.stride).index(16) if autobalance else 0 # stride 16 index 110 | self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, model.gr, h, autobalance 111 | for k in 'na', 'nc', 'nl', 'anchors': 112 | setattr(self, k, getattr(det, k)) 113 | 114 | def __call__(self, p, targets): # predictions, targets, model 115 | device = targets.device 116 | lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device) 117 | tcls, tbox, indices, anchors = self.build_targets(p, targets) # targets 118 | 119 | # Losses 120 | for i, pi in enumerate(p): # layer index, layer predictions 121 | b, a, gj, gi = indices[i] # image, anchor, gridy, gridx 122 | tobj = torch.zeros_like(pi[..., 0], device=device) # target obj 123 | 124 | n = b.shape[0] # number of targets 125 | if n: 126 | ps = pi[b, a, gj, gi] # prediction subset corresponding to targets 127 | 128 | # Regression 129 | pxy = ps[:, :2].sigmoid() * 2. - 0.5 130 | pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i] 131 | pbox = torch.cat((pxy, pwh), 1) # predicted box 132 | iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # iou(prediction, target) 133 | lbox += (1.0 - iou).mean() # iou loss 134 | 135 | # Objectness 136 | tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * iou.detach().clamp(0).type(tobj.dtype) # iou ratio 137 | 138 | # Classification 139 | if self.nc > 1: # cls loss (only if multiple classes) 140 | t = torch.full_like(ps[:, 5:], self.cn, device=device) # targets 141 | t[range(n), tcls[i]] = self.cp 142 | lcls += self.BCEcls(ps[:, 5:], t) # BCE 143 | 144 | # Append targets to text file 145 | # with open('targets.txt', 'a') as file: 146 | # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)] 147 | 148 | obji = self.BCEobj(pi[..., 4], tobj) 149 | lobj += obji * self.balance[i] # obj loss 150 | if self.autobalance: 151 | self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item() 152 | 153 | if self.autobalance: 154 | self.balance = [x / self.balance[self.ssi] for x in self.balance] 155 | lbox *= self.hyp['box'] 156 | lobj *= self.hyp['obj'] 157 | lcls *= self.hyp['cls'] 158 | bs = tobj.shape[0] # batch size 159 | 160 | loss = lbox + lobj + lcls 161 | return loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach() 162 | 163 | def build_targets(self, p, targets): 164 | # Build targets for compute_loss(), input targets(image,class,x,y,w,h) 165 | na, nt = self.na, targets.shape[0] # number of anchors, targets 166 | tcls, tbox, indices, anch = [], [], [], [] 167 | gain = torch.ones(7, device=targets.device) # normalized to gridspace gain 168 | ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt) 169 | targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices 170 | 171 | g = 0.5 # bias 172 | off = torch.tensor([[0, 0], 173 | [1, 0], [0, 1], [-1, 0], [0, -1], # j,k,l,m 174 | # [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm 175 | ], device=targets.device).float() * g # offsets 176 | 177 | for i in range(self.nl): 178 | anchors = self.anchors[i] 179 | gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain 180 | 181 | # Match targets to anchors 182 | t = targets * gain 183 | if nt: 184 | # Matches 185 | r = t[:, :, 4:6] / anchors[:, None] # wh ratio 186 | j = torch.max(r, 1. / r).max(2)[0] < self.hyp['anchor_t'] # compare 187 | # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2)) 188 | t = t[j] # filter 189 | 190 | # Offsets 191 | gxy = t[:, 2:4] # grid xy 192 | gxi = gain[[2, 3]] - gxy # inverse 193 | j, k = ((gxy % 1. < g) & (gxy > 1.)).T 194 | l, m = ((gxi % 1. < g) & (gxi > 1.)).T 195 | j = torch.stack((torch.ones_like(j), j, k, l, m)) 196 | t = t.repeat((5, 1, 1))[j] 197 | offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j] 198 | else: 199 | t = targets[0] 200 | offsets = 0 201 | 202 | # Define 203 | b, c = t[:, :2].long().T # image, class 204 | gxy = t[:, 2:4] # grid xy 205 | gwh = t[:, 4:6] # grid wh 206 | gij = (gxy - offsets).long() 207 | gi, gj = gij.T # grid xy indices 208 | 209 | # Append 210 | a = t[:, 6].long() # anchor indices 211 | indices.append((b, a, gj.clamp_(0, gain[3] - 1), gi.clamp_(0, gain[2] - 1))) # image, anchor, grid indices 212 | tbox.append(torch.cat((gxy - gij, gwh), 1)) # box 213 | anch.append(anchors[a]) # anchors 214 | tcls.append(c) # class 215 | 216 | return tcls, tbox, indices, anch 217 | -------------------------------------------------------------------------------- /utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | # YOLOv5 PyTorch utils 2 | 3 | import datetime 4 | import logging 5 | import math 6 | import os 7 | import platform 8 | import subprocess 9 | import time 10 | from contextlib import contextmanager 11 | from copy import deepcopy 12 | from pathlib import Path 13 | 14 | import torch 15 | import torch.backends.cudnn as cudnn 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | import torchvision 19 | 20 | try: 21 | import thop # for FLOPS computation 22 | except ImportError: 23 | thop = None 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | @contextmanager 28 | def torch_distributed_zero_first(local_rank: int): 29 | """ 30 | Decorator to make all processes in distributed training wait for each local_master to do something. 31 | """ 32 | if local_rank not in [-1, 0]: 33 | torch.distributed.barrier() 34 | yield 35 | if local_rank == 0: 36 | torch.distributed.barrier() 37 | 38 | 39 | def init_torch_seeds(seed=0): 40 | # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html 41 | torch.manual_seed(seed) 42 | if seed == 0: # slower, more reproducible 43 | cudnn.benchmark, cudnn.deterministic = False, True 44 | else: # faster, less reproducible 45 | cudnn.benchmark, cudnn.deterministic = True, False 46 | 47 | 48 | def date_modified(path=__file__): 49 | # return human-readable file modification date, i.e. '2021-3-26' 50 | t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime) 51 | return f'{t.year}-{t.month}-{t.day}' 52 | 53 | 54 | def git_describe(path=Path(__file__).parent): # path must be a directory 55 | # return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe 56 | s = f'git -C {path} describe --tags --long --always' 57 | try: 58 | return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1] 59 | except subprocess.CalledProcessError as e: 60 | return '' # not a git repository 61 | 62 | 63 | def select_device(device='', batch_size=None): 64 | # device = 'cpu' or '0' or '0,1,2,3' 65 | s = f'YOLOv5 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string 66 | cpu = device.lower() == 'cpu' 67 | if cpu: 68 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False 69 | elif device: # non-cpu device requested 70 | os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable 71 | assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability 72 | 73 | cuda = not cpu and torch.cuda.is_available() 74 | if cuda: 75 | n = torch.cuda.device_count() 76 | if n > 1 and batch_size: # check that batch_size is compatible with device_count 77 | assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}' 78 | space = ' ' * len(s) 79 | for i, d in enumerate(device.split(',') if device else range(n)): 80 | p = torch.cuda.get_device_properties(i) 81 | s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB 82 | else: 83 | s += 'CPU\n' 84 | 85 | logger.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe 86 | return torch.device('cuda:0' if cuda else 'cpu') 87 | 88 | 89 | def time_synchronized(): 90 | # pytorch-accurate time 91 | if torch.cuda.is_available(): 92 | torch.cuda.synchronize() 93 | return time.time() 94 | 95 | 96 | def profile(x, ops, n=100, device=None): 97 | # profile a pytorch module or list of modules. Example usage: 98 | # x = torch.randn(16, 3, 640, 640) # input 99 | # m1 = lambda x: x * torch.sigmoid(x) 100 | # m2 = nn.SiLU() 101 | # profile(x, [m1, m2], n=100) # profile speed over 100 iterations 102 | 103 | device = device or torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 104 | x = x.to(device) 105 | x.requires_grad = True 106 | print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '') 107 | print(f"\n{'Params':>12s}{'GFLOPS':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}") 108 | for m in ops if isinstance(ops, list) else [ops]: 109 | m = m.to(device) if hasattr(m, 'to') else m # device 110 | m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m # type 111 | dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward 112 | try: 113 | flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPS 114 | except: 115 | flops = 0 116 | 117 | for _ in range(n): 118 | t[0] = time_synchronized() 119 | y = m(x) 120 | t[1] = time_synchronized() 121 | try: 122 | _ = y.sum().backward() 123 | t[2] = time_synchronized() 124 | except: # no backward method 125 | t[2] = float('nan') 126 | dtf += (t[1] - t[0]) * 1000 / n # ms per op forward 127 | dtb += (t[2] - t[1]) * 1000 / n # ms per op backward 128 | 129 | s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' 130 | s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list' 131 | p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters 132 | print(f'{p:12}{flops:12.4g}{dtf:16.4g}{dtb:16.4g}{str(s_in):>24s}{str(s_out):>24s}') 133 | 134 | 135 | def is_parallel(model): 136 | return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) 137 | 138 | 139 | def intersect_dicts(da, db, exclude=()): 140 | # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values 141 | return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape} 142 | 143 | 144 | def initialize_weights(model): 145 | for m in model.modules(): 146 | t = type(m) 147 | if t is nn.Conv2d: 148 | pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 149 | elif t is nn.BatchNorm2d: 150 | m.eps = 1e-3 151 | m.momentum = 0.03 152 | elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]: 153 | m.inplace = True 154 | 155 | 156 | def find_modules(model, mclass=nn.Conv2d): 157 | # Finds layer indices matching module class 'mclass' 158 | return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)] 159 | 160 | 161 | def sparsity(model): 162 | # Return global model sparsity 163 | a, b = 0., 0. 164 | for p in model.parameters(): 165 | a += p.numel() 166 | b += (p == 0).sum() 167 | return b / a 168 | 169 | 170 | def prune(model, amount=0.3): 171 | # Prune model to requested global sparsity 172 | import torch.nn.utils.prune as prune 173 | print('Pruning model... ', end='') 174 | for name, m in model.named_modules(): 175 | if isinstance(m, nn.Conv2d): 176 | prune.l1_unstructured(m, name='weight', amount=amount) # prune 177 | prune.remove(m, 'weight') # make permanent 178 | print(' %.3g global sparsity' % sparsity(model)) 179 | 180 | 181 | def fuse_conv_and_bn(conv, bn): 182 | # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/ 183 | fusedconv = nn.Conv2d(conv.in_channels, 184 | conv.out_channels, 185 | kernel_size=conv.kernel_size, 186 | stride=conv.stride, 187 | padding=conv.padding, 188 | groups=conv.groups, 189 | bias=True).requires_grad_(False).to(conv.weight.device) 190 | 191 | # prepare filters 192 | w_conv = conv.weight.clone().view(conv.out_channels, -1) 193 | w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) 194 | fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape)) 195 | 196 | # prepare spatial bias 197 | b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias 198 | b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) 199 | fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) 200 | 201 | return fusedconv 202 | 203 | 204 | def model_info(model, verbose=False, img_size=640): 205 | # Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320] 206 | n_p = sum(x.numel() for x in model.parameters()) # number parameters 207 | n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients 208 | if verbose: 209 | print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma')) 210 | for i, (name, p) in enumerate(model.named_parameters()): 211 | name = name.replace('module_list.', '') 212 | print('%5g %40s %9s %12g %20s %10.3g %10.3g' % 213 | (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std())) 214 | 215 | try: # FLOPS 216 | from thop import profile 217 | stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 218 | img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input 219 | flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPS 220 | img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float 221 | fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPS 222 | except (ImportError, Exception): 223 | fs = '' 224 | 225 | logger.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}") 226 | 227 | 228 | def load_classifier(name='resnet101', n=2): 229 | # Loads a pretrained model reshaped to n-class output 230 | model = torchvision.models.__dict__[name](pretrained=True) 231 | 232 | # ResNet model properties 233 | # input_size = [3, 224, 224] 234 | # input_space = 'RGB' 235 | # input_range = [0, 1] 236 | # mean = [0.485, 0.456, 0.406] 237 | # std = [0.229, 0.224, 0.225] 238 | 239 | # Reshape output to n classes 240 | filters = model.fc.weight.shape[1] 241 | model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True) 242 | model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True) 243 | model.fc.out_features = n 244 | return model 245 | 246 | 247 | def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416) 248 | # scales img(bs,3,y,x) by ratio constrained to gs-multiple 249 | if ratio == 1.0: 250 | return img 251 | else: 252 | h, w = img.shape[2:] 253 | s = (int(h * ratio), int(w * ratio)) # new size 254 | img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize 255 | if not same_shape: # pad/crop img 256 | h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)] 257 | return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean 258 | 259 | 260 | def copy_attr(a, b, include=(), exclude=()): 261 | # Copy attributes from b to a, options to only include [...] and to exclude [...] 262 | for k, v in b.__dict__.items(): 263 | if (len(include) and k not in include) or k.startswith('_') or k in exclude: 264 | continue 265 | else: 266 | setattr(a, k, v) 267 | 268 | 269 | class ModelEMA: 270 | """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models 271 | Keep a moving average of everything in the model state_dict (parameters and buffers). 272 | This is intended to allow functionality like 273 | https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage 274 | A smoothed version of the weights is necessary for some training schemes to perform well. 275 | This class is sensitive where it is initialized in the sequence of model init, 276 | GPU assignment and distributed training wrappers. 277 | """ 278 | 279 | def __init__(self, model, decay=0.9999, updates=0): 280 | # Create EMA 281 | self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA 282 | # if next(model.parameters()).device.type != 'cpu': 283 | # self.ema.half() # FP16 EMA 284 | self.updates = updates # number of EMA updates 285 | self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs) 286 | for p in self.ema.parameters(): 287 | p.requires_grad_(False) 288 | 289 | def update(self, model): 290 | # Update EMA parameters 291 | with torch.no_grad(): 292 | self.updates += 1 293 | d = self.decay(self.updates) 294 | 295 | msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict 296 | for k, v in self.ema.state_dict().items(): 297 | if v.dtype.is_floating_point: 298 | v *= d 299 | v += (1. - d) * msd[k].detach() 300 | 301 | def update_attr(self, model, include=(), exclude=('process_group', 'reducer')): 302 | # Update EMA attributes 303 | copy_attr(self.ema, model, include, exclude) 304 | -------------------------------------------------------------------------------- /utils/wandb_logging/wandb_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | from pathlib import Path 4 | 5 | import torch 6 | import yaml 7 | from tqdm import tqdm 8 | 9 | sys.path.append(str(Path(__file__).parent.parent.parent)) # add utils/ to path 10 | from utils.datasets import LoadImagesAndLabels 11 | from utils.datasets import img2label_paths 12 | from utils.general import colorstr, xywh2xyxy, check_dataset 13 | 14 | try: 15 | import wandb 16 | from wandb import init, finish 17 | except ImportError: 18 | wandb = None 19 | 20 | WANDB_ARTIFACT_PREFIX = 'wandb-artifact://' 21 | 22 | 23 | def remove_prefix(from_string, prefix=WANDB_ARTIFACT_PREFIX): 24 | return from_string[len(prefix):] 25 | 26 | 27 | def check_wandb_config_file(data_config_file): 28 | wandb_config = '_wandb.'.join(data_config_file.rsplit('.', 1)) # updated data.yaml path 29 | if Path(wandb_config).is_file(): 30 | return wandb_config 31 | return data_config_file 32 | 33 | 34 | def get_run_info(run_path): 35 | run_path = Path(remove_prefix(run_path, WANDB_ARTIFACT_PREFIX)) 36 | run_id = run_path.stem 37 | project = run_path.parent.stem 38 | model_artifact_name = 'run_' + run_id + '_model' 39 | return run_id, project, model_artifact_name 40 | 41 | 42 | def check_wandb_resume(opt): 43 | process_wandb_config_ddp_mode(opt) if opt.global_rank not in [-1, 0] else None 44 | if isinstance(opt.resume, str): 45 | if opt.resume.startswith(WANDB_ARTIFACT_PREFIX): 46 | if opt.global_rank not in [-1, 0]: # For resuming DDP runs 47 | run_id, project, model_artifact_name = get_run_info(opt.resume) 48 | api = wandb.Api() 49 | artifact = api.artifact(project + '/' + model_artifact_name + ':latest') 50 | modeldir = artifact.download() 51 | opt.weights = str(Path(modeldir) / "last.pt") 52 | return True 53 | return None 54 | 55 | 56 | def process_wandb_config_ddp_mode(opt): 57 | with open(opt.data) as f: 58 | data_dict = yaml.load(f, Loader=yaml.SafeLoader) # data dict 59 | train_dir, val_dir = None, None 60 | if isinstance(data_dict['train'], str) and data_dict['train'].startswith(WANDB_ARTIFACT_PREFIX): 61 | api = wandb.Api() 62 | train_artifact = api.artifact(remove_prefix(data_dict['train']) + ':' + opt.artifact_alias) 63 | train_dir = train_artifact.download() 64 | train_path = Path(train_dir) / 'data/images/' 65 | data_dict['train'] = str(train_path) 66 | 67 | if isinstance(data_dict['val'], str) and data_dict['val'].startswith(WANDB_ARTIFACT_PREFIX): 68 | api = wandb.Api() 69 | val_artifact = api.artifact(remove_prefix(data_dict['val']) + ':' + opt.artifact_alias) 70 | val_dir = val_artifact.download() 71 | val_path = Path(val_dir) / 'data/images/' 72 | data_dict['val'] = str(val_path) 73 | if train_dir or val_dir: 74 | ddp_data_path = str(Path(val_dir) / 'wandb_local_data.yaml') 75 | with open(ddp_data_path, 'w') as f: 76 | yaml.dump(data_dict, f) 77 | opt.data = ddp_data_path 78 | 79 | 80 | class WandbLogger(): 81 | def __init__(self, opt, name, run_id, data_dict, job_type='Training'): 82 | # Pre-training routine -- 83 | self.job_type = job_type 84 | self.wandb, self.wandb_run, self.data_dict = wandb, None if not wandb else wandb.run, data_dict 85 | # It's more elegant to stick to 1 wandb.init call, but useful config data is overwritten in the WandbLogger's wandb.init call 86 | if isinstance(opt.resume, str): # checks resume from artifact 87 | if opt.resume.startswith(WANDB_ARTIFACT_PREFIX): 88 | run_id, project, model_artifact_name = get_run_info(opt.resume) 89 | model_artifact_name = WANDB_ARTIFACT_PREFIX + model_artifact_name 90 | assert wandb, 'install wandb to resume wandb runs' 91 | # Resume wandb-artifact:// runs here| workaround for not overwriting wandb.config 92 | self.wandb_run = wandb.init(id=run_id, project=project, resume='allow') 93 | opt.resume = model_artifact_name 94 | elif self.wandb: 95 | self.wandb_run = wandb.init(config=opt, 96 | resume="allow", 97 | project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem, 98 | name=name, 99 | job_type=job_type, 100 | id=run_id) if not wandb.run else wandb.run 101 | if self.wandb_run: 102 | if self.job_type == 'Training': 103 | if not opt.resume: 104 | wandb_data_dict = self.check_and_upload_dataset(opt) if opt.upload_dataset else data_dict 105 | # Info useful for resuming from artifacts 106 | self.wandb_run.config.opt = vars(opt) 107 | self.wandb_run.config.data_dict = wandb_data_dict 108 | self.data_dict = self.setup_training(opt, data_dict) 109 | if self.job_type == 'Dataset Creation': 110 | self.data_dict = self.check_and_upload_dataset(opt) 111 | else: 112 | prefix = colorstr('wandb: ') 113 | print(f"{prefix}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)") 114 | 115 | def check_and_upload_dataset(self, opt): 116 | assert wandb, 'Install wandb to upload dataset' 117 | check_dataset(self.data_dict) 118 | config_path = self.log_dataset_artifact(opt.data, 119 | opt.single_cls, 120 | 'YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem) 121 | print("Created dataset config file ", config_path) 122 | with open(config_path) as f: 123 | wandb_data_dict = yaml.load(f, Loader=yaml.SafeLoader) 124 | return wandb_data_dict 125 | 126 | def setup_training(self, opt, data_dict): 127 | self.log_dict, self.current_epoch, self.log_imgs = {}, 0, 16 # Logging Constants 128 | self.bbox_interval = opt.bbox_interval 129 | if isinstance(opt.resume, str): 130 | modeldir, _ = self.download_model_artifact(opt) 131 | if modeldir: 132 | self.weights = Path(modeldir) / "last.pt" 133 | config = self.wandb_run.config 134 | opt.weights, opt.save_period, opt.batch_size, opt.bbox_interval, opt.epochs, opt.hyp = str( 135 | self.weights), config.save_period, config.total_batch_size, config.bbox_interval, config.epochs, \ 136 | config.opt['hyp'] 137 | data_dict = dict(self.wandb_run.config.data_dict) # eliminates the need for config file to resume 138 | if 'val_artifact' not in self.__dict__: # If --upload_dataset is set, use the existing artifact, don't download 139 | self.train_artifact_path, self.train_artifact = self.download_dataset_artifact(data_dict.get('train'), 140 | opt.artifact_alias) 141 | self.val_artifact_path, self.val_artifact = self.download_dataset_artifact(data_dict.get('val'), 142 | opt.artifact_alias) 143 | self.result_artifact, self.result_table, self.val_table, self.weights = None, None, None, None 144 | if self.train_artifact_path is not None: 145 | train_path = Path(self.train_artifact_path) / 'data/images/' 146 | data_dict['train'] = str(train_path) 147 | if self.val_artifact_path is not None: 148 | val_path = Path(self.val_artifact_path) / 'data/images/' 149 | data_dict['val'] = str(val_path) 150 | self.val_table = self.val_artifact.get("val") 151 | self.map_val_table_path() 152 | if self.val_artifact is not None: 153 | self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation") 154 | self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"]) 155 | if opt.bbox_interval == -1: 156 | self.bbox_interval = opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else 1 157 | return data_dict 158 | 159 | def download_dataset_artifact(self, path, alias): 160 | if isinstance(path, str) and path.startswith(WANDB_ARTIFACT_PREFIX): 161 | dataset_artifact = wandb.use_artifact(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias) 162 | assert dataset_artifact is not None, "'Error: W&B dataset artifact doesn\'t exist'" 163 | datadir = dataset_artifact.download() 164 | return datadir, dataset_artifact 165 | return None, None 166 | 167 | def download_model_artifact(self, opt): 168 | if opt.resume.startswith(WANDB_ARTIFACT_PREFIX): 169 | model_artifact = wandb.use_artifact(remove_prefix(opt.resume, WANDB_ARTIFACT_PREFIX) + ":latest") 170 | assert model_artifact is not None, 'Error: W&B model artifact doesn\'t exist' 171 | modeldir = model_artifact.download() 172 | epochs_trained = model_artifact.metadata.get('epochs_trained') 173 | total_epochs = model_artifact.metadata.get('total_epochs') 174 | assert epochs_trained < total_epochs, 'training to %g epochs is finished, nothing to resume.' % ( 175 | total_epochs) 176 | return modeldir, model_artifact 177 | return None, None 178 | 179 | def log_model(self, path, opt, epoch, fitness_score, best_model=False): 180 | model_artifact = wandb.Artifact('run_' + wandb.run.id + '_model', type='model', metadata={ 181 | 'original_url': str(path), 182 | 'epochs_trained': epoch + 1, 183 | 'save period': opt.save_period, 184 | 'project': opt.project, 185 | 'total_epochs': opt.epochs, 186 | 'fitness_score': fitness_score 187 | }) 188 | model_artifact.add_file(str(path / 'last.pt'), name='last.pt') 189 | wandb.log_artifact(model_artifact, 190 | aliases=['latest', 'epoch ' + str(self.current_epoch), 'best' if best_model else '']) 191 | print("Saving model artifact on epoch ", epoch + 1) 192 | 193 | def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=False): 194 | with open(data_file) as f: 195 | data = yaml.load(f, Loader=yaml.SafeLoader) # data dict 196 | nc, names = (1, ['item']) if single_cls else (int(data['nc']), data['names']) 197 | names = {k: v for k, v in enumerate(names)} # to index dictionary 198 | self.train_artifact = self.create_dataset_table(LoadImagesAndLabels( 199 | data['train']), names, name='train') if data.get('train') else None 200 | self.val_artifact = self.create_dataset_table(LoadImagesAndLabels( 201 | data['val']), names, name='val') if data.get('val') else None 202 | if data.get('train'): 203 | data['train'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'train') 204 | if data.get('val'): 205 | data['val'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'val') 206 | path = data_file if overwrite_config else '_wandb.'.join(data_file.rsplit('.', 1)) # updated data.yaml path 207 | data.pop('download', None) 208 | with open(path, 'w') as f: 209 | yaml.dump(data, f) 210 | 211 | if self.job_type == 'Training': # builds correct artifact pipeline graph 212 | self.wandb_run.use_artifact(self.val_artifact) 213 | self.wandb_run.use_artifact(self.train_artifact) 214 | self.val_artifact.wait() 215 | self.val_table = self.val_artifact.get('val') 216 | self.map_val_table_path() 217 | else: 218 | self.wandb_run.log_artifact(self.train_artifact) 219 | self.wandb_run.log_artifact(self.val_artifact) 220 | return path 221 | 222 | def map_val_table_path(self): 223 | self.val_table_map = {} 224 | print("Mapping dataset") 225 | for i, data in enumerate(tqdm(self.val_table.data)): 226 | self.val_table_map[data[3]] = data[0] 227 | 228 | def create_dataset_table(self, dataset, class_to_id, name='dataset'): 229 | # TODO: Explore multiprocessing to slpit this loop parallely| This is essential for speeding up the the logging 230 | artifact = wandb.Artifact(name=name, type="dataset") 231 | img_files = tqdm([dataset.path]) if isinstance(dataset.path, str) and Path(dataset.path).is_dir() else None 232 | img_files = tqdm(dataset.img_files) if not img_files else img_files 233 | for img_file in img_files: 234 | if Path(img_file).is_dir(): 235 | artifact.add_dir(img_file, name='data/images') 236 | labels_path = 'labels'.join(dataset.path.rsplit('images', 1)) 237 | artifact.add_dir(labels_path, name='data/labels') 238 | else: 239 | artifact.add_file(img_file, name='data/images/' + Path(img_file).name) 240 | label_file = Path(img2label_paths([img_file])[0]) 241 | artifact.add_file(str(label_file), 242 | name='data/labels/' + label_file.name) if label_file.exists() else None 243 | table = wandb.Table(columns=["id", "train_image", "Classes", "name"]) 244 | class_set = wandb.Classes([{'id': id, 'name': name} for id, name in class_to_id.items()]) 245 | for si, (img, labels, paths, shapes) in enumerate(tqdm(dataset)): 246 | height, width = shapes[0] 247 | labels[:, 2:] = (xywh2xyxy(labels[:, 2:].view(-1, 4))) * torch.Tensor([width, height, width, height]) 248 | box_data, img_classes = [], {} 249 | for cls, *xyxy in labels[:, 1:].tolist(): 250 | cls = int(cls) 251 | box_data.append({"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]}, 252 | "class_id": cls, 253 | "box_caption": "%s" % (class_to_id[cls]), 254 | "scores": {"acc": 1}, 255 | "domain": "pixel"}) 256 | img_classes[cls] = class_to_id[cls] 257 | boxes = {"ground_truth": {"box_data": box_data, "class_labels": class_to_id}} # inference-space 258 | table.add_data(si, wandb.Image(paths, classes=class_set, boxes=boxes), json.dumps(img_classes), 259 | Path(paths).name) 260 | artifact.add(table, name) 261 | return artifact 262 | 263 | def log_training_progress(self, predn, path, names): 264 | if self.val_table and self.result_table: 265 | class_set = wandb.Classes([{'id': id, 'name': name} for id, name in names.items()]) 266 | box_data = [] 267 | total_conf = 0 268 | for *xyxy, conf, cls in predn.tolist(): 269 | if conf >= 0.25: 270 | box_data.append( 271 | {"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]}, 272 | "class_id": int(cls), 273 | "box_caption": "%s %.3f" % (names[cls], conf), 274 | "scores": {"class_score": conf}, 275 | "domain": "pixel"}) 276 | total_conf = total_conf + conf 277 | boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space 278 | id = self.val_table_map[Path(path).name] 279 | self.result_table.add_data(self.current_epoch, 280 | id, 281 | wandb.Image(self.val_table.data[id][1], boxes=boxes, classes=class_set), 282 | total_conf / max(1, len(box_data)) 283 | ) 284 | 285 | def log(self, log_dict): 286 | if self.wandb_run: 287 | for key, value in log_dict.items(): 288 | self.log_dict[key] = value 289 | 290 | def end_epoch(self, best_result=False): 291 | if self.wandb_run: 292 | wandb.log(self.log_dict) 293 | self.log_dict = {} 294 | if self.result_artifact: 295 | train_results = wandb.JoinedTable(self.val_table, self.result_table, "id") 296 | self.result_artifact.add(train_results, 'result') 297 | wandb.log_artifact(self.result_artifact, aliases=['latest', 'epoch ' + str(self.current_epoch), 298 | ('best' if best_result else '')]) 299 | self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"]) 300 | self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation") 301 | 302 | def finish_run(self): 303 | if self.wandb_run: 304 | if self.log_dict: 305 | wandb.log(self.log_dict) 306 | wandb.run.finish() 307 | -------------------------------------------------------------------------------- /utils/plots.py: -------------------------------------------------------------------------------- 1 | # Plotting utils 2 | 3 | import glob 4 | import math 5 | import os 6 | import random 7 | from copy import copy 8 | from pathlib import Path 9 | 10 | import cv2 11 | import matplotlib 12 | import matplotlib.pyplot as plt 13 | import numpy as np 14 | import pandas as pd 15 | import seaborn as sns 16 | import torch 17 | import yaml 18 | from PIL import Image, ImageDraw, ImageFont 19 | from scipy.signal import butter, filtfilt 20 | 21 | from utils.general import xywh2xyxy, xyxy2xywh 22 | from utils.metrics import fitness 23 | 24 | # Settings 25 | matplotlib.rc('font', **{'size': 11}) 26 | matplotlib.use('Agg') # for writing to files only 27 | 28 | 29 | def color_list(): 30 | # Return first 10 plt colors as (r,g,b) https://stackoverflow.com/questions/51350872/python-from-color-name-to-rgb 31 | def hex2rgb(h): 32 | return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4)) 33 | 34 | return [hex2rgb(h) for h in matplotlib.colors.TABLEAU_COLORS.values()] # or BASE_ (8), CSS4_ (148), XKCD_ (949) 35 | 36 | 37 | def hist2d(x, y, n=100): 38 | # 2d histogram used in labels.png and evolve.png 39 | xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n) 40 | hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges)) 41 | xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1) 42 | yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1) 43 | return np.log(hist[xidx, yidx]) 44 | 45 | 46 | def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5): 47 | # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy 48 | def butter_lowpass(cutoff, fs, order): 49 | nyq = 0.5 * fs 50 | normal_cutoff = cutoff / nyq 51 | return butter(order, normal_cutoff, btype='low', analog=False) 52 | 53 | b, a = butter_lowpass(cutoff, fs, order=order) 54 | return filtfilt(b, a, data) # forward-backward filter 55 | 56 | 57 | def plot_one_box(x, img, color=None, label=None, line_thickness=3): 58 | # Plots one bounding box on image img 59 | tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness 60 | color = color or [random.randint(0, 255) for _ in range(3)] 61 | c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3])) 62 | cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA) 63 | if label: 64 | tf = max(tl - 1, 1) # font thickness 65 | t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0] 66 | c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 67 | cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled 68 | cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA) 69 | 70 | 71 | def plot_one_box_PIL(box, img, color=None, label=None, line_thickness=None): 72 | img = Image.fromarray(img) 73 | draw = ImageDraw.Draw(img) 74 | line_thickness = line_thickness or max(int(min(img.size) / 200), 2) 75 | draw.rectangle(box, width=line_thickness, outline=tuple(color)) # plot 76 | if label: 77 | fontsize = max(round(max(img.size) / 40), 12) 78 | font = ImageFont.truetype("Arial.ttf", fontsize) 79 | txt_width, txt_height = font.getsize(label) 80 | draw.rectangle([box[0], box[1] - txt_height + 4, box[0] + txt_width, box[1]], fill=tuple(color)) 81 | draw.text((box[0], box[1] - txt_height + 1), label, fill=(255, 255, 255), font=font) 82 | return np.asarray(img) 83 | 84 | 85 | def plot_wh_methods(): # from utils.plots import *; plot_wh_methods() 86 | # Compares the two methods for width-height anchor multiplication 87 | # https://github.com/ultralytics/yolov3/issues/168 88 | x = np.arange(-4.0, 4.0, .1) 89 | ya = np.exp(x) 90 | yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2 91 | 92 | fig = plt.figure(figsize=(6, 3), tight_layout=True) 93 | plt.plot(x, ya, '.-', label='YOLOv3') 94 | plt.plot(x, yb ** 2, '.-', label='YOLOv5 ^2') 95 | plt.plot(x, yb ** 1.6, '.-', label='YOLOv5 ^1.6') 96 | plt.xlim(left=-4, right=4) 97 | plt.ylim(bottom=0, top=6) 98 | plt.xlabel('input') 99 | plt.ylabel('output') 100 | plt.grid() 101 | plt.legend() 102 | fig.savefig('comparison.png', dpi=200) 103 | 104 | 105 | def output_to_target(output): 106 | # Convert model output to target format [batch_id, class_id, x, y, w, h, conf] 107 | targets = [] 108 | for i, o in enumerate(output): 109 | for *box, conf, cls in o.cpu().numpy(): 110 | targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf]) 111 | return np.array(targets) 112 | 113 | 114 | def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16): 115 | # Plot image grid with labels 116 | 117 | if isinstance(images, torch.Tensor): 118 | images = images.cpu().float().numpy() 119 | if isinstance(targets, torch.Tensor): 120 | targets = targets.cpu().numpy() 121 | 122 | # un-normalise 123 | if np.max(images[0]) <= 1: 124 | images *= 255 125 | 126 | tl = 3 # line thickness 127 | tf = max(tl - 1, 1) # font thickness 128 | bs, _, h, w = images.shape # batch size, _, height, width 129 | bs = min(bs, max_subplots) # limit plot images 130 | ns = np.ceil(bs ** 0.5) # number of subplots (square) 131 | 132 | # Check if we should resize 133 | scale_factor = max_size / max(h, w) 134 | if scale_factor < 1: 135 | h = math.ceil(scale_factor * h) 136 | w = math.ceil(scale_factor * w) 137 | 138 | colors = color_list() # list of colors 139 | mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init 140 | for i, img in enumerate(images): 141 | if i == max_subplots: # if last batch has fewer images than we expect 142 | break 143 | 144 | block_x = int(w * (i // ns)) 145 | block_y = int(h * (i % ns)) 146 | 147 | img = img.transpose(1, 2, 0) 148 | if scale_factor < 1: 149 | img = cv2.resize(img, (w, h)) 150 | 151 | mosaic[block_y:block_y + h, block_x:block_x + w, :] = img 152 | if len(targets) > 0: 153 | image_targets = targets[targets[:, 0] == i] 154 | boxes = xywh2xyxy(image_targets[:, 2:6]).T 155 | classes = image_targets[:, 1].astype('int') 156 | labels = image_targets.shape[1] == 6 # labels if no conf column 157 | conf = None if labels else image_targets[:, 6] # check for confidence presence (label vs pred) 158 | 159 | if boxes.shape[1]: 160 | if boxes.max() <= 1.01: # if normalized with tolerance 0.01 161 | boxes[[0, 2]] *= w # scale to pixels 162 | boxes[[1, 3]] *= h 163 | elif scale_factor < 1: # absolute coords need scale if image scales 164 | boxes *= scale_factor 165 | boxes[[0, 2]] += block_x 166 | boxes[[1, 3]] += block_y 167 | for j, box in enumerate(boxes.T): 168 | cls = int(classes[j]) 169 | color = colors[cls % len(colors)] 170 | cls = names[cls] if names else cls 171 | if labels or conf[j] > 0.25: # 0.25 conf thresh 172 | label = '%s' % cls if labels else '%s %.1f' % (cls, conf[j]) 173 | plot_one_box(box, mosaic, label=label, color=color, line_thickness=tl) 174 | 175 | # Draw image filename labels 176 | if paths: 177 | label = Path(paths[i]).name[:40] # trim to 40 char 178 | t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0] 179 | cv2.putText(mosaic, label, (block_x + 5, block_y + t_size[1] + 5), 0, tl / 3, [220, 220, 220], thickness=tf, 180 | lineType=cv2.LINE_AA) 181 | 182 | # Image border 183 | cv2.rectangle(mosaic, (block_x, block_y), (block_x + w, block_y + h), (255, 255, 255), thickness=3) 184 | 185 | if fname: 186 | r = min(1280. / max(h, w) / ns, 1.0) # ratio to limit image size 187 | mosaic = cv2.resize(mosaic, (int(ns * w * r), int(ns * h * r)), interpolation=cv2.INTER_AREA) 188 | # cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB)) # cv2 save 189 | Image.fromarray(mosaic).save(fname) # PIL save 190 | return mosaic 191 | 192 | 193 | def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''): 194 | # Plot LR simulating training for full epochs 195 | optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals 196 | y = [] 197 | for _ in range(epochs): 198 | scheduler.step() 199 | y.append(optimizer.param_groups[0]['lr']) 200 | plt.plot(y, '.-', label='LR') 201 | plt.xlabel('epoch') 202 | plt.ylabel('LR') 203 | plt.grid() 204 | plt.xlim(0, epochs) 205 | plt.ylim(0) 206 | plt.savefig(Path(save_dir) / 'LR.png', dpi=200) 207 | plt.close() 208 | 209 | 210 | def plot_test_txt(): # from utils.plots import *; plot_test() 211 | # Plot test.txt histograms 212 | x = np.loadtxt('test.txt', dtype=np.float32) 213 | box = xyxy2xywh(x[:, :4]) 214 | cx, cy = box[:, 0], box[:, 1] 215 | 216 | fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True) 217 | ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0) 218 | ax.set_aspect('equal') 219 | plt.savefig('hist2d.png', dpi=300) 220 | 221 | fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True) 222 | ax[0].hist(cx, bins=600) 223 | ax[1].hist(cy, bins=600) 224 | plt.savefig('hist1d.png', dpi=200) 225 | 226 | 227 | def plot_targets_txt(): # from utils.plots import *; plot_targets_txt() 228 | # Plot targets.txt histograms 229 | x = np.loadtxt('targets.txt', dtype=np.float32).T 230 | s = ['x targets', 'y targets', 'width targets', 'height targets'] 231 | fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True) 232 | ax = ax.ravel() 233 | for i in range(4): 234 | ax[i].hist(x[i], bins=100, label='%.3g +/- %.3g' % (x[i].mean(), x[i].std())) 235 | ax[i].legend() 236 | ax[i].set_title(s[i]) 237 | plt.savefig('targets.jpg', dpi=200) 238 | 239 | 240 | def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_txt() 241 | # Plot study.txt generated by test.py 242 | fig, ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True) 243 | # ax = ax.ravel() 244 | 245 | fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True) 246 | # for f in [Path(path) / f'study_coco_{x}.txt' for x in ['yolov5s6', 'yolov5m6', 'yolov5l6', 'yolov5x6']]: 247 | for f in sorted(Path(path).glob('study*.txt')): 248 | y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T 249 | x = np.arange(y.shape[1]) if x is None else np.array(x) 250 | s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_inference (ms/img)', 't_NMS (ms/img)', 't_total (ms/img)'] 251 | # for i in range(7): 252 | # ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8) 253 | # ax[i].set_title(s[i]) 254 | 255 | j = y[3].argmax() + 1 256 | ax2.plot(y[6, 1:j], y[3, 1:j] * 1E2, '.-', linewidth=2, markersize=8, 257 | label=f.stem.replace('study_coco_', '').replace('yolo', 'YOLO')) 258 | 259 | ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [34.6, 40.5, 43.0, 47.5, 49.7, 51.5], 260 | 'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet') 261 | 262 | ax2.grid(alpha=0.2) 263 | ax2.set_yticks(np.arange(20, 60, 5)) 264 | ax2.set_xlim(0, 57) 265 | ax2.set_ylim(30, 55) 266 | ax2.set_xlabel('GPU Speed (ms/img)') 267 | ax2.set_ylabel('COCO AP val') 268 | ax2.legend(loc='lower right') 269 | plt.savefig(str(Path(path).name) + '.png', dpi=300) 270 | 271 | 272 | def plot_labels(labels, names=(), save_dir=Path(''), loggers=None): 273 | # plot dataset labels 274 | print('Plotting labels... ') 275 | c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes 276 | nc = int(c.max() + 1) # number of classes 277 | colors = color_list() 278 | x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height']) 279 | 280 | # seaborn correlogram 281 | sns.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9)) 282 | plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200) 283 | plt.close() 284 | 285 | # matplotlib labels 286 | matplotlib.use('svg') # faster 287 | ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel() 288 | ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8) 289 | ax[0].set_ylabel('instances') 290 | if 0 < len(names) < 30: 291 | ax[0].set_xticks(range(len(names))) 292 | ax[0].set_xticklabels(names, rotation=90, fontsize=10) 293 | else: 294 | ax[0].set_xlabel('classes') 295 | sns.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9) 296 | sns.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9) 297 | 298 | # rectangles 299 | labels[:, 1:3] = 0.5 # center 300 | labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000 301 | img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255) 302 | for cls, *box in labels[:1000]: 303 | ImageDraw.Draw(img).rectangle(box, width=1, outline=colors[int(cls) % 10]) # plot 304 | ax[1].imshow(img) 305 | ax[1].axis('off') 306 | 307 | for a in [0, 1, 2, 3]: 308 | for s in ['top', 'right', 'left', 'bottom']: 309 | ax[a].spines[s].set_visible(False) 310 | 311 | plt.savefig(save_dir / 'labels.jpg', dpi=200) 312 | matplotlib.use('Agg') 313 | plt.close() 314 | 315 | # loggers 316 | for k, v in loggers.items() or {}: 317 | if k == 'wandb' and v: 318 | v.log({"Labels": [v.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.jpg')]}, commit=False) 319 | 320 | 321 | def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution() 322 | # Plot hyperparameter evolution results in evolve.txt 323 | with open(yaml_file) as f: 324 | hyp = yaml.load(f, Loader=yaml.SafeLoader) 325 | x = np.loadtxt('evolve.txt', ndmin=2) 326 | f = fitness(x) 327 | # weights = (f - f.min()) ** 2 # for weighted results 328 | plt.figure(figsize=(10, 12), tight_layout=True) 329 | matplotlib.rc('font', **{'size': 8}) 330 | for i, (k, v) in enumerate(hyp.items()): 331 | y = x[:, i + 7] 332 | # mu = (y * weights).sum() / weights.sum() # best weighted result 333 | mu = y[f.argmax()] # best single result 334 | plt.subplot(6, 5, i + 1) 335 | plt.scatter(y, f, c=hist2d(y, f, 20), cmap='viridis', alpha=.8, edgecolors='none') 336 | plt.plot(mu, f.max(), 'k+', markersize=15) 337 | plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters 338 | if i % 5 != 0: 339 | plt.yticks([]) 340 | print('%15s: %.3g' % (k, mu)) 341 | plt.savefig('evolve.png', dpi=200) 342 | print('\nPlot saved as evolve.png') 343 | 344 | 345 | def profile_idetection(start=0, stop=0, labels=(), save_dir=''): 346 | # Plot iDetection '*.txt' per-image logs. from utils.plots import *; profile_idetection() 347 | ax = plt.subplots(2, 4, figsize=(12, 6), tight_layout=True)[1].ravel() 348 | s = ['Images', 'Free Storage (GB)', 'RAM Usage (GB)', 'Battery', 'dt_raw (ms)', 'dt_smooth (ms)', 'real-world FPS'] 349 | files = list(Path(save_dir).glob('frames*.txt')) 350 | for fi, f in enumerate(files): 351 | try: 352 | results = np.loadtxt(f, ndmin=2).T[:, 90:-30] # clip first and last rows 353 | n = results.shape[1] # number of rows 354 | x = np.arange(start, min(stop, n) if stop else n) 355 | results = results[:, x] 356 | t = (results[0] - results[0].min()) # set t0=0s 357 | results[0] = x 358 | for i, a in enumerate(ax): 359 | if i < len(results): 360 | label = labels[fi] if len(labels) else f.stem.replace('frames_', '') 361 | a.plot(t, results[i], marker='.', label=label, linewidth=1, markersize=5) 362 | a.set_title(s[i]) 363 | a.set_xlabel('time (s)') 364 | # if fi == len(files) - 1: 365 | # a.set_ylim(bottom=0) 366 | for side in ['top', 'right']: 367 | a.spines[side].set_visible(False) 368 | else: 369 | a.remove() 370 | except Exception as e: 371 | print('Warning: Plotting error for %s; %s' % (f, e)) 372 | 373 | ax[1].legend() 374 | plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200) 375 | 376 | 377 | def plot_results_overlay(start=0, stop=0): # from utils.plots import *; plot_results_overlay() 378 | # Plot training 'results*.txt', overlaying train and val losses 379 | s = ['train', 'train', 'train', 'Precision', 'mAP@0.5', 'val', 'val', 'val', 'Recall', 'mAP@0.5:0.95'] # legends 380 | t = ['Box', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles 381 | for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')): 382 | results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T 383 | n = results.shape[1] # number of rows 384 | x = range(start, min(stop, n) if stop else n) 385 | fig, ax = plt.subplots(1, 5, figsize=(14, 3.5), tight_layout=True) 386 | ax = ax.ravel() 387 | for i in range(5): 388 | for j in [i, i + 5]: 389 | y = results[j, x] 390 | ax[i].plot(x, y, marker='.', label=s[j]) 391 | # y_smooth = butter_lowpass_filtfilt(y) 392 | # ax[i].plot(x, np.gradient(y_smooth), marker='.', label=s[j]) 393 | 394 | ax[i].set_title(t[i]) 395 | ax[i].legend() 396 | ax[i].set_ylabel(f) if i == 0 else None # add filename 397 | fig.savefig(f.replace('.txt', '.png'), dpi=200) 398 | 399 | 400 | def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''): 401 | # Plot training 'results*.txt'. from utils.plots import *; plot_results(save_dir='runs/train/exp') 402 | fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True) 403 | ax = ax.ravel() 404 | s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall', 405 | 'val Box', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95'] 406 | if bucket: 407 | # files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id] 408 | files = ['results%g.txt' % x for x in id] 409 | c = ('gsutil cp ' + '%s ' * len(files) + '.') % tuple('gs://%s/results%g.txt' % (bucket, x) for x in id) 410 | os.system(c) 411 | else: 412 | files = list(Path(save_dir).glob('results*.txt')) 413 | assert len(files), 'No results.txt files found in %s, nothing to plot.' % os.path.abspath(save_dir) 414 | for fi, f in enumerate(files): 415 | try: 416 | results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T 417 | n = results.shape[1] # number of rows 418 | x = range(start, min(stop, n) if stop else n) 419 | for i in range(10): 420 | y = results[i, x] 421 | if i in [0, 1, 2, 5, 6, 7]: 422 | y[y == 0] = np.nan # don't show zero loss values 423 | # y /= y[0] # normalize 424 | label = labels[fi] if len(labels) else f.stem 425 | ax[i].plot(x, y, marker='.', label=label, linewidth=2, markersize=8) 426 | ax[i].set_title(s[i]) 427 | # if i in [5, 6, 7]: # share train and val loss y axes 428 | # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5]) 429 | except Exception as e: 430 | print('Warning: Plotting error for %s; %s' % (f, e)) 431 | 432 | ax[1].legend() 433 | fig.savefig(Path(save_dir) / 'results.png', dpi=200) 434 | --------------------------------------------------------------------------------