├── README.md ├── 深度学习实战1-(keras框架)企业数据分析与预测 ├── GM11.csv ├── data.csv ├── enterprise_data_analysis.py ├── net.h5 └── result.csv ├── 深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex) ├── 123.png ├── LICENSE ├── README.md ├── __pycache__ │ ├── eval.cpython-38.pyc │ ├── models.cpython-37.pyc │ ├── models.cpython-38.pyc │ ├── pix2tex.cpython-37.pyc │ └── pix2tex.cpython-38.pyc ├── dataset │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── arxiv.cpython-37.pyc │ │ ├── arxiv.cpython-38.pyc │ │ ├── dataset.cpython-37.pyc │ │ ├── dataset.cpython-38.pyc │ │ ├── demacro.cpython-37.pyc │ │ ├── demacro.cpython-38.pyc │ │ ├── extract_latex.cpython-37.pyc │ │ ├── extract_latex.cpython-38.pyc │ │ ├── latex2png.cpython-37.pyc │ │ ├── latex2png.cpython-38.pyc │ │ ├── render.cpython-37.pyc │ │ ├── render.cpython-38.pyc │ │ ├── scraping.cpython-37.pyc │ │ └── scraping.cpython-38.pyc │ ├── arxiv.py │ ├── dataset.py │ ├── demacro.py │ ├── extract_latex.py │ ├── latex2png.py │ ├── postprocess.py │ ├── preprocessing │ │ ├── generate_latex_vocab.py │ │ ├── preprocess_formulas.py │ │ ├── preprocess_latex.js │ │ └── third_party │ │ │ ├── README.md │ │ │ ├── katex │ │ │ ├── LICENSE.txt │ │ │ ├── README.md │ │ │ ├── cli.js │ │ │ ├── katex.js │ │ │ ├── package.json │ │ │ └── src │ │ │ │ ├── Lexer.js │ │ │ │ ├── Options.js │ │ │ │ ├── ParseError.js │ │ │ │ ├── Parser.js │ │ │ │ ├── Settings.js │ │ │ │ ├── Style.js │ │ │ │ ├── buildCommon.js │ │ │ │ ├── buildHTML.js │ │ │ │ ├── buildMathML.js │ │ │ │ ├── buildTree.js │ │ │ │ ├── delimiter.js │ │ │ │ ├── domTree.js │ │ │ │ ├── environments.js │ │ │ │ ├── fontMetrics.js │ │ │ │ ├── fontMetricsData.js │ │ │ │ ├── functions.js │ │ │ │ ├── mathMLTree.js │ │ │ │ ├── parseData.js │ │ │ │ ├── parseTree.js │ │ │ │ ├── symbols.js │ │ │ │ └── utils.js │ │ │ └── match-at │ │ │ ├── README.md │ │ │ └── package.json │ ├── render.py │ ├── scraping.py │ └── tokenizer.json ├── eval.py ├── gui.py ├── models.py ├── newlatex │ ├── gongshi6.png │ └── img2latex.py ├── pix2tex.py ├── requirements.txt ├── resources │ ├── MathJax.js │ ├── __pycache__ │ │ ├── resources.cpython-37.pyc │ │ └── resources.cpython-38.pyc │ ├── icon.svg │ ├── processing-icon-anim.svg │ ├── resources.py │ └── resources.qrc ├── settings │ ├── config.yaml │ └── debug.yaml ├── setup_desktop.py ├── train.py ├── train_resizer.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── utils.cpython-37.pyc │ └── utils.cpython-38.pyc │ └── utils.py ├── 深度学习实战11(进阶版)-BERT模型的微调应用-文本分类案例 ├── bert_model.py └── data │ └── train.csv ├── 深度学习实战12(进阶版)-利用Dewarp实现文本扭曲矫正 ├── 123.png └── dewarp.py ├── 深度学习实战13(进阶版)-文本纠错功能,经常写错别字的小伙伴的福星 ├── correct1.py ├── correct2.py ├── correct3.py ├── correct4.py ├── correct_new.py └── kHanyuPinlu.txt ├── 深度学习实战14(进阶版)-手写文字OCR识别,手写笔记也可以识别了 ├── 123456.png ├── ocr_result │ └── ndarray_1671255339.407687.jpg ├── writeOCR.py └── writeOCR_new.py ├── 深度学习实战15(进阶版)-让机器进行阅读理解+你可以变成出题者提问 ├── __pycache__ │ └── utils.cpython-38.pyc ├── reading.py └── utils.py ├── 深度学习实战2-(keras框架)企业信用评级与预测 ├── enterprise_credit.py └── train_new.csv ├── 深度学习实战3-文本卷积神经网络(TextCNN)新闻文本分类 ├── Text_classif.py ├── __pycache__ │ └── config.cpython-38.pyc ├── config.py └── vocab.txt ├── 深度学习实战4-卷积神经网络(DenseNet)数学图形识别+题目模式识别 └── math_classif.py ├── 深度学习实战5-卷积神经网络(CNN)中文OCR识别项目 ├── DroidSansFallback.ttf ├── chinese.txt └── ocr.py ├── 深度学习实战6-卷积神经网络(Pytorch)+聚类分析实现空气质量与天气预测 ├── 1234.csv ├── net.pkl ├── weather - 副本.csv ├── weather.csv ├── weather.py ├── weather_new.csv └── weathers.csv ├── 深度学习实战7-电商产品评论的情感分析 ├── Sentiment.py ├── __pycache__ │ └── data_loader.cpython-38.pyc ├── data_loader.py └── online_shopping_10_cats.csv ├── 深度学习实战8-生活照片转化漫画照片应用 ├── img2cartoon.py ├── input.png ├── new_image.png ├── result0.png ├── result2.png └── result21.png └── 深度学习实战9-文本生成图像-本地电脑实现text2img ├── stable_diffusion_tf ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── autoencoder_kl.cpython-38.pyc │ ├── clip_encoder.cpython-38.pyc │ ├── constants.cpython-38.pyc │ ├── diffusion_model.cpython-38.pyc │ ├── layers.cpython-38.pyc │ └── stable_diffusion.cpython-38.pyc ├── autoencoder_kl.py ├── clip_encoder.py ├── clip_tokenizer │ ├── __init__.py │ ├── __pycache__ │ │ └── __init__.cpython-38.pyc │ └── bpe_simple_vocab_16e6.txt.gz ├── constants.py ├── diffusion_model.py ├── layers.py └── stable_diffusion.py ├── text2img.py └── text2img2.py /README.md: -------------------------------------------------------------------------------- 1 | # deep_learning 2 | 深度学习实战项目主要围绕各行各业人工智能需要开发的入门级别的项目,由浅入深,涉及到的深度学习框架有tensorflow、keras、pytorch. 3 | 4 | ​ 5 | 人工智能实战项目 6 | 大家好,我是微学AI,本项目将围绕人工智能实战项目进行展开,紧密贴近生活,实战项目设计多个领域包括:金融、教育、医疗、地理、生物、人文、自然语言处理等;帮助各位读者结合机器学习与深度学习构建智能而且实用的人工智能简单系统,创建有影响力的AI应用,项目中提供项目原码,一步一步地运行每行代码,了解每行代码在什么,由浅入深,不断地解决多领域的问题。 7 | 8 | 目录 9 | 10 | 一、人工智能基础部分 11 | 12 | 1.人工智能基础部分1-人工智能的初步认识 13 | 14 | 2.人工智能基础部分2-一元一次函数感知器 15 | 16 | 3.人工智能基础部分3-方差损失函数的概念 17 | 18 | 4.人工智能基础部分4-梯度下降和反向传播 19 | 20 | 5.人工智能基础部分5-激活函数的概念 21 | 22 | 6.人工智能基础部分6-神经网络初步认识 23 | 24 | 7.人工智能基础部分7-高维空间的神经网络认识 25 | 26 | 8.人工智能基础部分8-深度学习框架keras入门案例 27 | 28 | 9.人工智能基础部分9-深度学习深入了解 29 | 30 | 10.人工智能基础部分10-卷积神经网络初步认识 31 | 32 | 11.人工智能基础部分11-图像识别实战 33 | 34 | 12.人工智能基础部分12-循环神经网络初步认识 35 | 36 | 13.人工智能基础部分13-LSTM网络:预测上证指数走势 37 | 38 | (更新完) 39 | 40 | 二、机器学习实战项目 41 | 42 | 1.机器学习实战1-四种算法对比对客户信用卡还款情况进行预测 43 | 44 | 2.机器学习实战2-聚类算法分析亚洲足球梯队 45 | 46 | 3.机器学习实战3-利用决策树算法根据天气数据集做出决策 47 | 48 | 4.机器学习实战4-教育领域:学生成绩的可视化分析与成绩预测-详细分析 49 | 50 | 5.机器学习实战5-天气预测系列:利用数据集可视化分析数据,并预测某个城市的天气情况 51 | 52 | 6.机器学习实战6-电子商务网站用户行为分析及服务推荐 53 | 54 | 7.机器学习实战7-服务员公司客户价值分析与流失分析 55 | 56 | 8.机器学习实战8-基于基站定位数据的商圈分析 57 | 58 | 9.机器学习实战9-售车逃税店铺自动识别 59 | 60 | 10.机器学习实战10-企业关联规则挖掘 61 | 62 | ...(待更新) 63 | 64 | 三、深度学习实战项目 65 | 66 | 1.深度学习实战1-(keras框架)企业数据分析与预测 67 | 68 | 2.深度学习实战2-(keras框架)企业信用评级与预测 69 | 70 | 3.深度学习实战3-文本卷积神经网络(TextCNN)新闻文本分类 71 | 72 | 4.深度学习实战4-卷积神经网络(DenseNet)数学图形识别+题目模式识别 73 | 74 | 5.深度学习实战5-卷积神经网络(CNN)中文OCR识别项目 75 | 76 | 6.深度学习实战6-卷积神经网络(Pytorch)+聚类分析实现空气质量与天气预测 77 | 78 | 7.深度学习实战7-电商产品评论的情感分析 79 | 80 | 8.深度学习实战8-生活照片转化漫画照片应用 81 | 82 | 9.深度学习实战9-文本生成图像-本地电脑实现text2img 83 | 84 | 10.深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex) 85 | 86 | 11.深度学习实战11(进阶版)-BERT模型的微调应用-文本分类案例 87 | 88 | 12.深度学习实战12(进阶版)-利用Dewarp实现文本扭曲矫正 89 | 90 | 13.深度学习实战13(进阶版)-文本纠错功能,经常写错别字的小伙伴的福星 91 | 92 | 14.深度学习实战14(进阶版)-手写文字OCR识别,手写笔记也可以识别了 93 | 94 | 15.深度学习实战15(进阶版)-让机器进行阅读理解+你可以变成出题者提问 95 | 96 | 16.深度学习实战16(进阶版)-虚拟截图识别文字-可以做纸质合同和表格识别 97 | 98 | 17.深度学习实战17(进阶版)-智能辅助编辑平台系统的搭建与开发案例 99 | 100 | 18.深度学习实战18(进阶版)-NLP的15项任务大融合系统,可实现市面上你能想到的NLP任务 101 | 102 | 19.深度学习实战19(进阶版)-SpeakGPT的本地实现部署测试,基于ChatGPT在自己的平台实现SpeakGPT功能 103 | 104 | 20.深度学习实战20(进阶版)-文件智能搜索系统,可以根据文件内容进行关键词搜索,快速找到文件 105 | 106 | 21.深度学习实战21(进阶版)-AI实体百科搜索,任何名词都可搜索到的百科全书 107 | 108 | 22.深度学习实战22(进阶版)-AI漫画视频生成模型,做自己的漫画视频 109 | 110 | 23.深度学习实战23(进阶版)-语义分割实战,实现人物图像抠图的效果(计算机视觉) 111 | 112 | 24.深度学习实战24-人工智能(Pytorch)搭建transformer模型,真正跑通transformer模型,深刻了解transformer的架构 113 | 114 | 25.深度学习实战25-人工智能(Pytorch)搭建T5模型,真正跑通T5模型,用T5模型生成数字加减结果 115 | 116 | 26.深度学习实战26-(Pytorch)搭建TextCNN实现多标签文本分类的任务 117 | 118 | 27.深度学习实战27-Pytorch框架+BERT实现中文文本的关系抽取 119 | 120 | ...(待更新) 121 | 122 | 四、深度学习技巧应用 123 | 124 | 1.深度学习技巧应用1-利用知识蒸馏技术做模型压缩 125 | 126 | 2.深度学习技巧应用2-神经网络中的‘残差连接’ 127 | 128 | 3.深度学习技巧应用3-神经网络中的超参数搜索 129 | 130 | 4.深度学习技巧应用4-模型融合:投票法、加权平均法、集成模型法 131 | 132 | 5.深度学习技巧应用5-神经网络中的模型剪枝技巧 133 | 134 | 6.深度学习技巧应用6-神经网络中模型冻结-迁移学习技巧 135 | 136 | 7.深度学习技巧应用7-K折交叉验证的实践操作 137 | 138 | 8.深度学习技巧应用8-各种数据类型的加载与处理,并输入神经网络进行训练 139 | 140 | 9.深度学习技巧应用9-模型训练中学习率的调整和假数据生成技巧与总结 141 | 142 | 10.深度学习技巧应用10-PyTorch框架中早停法类的构建与运用 143 | 144 | 11.深度学习技巧应用11-模型训练中稀疏化参数与稀疏损失函数的应用 145 | 146 | ...(待更新) 147 | 148 | 五、知识图谱实战项目 149 | 150 | 知识图谱开篇:知识图谱实战开篇-讲述知识图谱是什么,要学哪些知识,一文讲通 151 | 152 | 1.知识图谱实战应用1-知识图谱的构建与可视化应用 153 | 154 | 2.知识图谱实战应用2-知识图谱的知识融合与知识消歧 155 | 156 | 3.知识图谱实战应用3-知识图谱中的电影推荐算法 157 | 158 | 4.知识图谱实战应用4-知识图谱中寻找相似用户(协同过滤算法) 159 | 160 | 5.知识图谱实战应用5-基于知识图谱的创建语义搜索功能 161 | 162 | 6.知识图谱实战应用6-基于知识推理进行知识补全的功能 163 | 164 | 7.知识图谱实战应用7-最完整的常用Cypher查询语句与实际应用 165 | 166 | 8.知识图谱实战应用8-从文本关系抽取到知识图谱关系构建流程贯通 167 | 168 | 9.知识图谱实战应用9-基于neo4j的知识图谱框架设计与类模型构建 169 | 170 | ...(待更新) 171 | 172 | 173 | 以上已整理所有的代码与数据集,模型,可直接运行,可以关注:CSDN博客,微学AI,获取全套代码请联系: 174 | 也可微:shenqiang0601 175 | 或者联系邮箱:846514373@qq.com 176 | 177 | 178 | -------------------------------------------------------------------------------- /深度学习实战1-(keras框架)企业数据分析与预测/GM11.csv: -------------------------------------------------------------------------------- 1 | ,x3,x5,x7,y 2 | 2000,448.19,6212.7,525.71,64.87 3 | 2001,549.97,7601.73,618.25,99.75 4 | 2002,686.44,8092.82,638.94,88.11 5 | 2003,802.59,8767.98,656.58,106.07 6 | 2004,904.57,9422.33,758.83,137.32 7 | 2005,1000.69,9751.44,878.26,188.14 8 | 2006,1121.13,11349.47,923.67,219.91 9 | 2007,1248.29,11467.35,978.21,271.91 10 | 2008,1370.68,10671.78,1009.24,269.1 11 | 2009,1494.27,11570.58,1175.17,300.55 12 | 2010,1677.77,13120.83,1348.93,338.45 13 | 2011,1905.84,14468.24,1519.16,408.86 14 | 2012,2199.14,15444.93,1696.38,476.72 15 | 2013,2624.24,18951.32,1863.34,838.99 16 | 2014,3187.39,20835.95,2105.54,843.14 17 | 2015,3615.77,22820.89,2659.85,1107.67 18 | 2016,4476.38,25011.61,3263.57,1399.16 19 | 2017,5243.03,28209.74,3412.21,1535.14 20 | 2018,5977.27,30490.44,3758.39,1579.68 21 | 2019,6882.85,33156.83,4454.55,2088.14 22 | 2020,7042.31,35046.63,4600.4, 23 | 2021,8166.92,38384.22,5214.78, 24 | 2022,9471.11,42039.66,5911.2, 25 | -------------------------------------------------------------------------------- /深度学习实战1-(keras框架)企业数据分析与预测/data.csv: -------------------------------------------------------------------------------- 1 | x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,x11,x12,x13,y 2 | 3831732,181.54,448.19,7571,6212.7,6370241,525.71,985.31,60.62,65.66,120,1.029,5321,64.87 3 | 3913824,214.63,549.97,9038.16,7601.73,6467115,618.25,1259.2,73.46,95.46,113.5,1.051,6529,99.75 4 | 3928907,239.56,686.44,9905.31,8092.82,6560508,638.94,1468.06,81.16,81.16,108.2,1.064,7008,88.11 5 | 4282130,261.58,802.59,10444.6,8767.98,6664862,656.58,1678.12,85.72,91.7,102.2,1.092,7694,106.07 6 | 4453911,283.14,904.57,11255.7,9422.33,6741400,758.83,1893.52,88.88,114.61,97.7,1.2,8027,137.32 7 | 4548852,308.58,1000.69,12018.52,9751.44,6850024,878.26,2139.18,92.85,152.78,98.5,1.198,8549,188.14 8 | 4962579,348.09,1121.13,13966.53,11349.47,7006896,923.67,2492.74,94.37,170.62,102.8,1.348,9566,219.91 9 | 5029338,387.81,1248.29,14694,11467.35,7125979,978.21,2841.65,97.28,214.53,98.9,1.467,10473,271.91 10 | 5070216,453.49,1370.68,13380.47,10671.78,7206229,1009.24,3203.96,103.07,202.18,97.6,1.56,11469,269.1 11 | 5210706,533.55,1494.27,15002.59,11570.58,7251888,1175.17,3758.62,109.91,222.51,100.1,1.456,12360,300.55 12 | 5407087,598.33,1677.77,16884.16,13120.83,7376720,1348.93,4450.55,117.15,249.01,101.7,1.424,14174,338.45 13 | 5744550,665.32,1905.84,18287.24,14468.24,7505322,1519.16,5154.23,130.22,303.41,101.5,1.456,16394,408.86 14 | 5994973,738.97,2199.14,19850.66,15444.93,7607220,1696.38,6081.86,128.51,356.99,102.3,1.438,17881,476.72 15 | 6236312,877.07,2624.24,22469.22,18951.32,7734787,1863.34,7140.32,149.87,429.36,103.4,1.474,20058,838.99 16 | 6529045,1005.37,3187.39,25316.72,20835.95,7841695,2105.54,8287.38,169.19,508.84,105.9,1.515,22114,843.14 17 | 6791495,1118.03,3615.77,27609.59,22820.89,7946154,2659.85,9138.21,172.28,557.74,97.5,1.633,24190,1107.67 18 | 7110695,1304.48,4476.38,30658.49,25011.61,8061370,3263.57,10748.28,188.57,664.06,103.2,1.638,29549,1399.16 19 | 7431755,1700.87,5243.03,34438.08,28209.74,8145797,3412.21,12423.44,204.54,710.66,105.5,1.67,34214,1535.14 20 | 7512997,1969.51,5977.27,38053.52,30490.44,8222969,3758.39,13551.21,213.76,760.49,103,1.825,37934,1579.68 21 | 7599295,2110.78,6882.85,42049.14,33156.83,8323096,4454.55,15420.14,228.46,852.56,102.6,1.906,41972,2088.14 22 | -------------------------------------------------------------------------------- /深度学习实战1-(keras框架)企业数据分析与预测/enterprise_data_analysis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from keras.models import Sequential 4 | from keras.layers.core import Dense, Activation 5 | import matplotlib.pylab as plt # 绘制图像库 6 | 7 | def GM11(x0): #自定义灰色预测函数 8 | import numpy as np 9 | x1 = x0.cumsum() #1-AGO序列 10 | z1 = (x1[:len(x1)-1] + x1[1:])/2.0 11 | z1 = z1.reshape((len(z1),1)) 12 | B = np.append(-z1, np.ones_like(z1), axis = 1) 13 | Yn = x0[1:].reshape((len(x0)-1, 1)) 14 | [[a],[b]] = np.dot(np.dot(np.linalg.inv(np.dot(B.T, B)), B.T), Yn) #计算参数 15 | f = lambda k: (x0[0]-b/a)*np.exp(-a*(k-1))-(x0[0]-b/a)*np.exp(-a*(k-2)) #还原值 16 | delta = np.abs(x0 - np.array([f(i) for i in range(1,len(x0)+1)])) 17 | C = delta.std()/x0.std() 18 | P = 1.0*(np.abs(delta - delta.mean()) < 0.6745*x0.std()).sum()/len(x0) 19 | return f, a, b, x0[0], C, P #返回灰色预测函数、a、b、首项、方差比、小残差概率 20 | 21 | data = pd.read_csv('data.csv') #读取数据 22 | data.index = range(2000,2020) # 标注索引信息年份 23 | 24 | data.loc[2020] = None 25 | data.loc[2021] = None 26 | data.loc[2022] = None 27 | l = ['x1', 'x2', 'x3', 'x4', 'x5', 'x7'] 28 | l1 = ['x3', 'x5', 'x7'] 29 | for i in l1: 30 | f, _, _, _, C, _ = GM11(data[i].loc[range(2000, 2020)].values) 31 | print("%s后验差比值:%0.4f" % (i, C)) # 后验差比值c,即:真实误差的方差同原始数据方差的比值。 32 | data[i].loc[2020] = f(len(data) - 2) # 2014年预测结果 33 | data[i].loc[2021] = f(len(data) - 1) # 2015年预测结果 34 | data[i].loc[2022] = f(len(data)) # 2016年预测结果 35 | data[i] = data[i].round(2) # 保留两位小数 36 | 37 | data[l1 + ['y']].to_csv('GM11.csv') # 结果输出 38 | 39 | data = pd.read_csv('GM11.csv',index_col = 0) #读取数据 40 | feature = ['x3','x5','x7'] # 提取特征 41 | 42 | data_train = data.loc[range(2000, 2020)] # 取2014年前的数据建模 43 | print(data_train) 44 | data_mean = data_train.mean() 45 | data_std = data_train.std() 46 | data_train = (data_train - data_mean) / data_std # 数据标准化 后进行训练 47 | 48 | x_train = data_train[feature].values # 特征数据 49 | y_train = data_train['y'].values # 标签数据 50 | 51 | model = Sequential() #建立模型 52 | model.add(Dense(12,activation='relu',input_dim=3)) 53 | model.add(Dense(24,activation='relu')) # 隐藏层 54 | model.add(Dense(1)) # 输出层 55 | model.compile(loss='mean_squared_error', optimizer='adam') #编译模型 56 | model.fit(x_train, y_train, epochs = 10000, batch_size = 16,verbose=2) #训练模型,训练1000次 57 | model.save('net.h5') #保存模型参数 58 | 59 | x = ((data[feature] - data_mean[feature])/data_std[feature]).values 60 | data[u'y_pred'] = model.predict(x) * data_std['y'] + data_mean['y'] 61 | data.to_csv('result.csv') 62 | 63 | p = pd.read_csv('result.csv') 64 | p = p[['y','y_pred']].copy() 65 | p.index=range(2000,2023) 66 | p.plot(style=['b-o','r-*'],xticks=p.index,figsize=(15,5)) 67 | plt.xlabel("Year") 68 | plt.show() -------------------------------------------------------------------------------- /深度学习实战1-(keras框架)企业数据分析与预测/net.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战1-(keras框架)企业数据分析与预测/net.h5 -------------------------------------------------------------------------------- /深度学习实战1-(keras框架)企业数据分析与预测/result.csv: -------------------------------------------------------------------------------- 1 | ,x3,x5,x7,y,y_pred 2 | 2000,448.19,6212.7,525.71,64.87,55.39966 3 | 2001,549.97,7601.73,618.25,99.75,86.10028 4 | 2002,686.44,8092.82,638.94,88.11,102.56537 5 | 2003,802.59,8767.98,656.58,106.07,117.864044 6 | 2004,904.57,9422.33,758.83,137.32,150.1366 7 | 2005,1000.69,9751.44,878.26,188.14,183.38416 8 | 2006,1121.13,11349.47,923.67,219.91,211.61807 9 | 2007,1248.29,11467.35,978.21,271.91,234.03189 10 | 2008,1370.68,10671.78,1009.24,269.1,246.47754 11 | 2009,1494.27,11570.58,1175.17,300.55,294.22055 12 | 2010,1677.77,13120.83,1348.93,338.45,352.38922 13 | 2011,1905.84,14468.24,1519.16,408.86,412.72867 14 | 2012,2199.14,15444.93,1696.38,476.72,478.17303 15 | 2013,2624.24,18951.32,1863.34,838.99,832.98267 16 | 2014,3187.39,20835.95,2105.54,843.14,837.74603 17 | 2015,3615.77,22820.89,2659.85,1107.67,1107.5432 18 | 2016,4476.38,25011.61,3263.57,1399.16,1416.7266 19 | 2017,5243.03,28209.74,3412.21,1535.14,1482.1152 20 | 2018,5977.27,30490.44,3758.39,1579.68,1663.0085 21 | 2019,6882.85,33156.83,4454.55,2088.14,2065.5095 22 | 2020,7042.31,35046.63,4600.4,,2189.3767 23 | 2021,8166.92,38384.22,5214.78,,2519.647 24 | 2022,9471.11,42039.66,5911.2,,2886.0935 25 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/123.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/123.png -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Lukas Blecher 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/README.md: -------------------------------------------------------------------------------- 1 | # pix2tex - LaTeX OCR 2 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1ba_qCGJl29dFQqfBjdqMik3o_EqPE4fr) 3 | 4 | The goal of this project is to create a learning based system that takes an image of a math formula and returns corresponding LaTeX code. 5 | 6 | ![header](https://user-images.githubusercontent.com/55287601/109183599-69431f00-778e-11eb-9809-d42b9451e018.png) 7 | 8 | ## Requirements 9 | ### Model 10 | * PyTorch (tested on v1.7.1) 11 | * Python 3.7+ & dependencies (`requirements.txt`) 12 | ``` 13 | pip install -r requirements.txt 14 | ``` 15 | ### Dataset 16 | In order to render the math in many different fonts we use XeLaTeX, generate a PDF and finally convert it to a PNG. For the last step we need to use some third party tools: 17 | * [XeLaTeX](https://www.ctan.org/pkg/xetex) 18 | * [ImageMagick](https://imagemagick.org/) with [Ghostscript](https://www.ghostscript.com/index.html). (for converting pdf to png) 19 | * [Node.js](https://nodejs.org/) to run [KaTeX](https://github.com/KaTeX/KaTeX) (for normalizing Latex code) 20 | * [`de-macro`](https://www.ctan.org/pkg/de-macro) >= 1.4 (only for parsing arxiv papers) 21 | * Python 3.7+ & dependencies (`requirements.txt`) 22 | 23 | ## Using the model 24 | 1. Download/Clone this repository 25 | 2. For now you need to install the Python dependencies specified in `requirements.txt` (look [above](#Requirements)) 26 | 3. Download the `weights.pth` (and optionally `image_resizer.pth`) file from the [Releases](https://github.com/lukas-blecher/LaTeX-OCR/releases/latest)->Assets section and place it in the `checkpoints` directory 27 | 28 | Thanks to [@katie-lim](https://github.com/katie-lim), you can use a nice user interface as a quick way to get the model prediction. Just call the GUI with `python gui.py`. From here you can take a screenshot and the predicted latex code is rendered using [MathJax](https://www.mathjax.org/) and copied to your clipboard. 29 | 30 | ![demo](https://user-images.githubusercontent.com/55287601/117812740-77b7b780-b262-11eb-81f6-fc19766ae2ae.gif) 31 | 32 | If the model is unsure about the what's in the image it might output a different prediction every time you click "Retry". With the `temperature` parameter you can control this behavior (low temperature will produce the same result). 33 | 34 | Alternatively you can use `pix2tex.py` with similar functionality as `gui.py`, only as command line tool. In this case you don't need to install PyQt5. Using this script you can also parse already existing images from the disk. 35 | 36 | **Note:** As of right now it works best with images of smaller resolution. Don't zoom in all the way before taking a picture. Double check the result carefully. You can try to redo the prediction with an other resolution if the answer was wrong. 37 | 38 | **Update:** I have trained an image classifier on randomly scaled images of the training data to predict the original size. 39 | This model will automatically resize the custom image to best resemble the training data and thus increase performance of images found in the wild. To use this preprocessing step, all you have to do is download the second weights file mentioned above. You should be able to take bigger (or smaller) images of the formula and still get a satisfying result 40 | 41 | ## Training the model [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1MqZSKzSgEnJB9lU7LyPma4bo4J3dnj1E) 42 | 43 | 1. First we need to combine the images with their ground truth labels. I wrote a dataset class (which needs further improving) that saves the relative paths to the images with the LaTeX code they were rendered with. To generate the dataset pickle file run 44 | 45 | ``` 46 | python dataset/dataset.py --equations path_to_textfile --images path_to_images --tokenizer dataset/tokenizer.json --out dataset.pkl 47 | ``` 48 | 49 | You can find my generated training data on the [Google Drive](https://drive.google.com/drive/folders/13CA4vAmOmD_I_dSbvLp-Lf0s6KiaNfuO) as well (formulae.zip - images, math.txt - labels). Repeat the step for the validation and test data. All use the same label text file. 50 | 51 | 2. Edit the `data` (and `valdata`) entry in the config file to the newly generated `.pkl` file. Change other hyperparameters if you want to. See `settings/config.yaml` for a template. 52 | 3. Now for the actual training run 53 | ``` 54 | python train.py --config path_to_config_file 55 | ``` 56 | 57 | If you want to use your own data you might be interested in creating your own tokenizer with 58 | ``` 59 | python dataset/dataset.py --equations path_to_textfile --vocab-size 8000 --out tokenizer.json 60 | ``` 61 | Don't forget to update the path to the tokenizer in the config file and set `num_tokens` to your vocabulary size. 62 | 63 | ## Model 64 | The model consist of a ViT [[1](#References)] encoder with a ResNet backbone and a Transformer [[2](#References)] decoder. 65 | 66 | ### Performance 67 | | BLEU score | normed edit distance | 68 | | ---------- | -------------------- | 69 | | 0.88 | 0.10 | 70 | 71 | ## Data 72 | We need paired data for the network to learn. Luckily there is a lot of LaTeX code on the internet, e.g. [wikipedia](https://www.wikipedia.org), [arXiv](https://www.arxiv.org). We also use the formulae from the [im2latex-100k](https://zenodo.org/record/56198#.V2px0jXT6eA) dataset. 73 | All of it can be found [here](https://drive.google.com/drive/folders/13CA4vAmOmD_I_dSbvLp-Lf0s6KiaNfuO) 74 | 75 | ### Fonts 76 | Latin Modern Math, GFSNeohellenicMath.otf, Asana Math, XITS Math, Cambria Math 77 | 78 | 79 | ## TODO 80 | - [x] add more evaluation metrics 81 | - [x] create a GUI 82 | - [ ] add beam search 83 | - [ ] support handwritten formulae 84 | - [ ] reduce model size (distillation) 85 | - [ ] find optimal hyperparameters 86 | - [ ] tweak model structure 87 | - [ ] fix data scraping and scrape more data 88 | - [ ] trace the model 89 | 90 | 91 | ## Contribution 92 | Contributions of any kind are welcome. 93 | 94 | ## Acknowledgment 95 | Code taken and modified from [lucidrains](https://github.com/lucidrains), [rwightman](https://github.com/rwightman/pytorch-image-models), [im2markup](https://github.com/harvardnlp/im2markup), [arxiv_leaks](https://github.com/soskek/arxiv_leaks), [pkra: Mathjax](https://github.com/pkra/MathJax-single-file), [harupy: snipping tool](https://github.com/harupy/snipping-tool) 96 | 97 | ## References 98 | [1] [An Image is Worth 16x16 Words](https://arxiv.org/abs/2010.11929) 99 | 100 | [2] [Attention Is All You Need](https://arxiv.org/abs/1706.03762) 101 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/__pycache__/eval.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/__pycache__/eval.cpython-38.pyc -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/__pycache__/models.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/__pycache__/models.cpython-37.pyc -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/__pycache__/models.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/__pycache__/models.cpython-38.pyc -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/__pycache__/pix2tex.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/__pycache__/pix2tex.cpython-37.pyc -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/__pycache__/pix2tex.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/__pycache__/pix2tex.cpython-38.pyc -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | import dataset.arxiv 2 | import dataset.extract_latex 3 | import dataset.latex2png 4 | import dataset.render 5 | import dataset.scraping 6 | import dataset.dataset 7 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/__pycache__/arxiv.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/__pycache__/arxiv.cpython-37.pyc -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/__pycache__/arxiv.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/__pycache__/arxiv.cpython-38.pyc -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/__pycache__/demacro.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/__pycache__/demacro.cpython-37.pyc -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/__pycache__/demacro.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/__pycache__/demacro.cpython-38.pyc -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/__pycache__/extract_latex.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/__pycache__/extract_latex.cpython-37.pyc -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/__pycache__/extract_latex.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/__pycache__/extract_latex.cpython-38.pyc -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/__pycache__/latex2png.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/__pycache__/latex2png.cpython-37.pyc -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/__pycache__/latex2png.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/__pycache__/latex2png.cpython-38.pyc -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/__pycache__/render.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/__pycache__/render.cpython-37.pyc -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/__pycache__/render.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/__pycache__/render.cpython-38.pyc -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/__pycache__/scraping.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/__pycache__/scraping.cpython-37.pyc -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/__pycache__/scraping.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/__pycache__/scraping.cpython-38.pyc -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/arxiv.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/soskek/arxiv_leaks 2 | 3 | import argparse 4 | import json 5 | import os 6 | import glob 7 | import re 8 | import sys 9 | import argparse 10 | import logging 11 | import shutil 12 | import subprocess 13 | import tarfile 14 | import tempfile 15 | import chardet 16 | import logging 17 | import requests 18 | import urllib.request 19 | from urllib.error import HTTPError 20 | try: 21 | from extract_latex import * 22 | from scraping import * 23 | from demacro import * 24 | except: 25 | from dataset.extract_latex import * 26 | from dataset.scraping import * 27 | from dataset.demacro import * 28 | 29 | # logging.getLogger().setLevel(logging.INFO) 30 | arxiv_id = re.compile(r'(? File written: {0}'.format(path)) 93 | 94 | 95 | if __name__ == '__main__': 96 | main() 97 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/extract_latex.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | 4 | MIN_CHARS = 20 5 | MAX_CHARS = 3000 6 | dollar = re.compile(r'((? 0 and s[i-1] == '\\': # not perfect 21 | continue 22 | else: 23 | a.append(1) 24 | if i == 0: 25 | surrounding = True 26 | elif c == '}': 27 | if i > 0 and s[i-1] == '\\': 28 | continue 29 | else: 30 | a.append(-1) 31 | b = np.cumsum(a) 32 | if len(b) > 1 and b[-1] != 0: 33 | raise ValueError(s) 34 | surrounding = s[-1] == '}' and surrounding 35 | if not surrounding: 36 | return s 37 | elif (b == 0).sum() == 1: 38 | return s[1:-1] 39 | else: 40 | return s 41 | 42 | 43 | def clean_matches(matches, min_chars=MIN_CHARS): 44 | template = r'\\%s\s?\{(.*?)\}' 45 | sub = [re.compile(template % s) for s in ['ref', 'cite', 'label', 'caption']] 46 | faulty = [] 47 | for i in range(len(matches)): 48 | if 'tikz' in matches[i]: # do not support tikz at the moment 49 | faulty.append(i) 50 | continue 51 | for s in sub: 52 | matches[i] = re.sub(s, '', matches[i]) 53 | matches[i] = matches[i].replace('\n', '').replace(r'\notag', '').replace(r'\nonumber', '') 54 | matches[i] = re.sub(outer_whitespace, '', matches[i]) 55 | if len(matches[i]) < min_chars: 56 | faulty.append(i) 57 | continue 58 | # try: 59 | # matches[i] = check_brackets(matches[i]) 60 | # except ValueError: 61 | # faulty.append(i) 62 | if matches[i][-1] == '\\' or 'newcommand' in matches[i][-1]: 63 | faulty.append(i) 64 | 65 | matches = [m.strip() for i, m in enumerate(matches) if i not in faulty] 66 | return list(set(matches)) 67 | 68 | 69 | def find_math(s, wiki=False): 70 | matches = [] 71 | x = re.findall(inline, s) 72 | matches.extend([(g[1] if g[1] != '' else g[-1]) for g in x]) 73 | if not wiki: 74 | patterns = [dollar, equation, align] 75 | groups = [1, 1, 0] 76 | else: 77 | patterns = [displaymath] 78 | groups = [1] 79 | for i, pattern in zip(groups, patterns): 80 | x = re.findall(pattern, s) 81 | matches.extend([g[i] for g in x]) 82 | 83 | return clean_matches(matches) 84 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/latex2png.py: -------------------------------------------------------------------------------- 1 | # mostly taken from http://code.google.com/p/latexmath2png/ 2 | # install preview.sty 3 | import os 4 | import re 5 | import sys 6 | import io 7 | import glob 8 | import tempfile 9 | import subprocess 10 | from PIL import Image 11 | 12 | 13 | class Latex: 14 | BASE = r''' 15 | \documentclass[varwidth]{standalone} 16 | \usepackage{fontspec,unicode-math} 17 | \usepackage[active,tightpage,displaymath,textmath]{preview} 18 | \setmathfont{%s} 19 | \begin{document} 20 | \thispagestyle{empty} 21 | %s 22 | \end{document} 23 | ''' 24 | 25 | def __init__(self, math, dpi=250, font='Latin Modern Math'): 26 | '''takes list of math code. `returns each element as PNG with DPI=`dpi`''' 27 | self.math = math 28 | self.dpi = dpi 29 | self.font = font 30 | 31 | def write(self, return_bytes=False): 32 | # inline = bool(re.match('^\$[^$]*\$$', self.math)) and False 33 | try: 34 | workdir = tempfile.gettempdir() 35 | fd, texfile = tempfile.mkstemp('.tex', 'eq', workdir, True) 36 | # print(self.BASE % (self.font, self.math)) 37 | with os.fdopen(fd, 'w+') as f: 38 | document = self.BASE % (self.font, '\n'.join(self.math)) 39 | # print(document) 40 | f.write(document) 41 | 42 | png = self.convert_file(texfile, workdir, return_bytes=return_bytes) 43 | return png 44 | 45 | finally: 46 | if os.path.exists(texfile): 47 | try: 48 | os.remove(texfile) 49 | except PermissionError: 50 | pass 51 | 52 | def convert_file(self, infile, workdir, return_bytes=False): 53 | 54 | try: 55 | # Generate the PDF file 56 | cmd = 'xelatex -halt-on-error -output-directory %s %s' % (workdir, infile) 57 | 58 | p = subprocess.Popen( 59 | cmd, 60 | shell=True, 61 | stdin=subprocess.PIPE, 62 | stdout=subprocess.PIPE, 63 | stderr=subprocess.PIPE, 64 | ) 65 | sout, serr = p.communicate() 66 | # Something bad happened, abort 67 | if p.returncode != 0: 68 | raise Exception('latex error', serr, sout) 69 | 70 | # Convert the PDF file to PNG's 71 | pdffile = infile.replace('.tex', '.pdf') 72 | pngfile = os.path.join(workdir, infile.replace('.tex', '.png')) 73 | 74 | cmd = 'magick convert -density %i -colorspace gray %s -quality 90 %s' % ( 75 | self.dpi, 76 | pdffile, 77 | pngfile, 78 | ) # -bg Transparent -z 9 79 | p = subprocess.Popen( 80 | cmd, 81 | shell=True, 82 | stdin=subprocess.PIPE, 83 | stdout=subprocess.PIPE, 84 | stderr=subprocess.PIPE, 85 | ) 86 | 87 | sout, serr = p.communicate() 88 | if p.returncode != 0: 89 | raise Exception('PDFpng error', serr, cmd, os.path.exists(pdffile), os.path.exists(infile)) 90 | if return_bytes: 91 | if len(self.math) > 1: 92 | png = [open(pngfile.replace('.png', '')+'-%i.png' % i, 'rb').read() for i in range(len(self.math))] 93 | else: 94 | png = [open(pngfile.replace('.png', '')+'.png', 'rb').read()] 95 | return png 96 | else: 97 | if len(self.math) > 1: 98 | return [(pngfile.replace('.png', '')+'-%i.png' % i) for i in range(len(self.math))] 99 | else: 100 | return (pngfile.replace('.png', '')+'.png') 101 | finally: 102 | # Cleanup temporaries 103 | basefile = infile.replace('.tex', '') 104 | tempext = ['.aux', '.pdf', '.log'] 105 | if return_bytes: 106 | ims = glob.glob(basefile+'*.png') 107 | for im in ims: 108 | os.remove(im) 109 | for te in tempext: 110 | tempfile = basefile + te 111 | if os.path.exists(tempfile): 112 | os.remove(tempfile) 113 | 114 | 115 | __cache = {} 116 | 117 | 118 | def tex2png(eq, **kwargs): 119 | if not eq in __cache: 120 | __cache[eq] = Latex(eq, **kwargs).write(return_bytes=True) 121 | return __cache[eq] 122 | 123 | 124 | def tex2pil(tex, **kwargs): 125 | pngs = Latex(tex, **kwargs).write(return_bytes=True) 126 | images = [Image.open(io.BytesIO(d)) for d in pngs] 127 | return images 128 | 129 | 130 | if __name__ == '__main__': 131 | if len(sys.argv) > 1: 132 | src = sys.argv[1] 133 | else: 134 | src = r'\begin{equation}\mathcal{ L}\nonumber\end{equation}' 135 | 136 | print('Equation is: %s' % src) 137 | print(Latex(src).write()) 138 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/postprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from tqdm.auto import tqdm 3 | 4 | if __name__ == '__main__': 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('-i', '--input', required=True, help='input file') 7 | parser.add_argument('-o', '--output', default=None, help='output file') 8 | args = parser.parse_args() 9 | 10 | d = open(args.input, 'r').read().split('\n') 11 | reqs = ['\\', '_', '^', '(', ')', '{', '}'] 12 | deleted = 0 13 | for i in tqdm(reversed(range(len(d))), total=len(d)): 14 | if not any([r in d[i] for r in reqs]): 15 | del d[i] 16 | deleted += 1 17 | print('removed %i lines' % deleted) 18 | f = args.output 19 | if f is None: 20 | f = args.input 21 | open(f, 'w').write('\n'.join(d)) 22 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/preprocessing/generate_latex_vocab.py: -------------------------------------------------------------------------------- 1 | import sys, logging, argparse, os 2 | 3 | def process_args(args): 4 | parser = argparse.ArgumentParser(description='Generate vocabulary file.') 5 | 6 | parser.add_argument('--data-path', dest='data_path', 7 | type=str, required=True, 8 | help=('Input file containing per line. This should be the file used for training.' 9 | )) 10 | parser.add_argument('--label-path', dest='label_path', 11 | type=str, required=True, 12 | help=('Input file containing a tokenized formula per line.' 13 | )) 14 | parser.add_argument('--output-file', dest='output_file', 15 | type=str, required=True, 16 | help=('Output file for putting vocabulary.' 17 | )) 18 | parser.add_argument('--unk-threshold', dest='unk_threshold', 19 | type=int, default=1, 20 | help=('If the number of occurences of a token is less than (including) the threshold, then it will be excluded from the generated vocabulary.' 21 | )) 22 | parser.add_argument('--log-path', dest="log_path", 23 | type=str, default='log.txt', 24 | help=('Log file path, default=log.txt' 25 | )) 26 | parameters = parser.parse_args(args) 27 | return parameters 28 | 29 | def main(args): 30 | parameters = process_args(args) 31 | logging.basicConfig( 32 | level=logging.INFO, 33 | format='%(asctime)-15s %(name)-5s %(levelname)-8s %(message)s', 34 | filename=parameters.log_path) 35 | 36 | console = logging.StreamHandler() 37 | console.setLevel(logging.INFO) 38 | formatter = logging.Formatter('%(asctime)-15s %(name)-5s %(levelname)-8s %(message)s') 39 | console.setFormatter(formatter) 40 | logging.getLogger('').addHandler(console) 41 | 42 | logging.info('Script being executed: %s'%__file__) 43 | 44 | label_path = parameters.label_path 45 | assert os.path.exists(label_path), label_path 46 | data_path = parameters.data_path 47 | assert os.path.exists(data_path), data_path 48 | 49 | formulas = open(label_path).readlines() 50 | vocab = {} 51 | max_len = 0 52 | with open(data_path) as fin: 53 | for line in fin: 54 | _, line_idx = line.strip().split() 55 | line_strip = formulas[int(line_idx)].strip() 56 | tokens = line_strip.split() 57 | tokens_out = [] 58 | for token in tokens: 59 | tokens_out.append(token) 60 | if token not in vocab: 61 | vocab[token] = 0 62 | vocab[token] += 1 63 | 64 | vocab_sort = sorted(list(vocab.keys())) 65 | vocab_out = [] 66 | num_unknown = 0 67 | for word in vocab_sort: 68 | if vocab[word] > parameters.unk_threshold: 69 | vocab_out.append(word) 70 | else: 71 | num_unknown += 1 72 | #vocab = ["'"+word.replace('\\','\\\\').replace('\'', '\\\'')+"'" for word in vocab_out] 73 | vocab = [word for word in vocab_out] 74 | 75 | with open(parameters.output_file, 'w') as fout: 76 | fout.write('\n'.join(vocab)) 77 | logging.info('#UNK\'s: %d'%num_unknown) 78 | 79 | if __name__ == '__main__': 80 | main(sys.argv[1:]) 81 | logging.info('Jobs finished') 82 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/preprocessing/preprocess_formulas.py: -------------------------------------------------------------------------------- 1 | # taken and modified from https://github.com/harvardnlp/im2markup 2 | # tokenize latex formulas 3 | import sys 4 | import os 5 | import re 6 | import argparse 7 | import logging 8 | import subprocess 9 | import shutil 10 | 11 | 12 | def process_args(args): 13 | parser = argparse.ArgumentParser(description='Preprocess (tokenize or normalize) latex formulas') 14 | 15 | parser.add_argument('--mode', '-m', dest='mode', 16 | choices=['tokenize', 'normalize'], default='normalize', 17 | help=('Tokenize (split to tokens seperated by space) or normalize (further translate to an equivalent standard form).' 18 | )) 19 | parser.add_argument('--input-file', '-i', dest='input_file', 20 | type=str, required=True, 21 | help=('Input file containing latex formulas. One formula per line.' 22 | )) 23 | parser.add_argument('--output-file', '-o', dest='output_file', 24 | type=str, required=True, 25 | help=('Output file.' 26 | )) 27 | parser.add_argument('-n', '--num-threads', dest='num_threads', 28 | type=int, default=4, 29 | help=('Number of threads, default=4.')) 30 | parser.add_argument('--log-path', dest="log_path", 31 | type=str, default=None, 32 | help=('Log file path, default=log.txt')) 33 | parameters = parser.parse_args(args) 34 | return parameters 35 | 36 | 37 | def main(args): 38 | parameters = process_args(args) 39 | logging.basicConfig( 40 | level=logging.INFO, 41 | format='%(asctime)-15s %(name)-5s %(levelname)-8s %(message)s', 42 | filename=parameters.log_path) 43 | 44 | console = logging.StreamHandler() 45 | console.setLevel(logging.INFO) 46 | formatter = logging.Formatter('%(asctime)-15s %(name)-5s %(levelname)-8s %(message)s') 47 | console.setFormatter(formatter) 48 | logging.getLogger('').addHandler(console) 49 | 50 | logging.info('Script being executed: %s' % __file__) 51 | 52 | input_file = parameters.input_file 53 | output_file = parameters.output_file 54 | 55 | assert os.path.exists(input_file), input_file 56 | shutil.copy(input_file, output_file) 57 | operators = '\s?'.join('|'.join(['arccos', 'arcsin', 'arctan', 'arg', 'cos', 'cosh', 'cot', 'coth', 'csc', 'deg', 'det', 'dim', 'exp', 'gcd', 'hom', 'inf', 58 | 'injlim', 'ker', 'lg', 'lim', 'liminf', 'limsup', 'ln', 'log', 'max', 'min', 'Pr', 'projlim', 'sec', 'sin', 'sinh', 'sup', 'tan', 'tanh'])) 59 | ops = re.compile(r'\\operatorname {(%s)}' % operators) 60 | temp_file = output_file + '.tmp' 61 | with open(temp_file, 'w') as fout: 62 | prepre = open(output_file, 'r').read().replace('\r', ' ') # delete \r 63 | # replace split, align with aligned 64 | prepre = re.sub(r'\\begin{(split|align|alignedat|alignat|eqnarray)\*?}(.+?)\\end{\1\*?}', r'\\begin{aligned}\2\\end{aligned}', prepre, flags=re.S) 65 | prepre = re.sub(r'\\begin{(smallmatrix)\*?}(.+?)\\end{\1\*?}', r'\\begin{matrix}\2\\end{matrix}', prepre, flags=re.S) 66 | fout.write(prepre) 67 | 68 | # print(os.path.abspath(__file__)) 69 | cmd = r"cat %s | node %s %s > %s " % (temp_file, os.path.join(os.path.dirname(__file__), 'preprocess_latex.js'), parameters.mode, output_file) 70 | ret = subprocess.call(cmd, shell=True) 71 | os.remove(temp_file) 72 | if ret != 0: 73 | logging.error('FAILED: %s' % cmd) 74 | temp_file = output_file + '.tmp' 75 | shutil.move(output_file, temp_file) 76 | with open(temp_file, 'r') as fin: 77 | with open(output_file, 'w') as fout: 78 | for line in fin: 79 | tokens = line.strip().split() 80 | tokens_out = [] 81 | for token in tokens: 82 | tokens_out.append(token) 83 | if len(tokens_out) > 5: 84 | post = ' '.join(tokens_out) 85 | # use \sin instead of \operatorname{sin} 86 | names = ['\\'+x.replace(' ', '') for x in re.findall(ops, post)] 87 | post = re.sub(ops, lambda match: str(names.pop(0)), post).replace(r'\\ \end{array}', r'\end{array}') 88 | fout.write(post+'\n') 89 | os.remove(temp_file) 90 | 91 | 92 | if __name__ == '__main__': 93 | main(sys.argv[1:]) 94 | logging.info('Jobs finished') 95 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/preprocessing/third_party/README.md: -------------------------------------------------------------------------------- 1 | Directly taken from https://github.com/harvardnlp/im2markup 2 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/preprocessing/third_party/katex/LICENSE.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015 Khan Academy 4 | 5 | This software also uses portions of the underscore.js project, which is 6 | MIT licensed with the following copyright: 7 | 8 | Copyright (c) 2009-2015 Jeremy Ashkenas, DocumentCloud and Investigative 9 | Reporters & Editors 10 | 11 | Permission is hereby granted, free of charge, to any person obtaining a copy 12 | of this software and associated documentation files (the "Software"), to deal 13 | in the Software without restriction, including without limitation the rights 14 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | copies of the Software, and to permit persons to whom the Software is 16 | furnished to do so, subject to the following conditions: 17 | 18 | The above copyright notice and this permission notice shall be included in all 19 | copies or substantial portions of the Software. 20 | 21 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 27 | SOFTWARE. 28 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/preprocessing/third_party/katex/README.md: -------------------------------------------------------------------------------- 1 | # [KaTeX](https://khan.github.io/KaTeX/) [![Build Status](https://travis-ci.org/Khan/KaTeX.svg?branch=master)](https://travis-ci.org/Khan/KaTeX) 2 | 3 | [![Join the chat at https://gitter.im/Khan/KaTeX](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/Khan/KaTeX?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) 4 | 5 | KaTeX is a fast, easy-to-use JavaScript library for TeX math rendering on the web. 6 | 7 | * **Fast:** KaTeX renders its math synchronously and doesn't need to reflow the page. See how it compares to a competitor in [this speed test](http://jsperf.com/katex-vs-mathjax/). 8 | * **Print quality:** KaTeX’s layout is based on Donald Knuth’s TeX, the gold standard for math typesetting. 9 | * **Self contained:** KaTeX has no dependencies and can easily be bundled with your website resources. 10 | * **Server side rendering:** KaTeX produces the same output regardless of browser or environment, so you can pre-render expressions using Node.js and send them as plain HTML. 11 | 12 | KaTeX supports all major browsers, including Chrome, Safari, Firefox, Opera, and IE 8 - IE 11. A list of supported commands can be on the [wiki](https://github.com/Khan/KaTeX/wiki/Function-Support-in-KaTeX). 13 | 14 | ## Usage 15 | 16 | You can [download KaTeX](https://github.com/khan/katex/releases) and host it on your server or include the `katex.min.js` and `katex.min.css` files on your page directly from a CDN: 17 | 18 | ```html 19 | 20 | 21 | ``` 22 | 23 | #### In-browser rendering 24 | 25 | Call `katex.render` with a TeX expression and a DOM element to render into: 26 | 27 | ```js 28 | katex.render("c = \\pm\\sqrt{a^2 + b^2}", element); 29 | ``` 30 | 31 | If KaTeX can't parse the expression, it throws a `katex.ParseError` error. 32 | 33 | #### Server side rendering or rendering to a string 34 | 35 | To generate HTML on the server or to generate an HTML string of the rendered math, you can use `katex.renderToString`: 36 | 37 | ```js 38 | var html = katex.renderToString("c = \\pm\\sqrt{a^2 + b^2}"); 39 | // '...' 40 | ``` 41 | 42 | Make sure to include the CSS and font files, but there is no need to include the JavaScript. Like `render`, `renderToString` throws if it can't parse the expression. 43 | 44 | #### Rendering options 45 | 46 | You can provide an object of options as the last argument to `katex.render` and `katex.renderToString`. Available options are: 47 | 48 | - `displayMode`: `boolean`. If `true` the math will be rendered in display mode, which will put the math in display style (so `\int` and `\sum` are large, for example), and will center the math on the page on its own line. If `false` the math will be rendered in inline mode. (default: `false`) 49 | - `throwOnError`: `boolean`. If `true`, KaTeX will throw a `ParseError` when it encounters an unsupported command. If `false`, KaTeX will render the unsupported command as text in the color given by `errorColor`. (default: `true`) 50 | - `errorColor`: `string`. A color string given in the format `"#XXX"` or `"#XXXXXX"`. This option determines the color which unsupported commands are rendered in. (default: `#cc0000`) 51 | 52 | For example: 53 | 54 | ```js 55 | katex.render("c = \\pm\\sqrt{a^2 + b^2}", element, { displayMode: true }); 56 | ``` 57 | 58 | #### Automatic rendering of math on a page 59 | 60 | Math on the page can be automatically rendered using the auto-render extension. See [the Auto-render README](contrib/auto-render/README.md) for more information. 61 | 62 | ## Contributing 63 | 64 | See [CONTRIBUTING.md](CONTRIBUTING.md) 65 | 66 | ## License 67 | 68 | KaTeX is licensed under the [MIT License](http://opensource.org/licenses/MIT). 69 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/preprocessing/third_party/katex/cli.js: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env node 2 | // Simple CLI for KaTeX. 3 | // Reads TeX from stdin, outputs HTML to stdout. 4 | /* eslint no-console:0 */ 5 | 6 | var katex = require("./"); 7 | var input = ""; 8 | 9 | // Skip the first two args, which are just "node" and "cli.js" 10 | var args = process.argv.slice(2); 11 | 12 | if (args.indexOf("--help") !== -1) { 13 | console.log(process.argv[0] + " " + process.argv[1] + 14 | " [ --help ]" + 15 | " [ --display-mode ]"); 16 | 17 | console.log("\n" + 18 | "Options:"); 19 | console.log(" --help Display this help message"); 20 | console.log(" --display-mode Render in display mode (not inline mode)"); 21 | process.exit(); 22 | } 23 | 24 | process.stdin.on("data", function(chunk) { 25 | input += chunk.toString(); 26 | }); 27 | 28 | process.stdin.on("end", function() { 29 | var options = { displayMode: args.indexOf("--display-mode") !== -1 }; 30 | var output = katex.renderToString(input, options); 31 | console.log(output); 32 | }); 33 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/preprocessing/third_party/katex/katex.js: -------------------------------------------------------------------------------- 1 | /* eslint no-console:0 */ 2 | /** 3 | * This is the main entry point for KaTeX. Here, we expose functions for 4 | * rendering expressions either to DOM nodes or to markup strings. 5 | * 6 | * We also expose the ParseError class to check if errors thrown from KaTeX are 7 | * errors in the expression, or errors in javascript handling. 8 | */ 9 | 10 | var ParseError = require("./src/ParseError"); 11 | var Settings = require("./src/Settings"); 12 | 13 | var buildTree = require("./src/buildTree"); 14 | var parseTree = require("./src/parseTree"); 15 | var utils = require("./src/utils"); 16 | 17 | /** 18 | * Parse and build an expression, and place that expression in the DOM node 19 | * given. 20 | */ 21 | var render = function(expression, baseNode, options) { 22 | utils.clearNode(baseNode); 23 | 24 | var settings = new Settings(options); 25 | 26 | var tree = parseTree(expression, settings); 27 | var node = buildTree(tree, expression, settings).toNode(); 28 | 29 | baseNode.appendChild(node); 30 | }; 31 | 32 | // KaTeX's styles don't work properly in quirks mode. Print out an error, and 33 | // disable rendering. 34 | if (typeof document !== "undefined") { 35 | if (document.compatMode !== "CSS1Compat") { 36 | typeof console !== "undefined" && console.warn( 37 | "Warning: KaTeX doesn't work in quirks mode. Make sure your " + 38 | "website has a suitable doctype."); 39 | 40 | render = function() { 41 | throw new ParseError("KaTeX doesn't work in quirks mode."); 42 | }; 43 | } 44 | } 45 | 46 | /** 47 | * Parse and build an expression, and return the markup for that. 48 | */ 49 | var renderToString = function(expression, options) { 50 | var settings = new Settings(options); 51 | 52 | var tree = parseTree(expression, settings); 53 | return buildTree(tree, expression, settings).toMarkup(); 54 | }; 55 | 56 | /** 57 | * Parse an expression and return the parse tree. 58 | */ 59 | var generateParseTree = function(expression, options) { 60 | var settings = new Settings(options); 61 | return parseTree(expression, settings); 62 | }; 63 | 64 | module.exports = { 65 | render: render, 66 | renderToString: renderToString, 67 | /** 68 | * NOTE: This method is not currently recommended for public use. 69 | * The internal tree representation is unstable and is very likely 70 | * to change. Use at your own risk. 71 | */ 72 | __parse: generateParseTree, 73 | ParseError: ParseError, 74 | }; 75 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/preprocessing/third_party/katex/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "_args": [ 3 | [ 4 | "katex", 5 | "/home/srush/Projects/im2latex" 6 | ] 7 | ], 8 | "_from": "katex@latest", 9 | "_id": "katex@0.6.0", 10 | "_inCache": true, 11 | "_installable": true, 12 | "_location": "/katex", 13 | "_nodeVersion": "4.2.1", 14 | "_npmOperationalInternal": { 15 | "host": "packages-12-west.internal.npmjs.com", 16 | "tmp": "tmp/katex-0.6.0.tgz_1460769444991_0.38667152682319283" 17 | }, 18 | "_npmUser": { 19 | "email": "kevinb7@gmail.com", 20 | "name": "kevinbarabash" 21 | }, 22 | "_npmVersion": "2.15.2", 23 | "_phantomChildren": {}, 24 | "_requested": { 25 | "name": "katex", 26 | "raw": "katex", 27 | "rawSpec": "", 28 | "scope": null, 29 | "spec": "latest", 30 | "type": "tag" 31 | }, 32 | "_requiredBy": [ 33 | "#USER" 34 | ], 35 | "_resolved": "https://registry.npmjs.org/katex/-/katex-0.6.0.tgz", 36 | "_shasum": "12418e09121c05c92041b6b3b9fb6bab213cb6f3", 37 | "_shrinkwrap": null, 38 | "_spec": "katex", 39 | "_where": "/home/srush/Projects/im2latex", 40 | "bin": { 41 | "katex": "cli.js" 42 | }, 43 | "bugs": { 44 | "url": "https://github.com/Khan/KaTeX/issues" 45 | }, 46 | "dependencies": { 47 | "match-at": "^0.1.0" 48 | }, 49 | "description": "Fast math typesetting for the web.", 50 | "devDependencies": { 51 | "browserify": "^10.2.4", 52 | "clean-css": "~2.2.15", 53 | "eslint": "^1.10.2", 54 | "express": "~3.3.3", 55 | "glob": "^5.0.15", 56 | "jasmine": "^2.3.2", 57 | "jasmine-core": "^2.3.4", 58 | "js-yaml": "^3.3.1", 59 | "jspngopt": "^0.1.0", 60 | "less": "~1.7.5", 61 | "nomnom": "^1.8.1", 62 | "pako": "0.2.7", 63 | "selenium-webdriver": "^2.46.1", 64 | "uglify-js": "~2.4.15" 65 | }, 66 | "directories": {}, 67 | "dist": { 68 | "shasum": "12418e09121c05c92041b6b3b9fb6bab213cb6f3", 69 | "tarball": "https://registry.npmjs.org/katex/-/katex-0.6.0.tgz" 70 | }, 71 | "files": [ 72 | "cli.js", 73 | "dist/", 74 | "katex.js", 75 | "src/" 76 | ], 77 | "gitHead": "b94fc6534d5c23f944906a52a592bee4e0090665", 78 | "homepage": "https://github.com/Khan/KaTeX#readme", 79 | "license": "MIT", 80 | "main": "katex.js", 81 | "maintainers": [ 82 | { 83 | "name": "kevinbarabash", 84 | "email": "kevinb7@gmail.com" 85 | }, 86 | { 87 | "name": "spicyj", 88 | "email": "ben@benalpert.com" 89 | }, 90 | { 91 | "name": "xymostech", 92 | "email": "xymostech@gmail.com" 93 | } 94 | ], 95 | "name": "katex", 96 | "optionalDependencies": {}, 97 | "readme": "ERROR: No README data found!", 98 | "repository": { 99 | "type": "git", 100 | "url": "git://github.com/Khan/KaTeX.git" 101 | }, 102 | "scripts": { 103 | "prepublish": "make dist", 104 | "start": "node server.js", 105 | "test": "make lint test" 106 | }, 107 | "version": "0.6.0" 108 | } 109 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/preprocessing/third_party/katex/src/Lexer.js: -------------------------------------------------------------------------------- 1 | /** 2 | * The Lexer class handles tokenizing the input in various ways. Since our 3 | * parser expects us to be able to backtrack, the lexer allows lexing from any 4 | * given starting point. 5 | * 6 | * Its main exposed function is the `lex` function, which takes a position to 7 | * lex from and a type of token to lex. It defers to the appropriate `_innerLex` 8 | * function. 9 | * 10 | * The various `_innerLex` functions perform the actual lexing of different 11 | * kinds. 12 | */ 13 | 14 | var matchAt = require("../../match-at"); 15 | 16 | var ParseError = require("./ParseError"); 17 | 18 | // The main lexer class 19 | function Lexer(input) { 20 | this._input = input; 21 | } 22 | 23 | // The resulting token returned from `lex`. 24 | function Token(text, data, position) { 25 | this.text = text; 26 | this.data = data; 27 | this.position = position; 28 | } 29 | 30 | /* The following tokenRegex 31 | * - matches typical whitespace (but not NBSP etc.) using its first group 32 | * - matches symbol combinations which result in a single output character 33 | * - does not match any control character \x00-\x1f except whitespace 34 | * - does not match a bare backslash 35 | * - matches any ASCII character except those just mentioned 36 | * - does not match the BMP private use area \uE000-\uF8FF 37 | * - does not match bare surrogate code units 38 | * - matches any BMP character except for those just described 39 | * - matches any valid Unicode surrogate pair 40 | * - matches a backslash followed by one or more letters 41 | * - matches a backslash followed by any BMP character, including newline 42 | * Just because the Lexer matches something doesn't mean it's valid input: 43 | * If there is no matching function or symbol definition, the Parser will 44 | * still reject the input. 45 | */ 46 | var tokenRegex = new RegExp( 47 | "([ \r\n\t]+)|(" + // whitespace 48 | "---?" + // special combinations 49 | "|[!-\\[\\]-\u2027\u202A-\uD7FF\uF900-\uFFFF]" + // single codepoint 50 | "|[\uD800-\uDBFF][\uDC00-\uDFFF]" + // surrogate pair 51 | "|\\\\(?:[a-zA-Z]+|[^\uD800-\uDFFF])" + // function name 52 | ")" 53 | ); 54 | 55 | var whitespaceRegex = /\s*/; 56 | 57 | /** 58 | * This function lexes a single normal token. It takes a position and 59 | * whether it should completely ignore whitespace or not. 60 | */ 61 | Lexer.prototype._innerLex = function(pos, ignoreWhitespace) { 62 | var input = this._input; 63 | if (pos === input.length) { 64 | return new Token("EOF", null, pos); 65 | } 66 | var match = matchAt(tokenRegex, input, pos); 67 | if (match === null) { 68 | throw new ParseError( 69 | "Unexpected character: '" + input[pos] + "'", 70 | this, pos); 71 | } else if (match[2]) { // matched non-whitespace 72 | return new Token(match[2], null, pos + match[2].length); 73 | } else if (ignoreWhitespace) { 74 | return this._innerLex(pos + match[1].length, true); 75 | } else { // concatenate whitespace to a single space 76 | return new Token(" ", null, pos + match[1].length); 77 | } 78 | }; 79 | 80 | // A regex to match a CSS color (like #ffffff or BlueViolet) 81 | var cssColor = /#[a-z0-9]+|[a-z]+/i; 82 | 83 | /** 84 | * This function lexes a CSS color. 85 | */ 86 | Lexer.prototype._innerLexColor = function(pos) { 87 | var input = this._input; 88 | 89 | // Ignore whitespace 90 | var whitespace = matchAt(whitespaceRegex, input, pos)[0]; 91 | pos += whitespace.length; 92 | 93 | var match; 94 | if ((match = matchAt(cssColor, input, pos))) { 95 | // If we look like a color, return a color 96 | return new Token(match[0], null, pos + match[0].length); 97 | } else { 98 | throw new ParseError("Invalid color", this, pos); 99 | } 100 | }; 101 | 102 | // A regex to match a dimension. Dimensions look like 103 | // "1.2em" or ".4pt" or "1 ex" 104 | var sizeRegex = /(-?)\s*(\d+(?:\.\d*)?|\.\d+)\s*([a-z]{2})/; 105 | 106 | /** 107 | * This function lexes a dimension. 108 | */ 109 | Lexer.prototype._innerLexSize = function(pos) { 110 | var input = this._input; 111 | 112 | // Ignore whitespace 113 | var whitespace = matchAt(whitespaceRegex, input, pos)[0]; 114 | pos += whitespace.length; 115 | 116 | var match; 117 | if ((match = matchAt(sizeRegex, input, pos))) { 118 | var unit = match[3]; 119 | // We only currently handle "em" and "ex" units 120 | // if (unit !== "em" && unit !== "ex") { 121 | // throw new ParseError("Invalid unit: '" + unit + "'", this, pos); 122 | // } 123 | return new Token(match[0], { 124 | number: +(match[1] + match[2]), 125 | unit: unit, 126 | }, pos + match[0].length); 127 | } 128 | 129 | throw new ParseError("Invalid size", this, pos); 130 | }; 131 | 132 | /** 133 | * This function lexes a string of whitespace. 134 | */ 135 | Lexer.prototype._innerLexWhitespace = function(pos) { 136 | var input = this._input; 137 | 138 | var whitespace = matchAt(whitespaceRegex, input, pos)[0]; 139 | pos += whitespace.length; 140 | 141 | return new Token(whitespace[0], null, pos); 142 | }; 143 | 144 | /** 145 | * This function lexes a single token starting at `pos` and of the given mode. 146 | * Based on the mode, we defer to one of the `_innerLex` functions. 147 | */ 148 | Lexer.prototype.lex = function(pos, mode) { 149 | if (mode === "math") { 150 | return this._innerLex(pos, true); 151 | } else if (mode === "text") { 152 | return this._innerLex(pos, false); 153 | } else if (mode === "color") { 154 | return this._innerLexColor(pos); 155 | } else if (mode === "size") { 156 | return this._innerLexSize(pos); 157 | } else if (mode === "whitespace") { 158 | return this._innerLexWhitespace(pos); 159 | } 160 | }; 161 | 162 | module.exports = Lexer; 163 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/preprocessing/third_party/katex/src/Options.js: -------------------------------------------------------------------------------- 1 | /** 2 | * This file contains information about the options that the Parser carries 3 | * around with it while parsing. Data is held in an `Options` object, and when 4 | * recursing, a new `Options` object can be created with the `.with*` and 5 | * `.reset` functions. 6 | */ 7 | 8 | /** 9 | * This is the main options class. It contains the style, size, color, and font 10 | * of the current parse level. It also contains the style and size of the parent 11 | * parse level, so size changes can be handled efficiently. 12 | * 13 | * Each of the `.with*` and `.reset` functions passes its current style and size 14 | * as the parentStyle and parentSize of the new options class, so parent 15 | * handling is taken care of automatically. 16 | */ 17 | function Options(data) { 18 | this.style = data.style; 19 | this.color = data.color; 20 | this.size = data.size; 21 | this.phantom = data.phantom; 22 | this.font = data.font; 23 | 24 | if (data.parentStyle === undefined) { 25 | this.parentStyle = data.style; 26 | } else { 27 | this.parentStyle = data.parentStyle; 28 | } 29 | 30 | if (data.parentSize === undefined) { 31 | this.parentSize = data.size; 32 | } else { 33 | this.parentSize = data.parentSize; 34 | } 35 | } 36 | 37 | /** 38 | * Returns a new options object with the same properties as "this". Properties 39 | * from "extension" will be copied to the new options object. 40 | */ 41 | Options.prototype.extend = function(extension) { 42 | var data = { 43 | style: this.style, 44 | size: this.size, 45 | color: this.color, 46 | parentStyle: this.style, 47 | parentSize: this.size, 48 | phantom: this.phantom, 49 | font: this.font, 50 | }; 51 | 52 | for (var key in extension) { 53 | if (extension.hasOwnProperty(key)) { 54 | data[key] = extension[key]; 55 | } 56 | } 57 | 58 | return new Options(data); 59 | }; 60 | 61 | /** 62 | * Create a new options object with the given style. 63 | */ 64 | Options.prototype.withStyle = function(style) { 65 | return this.extend({ 66 | style: style, 67 | }); 68 | }; 69 | 70 | /** 71 | * Create a new options object with the given size. 72 | */ 73 | Options.prototype.withSize = function(size) { 74 | return this.extend({ 75 | size: size, 76 | }); 77 | }; 78 | 79 | /** 80 | * Create a new options object with the given color. 81 | */ 82 | Options.prototype.withColor = function(color) { 83 | return this.extend({ 84 | color: color, 85 | }); 86 | }; 87 | 88 | /** 89 | * Create a new options object with "phantom" set to true. 90 | */ 91 | Options.prototype.withPhantom = function() { 92 | return this.extend({ 93 | phantom: true, 94 | }); 95 | }; 96 | 97 | /** 98 | * Create a new options objects with the give font. 99 | */ 100 | Options.prototype.withFont = function(font) { 101 | return this.extend({ 102 | font: font, 103 | }); 104 | }; 105 | 106 | /** 107 | * Create a new options object with the same style, size, and color. This is 108 | * used so that parent style and size changes are handled correctly. 109 | */ 110 | Options.prototype.reset = function() { 111 | return this.extend({}); 112 | }; 113 | 114 | /** 115 | * A map of color names to CSS colors. 116 | * TODO(emily): Remove this when we have real macros 117 | */ 118 | var colorMap = { 119 | "katex-blue": "#6495ed", 120 | "katex-orange": "#ffa500", 121 | "katex-pink": "#ff00af", 122 | "katex-red": "#df0030", 123 | "katex-green": "#28ae7b", 124 | "katex-gray": "gray", 125 | "katex-purple": "#9d38bd", 126 | "katex-blueA": "#c7e9f1", 127 | "katex-blueB": "#9cdceb", 128 | "katex-blueC": "#58c4dd", 129 | "katex-blueD": "#29abca", 130 | "katex-blueE": "#1c758a", 131 | "katex-tealA": "#acead7", 132 | "katex-tealB": "#76ddc0", 133 | "katex-tealC": "#5cd0b3", 134 | "katex-tealD": "#55c1a7", 135 | "katex-tealE": "#49a88f", 136 | "katex-greenA": "#c9e2ae", 137 | "katex-greenB": "#a6cf8c", 138 | "katex-greenC": "#83c167", 139 | "katex-greenD": "#77b05d", 140 | "katex-greenE": "#699c52", 141 | "katex-goldA": "#f7c797", 142 | "katex-goldB": "#f9b775", 143 | "katex-goldC": "#f0ac5f", 144 | "katex-goldD": "#e1a158", 145 | "katex-goldE": "#c78d46", 146 | "katex-redA": "#f7a1a3", 147 | "katex-redB": "#ff8080", 148 | "katex-redC": "#fc6255", 149 | "katex-redD": "#e65a4c", 150 | "katex-redE": "#cf5044", 151 | "katex-maroonA": "#ecabc1", 152 | "katex-maroonB": "#ec92ab", 153 | "katex-maroonC": "#c55f73", 154 | "katex-maroonD": "#a24d61", 155 | "katex-maroonE": "#94424f", 156 | "katex-purpleA": "#caa3e8", 157 | "katex-purpleB": "#b189c6", 158 | "katex-purpleC": "#9a72ac", 159 | "katex-purpleD": "#715582", 160 | "katex-purpleE": "#644172", 161 | "katex-mintA": "#f5f9e8", 162 | "katex-mintB": "#edf2df", 163 | "katex-mintC": "#e0e5cc", 164 | "katex-grayA": "#fdfdfd", 165 | "katex-grayB": "#f7f7f7", 166 | "katex-grayC": "#eeeeee", 167 | "katex-grayD": "#dddddd", 168 | "katex-grayE": "#cccccc", 169 | "katex-grayF": "#aaaaaa", 170 | "katex-grayG": "#999999", 171 | "katex-grayH": "#555555", 172 | "katex-grayI": "#333333", 173 | "katex-kaBlue": "#314453", 174 | "katex-kaGreen": "#639b24", 175 | }; 176 | 177 | /** 178 | * Gets the CSS color of the current options object, accounting for the 179 | * `colorMap`. 180 | */ 181 | Options.prototype.getColor = function() { 182 | if (this.phantom) { 183 | return "transparent"; 184 | } else { 185 | return colorMap[this.color] || this.color; 186 | } 187 | }; 188 | 189 | module.exports = Options; 190 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/preprocessing/third_party/katex/src/ParseError.js: -------------------------------------------------------------------------------- 1 | /** 2 | * This is the ParseError class, which is the main error thrown by KaTeX 3 | * functions when something has gone wrong. This is used to distinguish internal 4 | * errors from errors in the expression that the user provided. 5 | */ 6 | function ParseError(message, lexer, position) { 7 | var error = "KaTeX parse error: " + message; 8 | 9 | if (lexer !== undefined && position !== undefined) { 10 | // If we have the input and a position, make the error a bit fancier 11 | 12 | // Prepend some information 13 | error += " at position " + position + ": "; 14 | 15 | // Get the input 16 | var input = lexer._input; 17 | // Insert a combining underscore at the correct position 18 | input = input.slice(0, position) + "\u0332" + 19 | input.slice(position); 20 | 21 | // Extract some context from the input and add it to the error 22 | var begin = Math.max(0, position - 15); 23 | var end = position + 15; 24 | error += input.slice(begin, end); 25 | } 26 | 27 | // Some hackery to make ParseError a prototype of Error 28 | // See http://stackoverflow.com/a/8460753 29 | var self = new Error(error); 30 | self.name = "ParseError"; 31 | self.__proto__ = ParseError.prototype; 32 | 33 | self.position = position; 34 | return self; 35 | } 36 | 37 | // More hackery 38 | ParseError.prototype.__proto__ = Error.prototype; 39 | 40 | module.exports = ParseError; 41 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/preprocessing/third_party/katex/src/Settings.js: -------------------------------------------------------------------------------- 1 | /** 2 | * This is a module for storing settings passed into KaTeX. It correctly handles 3 | * default settings. 4 | */ 5 | 6 | /** 7 | * Helper function for getting a default value if the value is undefined 8 | */ 9 | function get(option, defaultValue) { 10 | return option === undefined ? defaultValue : option; 11 | } 12 | 13 | /** 14 | * The main Settings object 15 | * 16 | * The current options stored are: 17 | * - displayMode: Whether the expression should be typeset by default in 18 | * textstyle or displaystyle (default false) 19 | */ 20 | function Settings(options) { 21 | // allow null options 22 | options = options || {}; 23 | this.displayMode = get(options.displayMode, false); 24 | this.throwOnError = get(options.throwOnError, true); 25 | this.errorColor = get(options.errorColor, "#cc0000"); 26 | } 27 | 28 | module.exports = Settings; 29 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/preprocessing/third_party/katex/src/Style.js: -------------------------------------------------------------------------------- 1 | /** 2 | * This file contains information and classes for the various kinds of styles 3 | * used in TeX. It provides a generic `Style` class, which holds information 4 | * about a specific style. It then provides instances of all the different kinds 5 | * of styles possible, and provides functions to move between them and get 6 | * information about them. 7 | */ 8 | 9 | /** 10 | * The main style class. Contains a unique id for the style, a size (which is 11 | * the same for cramped and uncramped version of a style), a cramped flag, and a 12 | * size multiplier, which gives the size difference between a style and 13 | * textstyle. 14 | */ 15 | function Style(id, size, multiplier, cramped) { 16 | this.id = id; 17 | this.size = size; 18 | this.cramped = cramped; 19 | this.sizeMultiplier = multiplier; 20 | } 21 | 22 | /** 23 | * Get the style of a superscript given a base in the current style. 24 | */ 25 | Style.prototype.sup = function() { 26 | return styles[sup[this.id]]; 27 | }; 28 | 29 | /** 30 | * Get the style of a subscript given a base in the current style. 31 | */ 32 | Style.prototype.sub = function() { 33 | return styles[sub[this.id]]; 34 | }; 35 | 36 | /** 37 | * Get the style of a fraction numerator given the fraction in the current 38 | * style. 39 | */ 40 | Style.prototype.fracNum = function() { 41 | return styles[fracNum[this.id]]; 42 | }; 43 | 44 | /** 45 | * Get the style of a fraction denominator given the fraction in the current 46 | * style. 47 | */ 48 | Style.prototype.fracDen = function() { 49 | return styles[fracDen[this.id]]; 50 | }; 51 | 52 | /** 53 | * Get the cramped version of a style (in particular, cramping a cramped style 54 | * doesn't change the style). 55 | */ 56 | Style.prototype.cramp = function() { 57 | return styles[cramp[this.id]]; 58 | }; 59 | 60 | /** 61 | * HTML class name, like "displaystyle cramped" 62 | */ 63 | Style.prototype.cls = function() { 64 | return sizeNames[this.size] + (this.cramped ? " cramped" : " uncramped"); 65 | }; 66 | 67 | /** 68 | * HTML Reset class name, like "reset-textstyle" 69 | */ 70 | Style.prototype.reset = function() { 71 | return resetNames[this.size]; 72 | }; 73 | 74 | // IDs of the different styles 75 | var D = 0; 76 | var Dc = 1; 77 | var T = 2; 78 | var Tc = 3; 79 | var S = 4; 80 | var Sc = 5; 81 | var SS = 6; 82 | var SSc = 7; 83 | 84 | // String names for the different sizes 85 | var sizeNames = [ 86 | "displaystyle textstyle", 87 | "textstyle", 88 | "scriptstyle", 89 | "scriptscriptstyle", 90 | ]; 91 | 92 | // Reset names for the different sizes 93 | var resetNames = [ 94 | "reset-textstyle", 95 | "reset-textstyle", 96 | "reset-scriptstyle", 97 | "reset-scriptscriptstyle", 98 | ]; 99 | 100 | // Instances of the different styles 101 | var styles = [ 102 | new Style(D, 0, 1.0, false), 103 | new Style(Dc, 0, 1.0, true), 104 | new Style(T, 1, 1.0, false), 105 | new Style(Tc, 1, 1.0, true), 106 | new Style(S, 2, 0.7, false), 107 | new Style(Sc, 2, 0.7, true), 108 | new Style(SS, 3, 0.5, false), 109 | new Style(SSc, 3, 0.5, true), 110 | ]; 111 | 112 | // Lookup tables for switching from one style to another 113 | var sup = [S, Sc, S, Sc, SS, SSc, SS, SSc]; 114 | var sub = [Sc, Sc, Sc, Sc, SSc, SSc, SSc, SSc]; 115 | var fracNum = [T, Tc, S, Sc, SS, SSc, SS, SSc]; 116 | var fracDen = [Tc, Tc, Sc, Sc, SSc, SSc, SSc, SSc]; 117 | var cramp = [Dc, Dc, Tc, Tc, Sc, Sc, SSc, SSc]; 118 | 119 | // We only export some of the styles. Also, we don't export the `Style` class so 120 | // no more styles can be generated. 121 | module.exports = { 122 | DISPLAY: styles[D], 123 | TEXT: styles[T], 124 | SCRIPT: styles[S], 125 | SCRIPTSCRIPT: styles[SS], 126 | }; 127 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/preprocessing/third_party/katex/src/buildTree.js: -------------------------------------------------------------------------------- 1 | var buildHTML = require("./buildHTML"); 2 | var buildMathML = require("./buildMathML"); 3 | var buildCommon = require("./buildCommon"); 4 | var Options = require("./Options"); 5 | var Settings = require("./Settings"); 6 | var Style = require("./Style"); 7 | 8 | var makeSpan = buildCommon.makeSpan; 9 | 10 | var buildTree = function(tree, expression, settings) { 11 | settings = settings || new Settings({}); 12 | 13 | var startStyle = Style.TEXT; 14 | if (settings.displayMode) { 15 | startStyle = Style.DISPLAY; 16 | } 17 | 18 | // Setup the default options 19 | var options = new Options({ 20 | style: startStyle, 21 | size: "size5", 22 | }); 23 | 24 | // `buildHTML` sometimes messes with the parse tree (like turning bins -> 25 | // ords), so we build the MathML version first. 26 | var mathMLNode = buildMathML(tree, expression, options); 27 | var htmlNode = buildHTML(tree, options); 28 | 29 | var katexNode = makeSpan(["katex"], [ 30 | mathMLNode, htmlNode, 31 | ]); 32 | 33 | if (settings.displayMode) { 34 | return makeSpan(["katex-display"], [katexNode]); 35 | } else { 36 | return katexNode; 37 | } 38 | }; 39 | 40 | module.exports = buildTree; 41 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/preprocessing/third_party/katex/src/fontMetrics.js: -------------------------------------------------------------------------------- 1 | /* eslint no-unused-vars:0 */ 2 | 3 | var Style = require("./Style"); 4 | 5 | /** 6 | * This file contains metrics regarding fonts and individual symbols. The sigma 7 | * and xi variables, as well as the metricMap map contain data extracted from 8 | * TeX, TeX font metrics, and the TTF files. These data are then exposed via the 9 | * `metrics` variable and the getCharacterMetrics function. 10 | */ 11 | 12 | // These font metrics are extracted from TeX by using 13 | // \font\a=cmmi10 14 | // \showthe\fontdimenX\a 15 | // where X is the corresponding variable number. These correspond to the font 16 | // parameters of the symbol fonts. In TeX, there are actually three sets of 17 | // dimensions, one for each of textstyle, scriptstyle, and scriptscriptstyle, 18 | // but we only use the textstyle ones, and scale certain dimensions accordingly. 19 | // See the TeXbook, page 441. 20 | var sigma1 = 0.025; 21 | var sigma2 = 0; 22 | var sigma3 = 0; 23 | var sigma4 = 0; 24 | var sigma5 = 0.431; 25 | var sigma6 = 1; 26 | var sigma7 = 0; 27 | var sigma8 = 0.677; 28 | var sigma9 = 0.394; 29 | var sigma10 = 0.444; 30 | var sigma11 = 0.686; 31 | var sigma12 = 0.345; 32 | var sigma13 = 0.413; 33 | var sigma14 = 0.363; 34 | var sigma15 = 0.289; 35 | var sigma16 = 0.150; 36 | var sigma17 = 0.247; 37 | var sigma18 = 0.386; 38 | var sigma19 = 0.050; 39 | var sigma20 = 2.390; 40 | var sigma21 = 1.01; 41 | var sigma21Script = 0.81; 42 | var sigma21ScriptScript = 0.71; 43 | var sigma22 = 0.250; 44 | 45 | // These font metrics are extracted from TeX by using 46 | // \font\a=cmex10 47 | // \showthe\fontdimenX\a 48 | // where X is the corresponding variable number. These correspond to the font 49 | // parameters of the extension fonts (family 3). See the TeXbook, page 441. 50 | var xi1 = 0; 51 | var xi2 = 0; 52 | var xi3 = 0; 53 | var xi4 = 0; 54 | var xi5 = 0.431; 55 | var xi6 = 1; 56 | var xi7 = 0; 57 | var xi8 = 0.04; 58 | var xi9 = 0.111; 59 | var xi10 = 0.166; 60 | var xi11 = 0.2; 61 | var xi12 = 0.6; 62 | var xi13 = 0.1; 63 | 64 | // This value determines how large a pt is, for metrics which are defined in 65 | // terms of pts. 66 | // This value is also used in katex.less; if you change it make sure the values 67 | // match. 68 | var ptPerEm = 10.0; 69 | 70 | // The space between adjacent `|` columns in an array definition. From 71 | // `\showthe\doublerulesep` in LaTeX. 72 | var doubleRuleSep = 2.0 / ptPerEm; 73 | 74 | /** 75 | * This is just a mapping from common names to real metrics 76 | */ 77 | var metrics = { 78 | xHeight: sigma5, 79 | quad: sigma6, 80 | num1: sigma8, 81 | num2: sigma9, 82 | num3: sigma10, 83 | denom1: sigma11, 84 | denom2: sigma12, 85 | sup1: sigma13, 86 | sup2: sigma14, 87 | sup3: sigma15, 88 | sub1: sigma16, 89 | sub2: sigma17, 90 | supDrop: sigma18, 91 | subDrop: sigma19, 92 | axisHeight: sigma22, 93 | defaultRuleThickness: xi8, 94 | bigOpSpacing1: xi9, 95 | bigOpSpacing2: xi10, 96 | bigOpSpacing3: xi11, 97 | bigOpSpacing4: xi12, 98 | bigOpSpacing5: xi13, 99 | ptPerEm: ptPerEm, 100 | emPerEx: sigma5 / sigma6, 101 | doubleRuleSep: doubleRuleSep, 102 | 103 | // TODO(alpert): Missing parallel structure here. We should probably add 104 | // style-specific metrics for all of these. 105 | delim1: sigma20, 106 | getDelim2: function(style) { 107 | if (style.size === Style.TEXT.size) { 108 | return sigma21; 109 | } else if (style.size === Style.SCRIPT.size) { 110 | return sigma21Script; 111 | } else if (style.size === Style.SCRIPTSCRIPT.size) { 112 | return sigma21ScriptScript; 113 | } 114 | throw new Error("Unexpected style size: " + style.size); 115 | }, 116 | }; 117 | 118 | // This map contains a mapping from font name and character code to character 119 | // metrics, including height, depth, italic correction, and skew (kern from the 120 | // character to the corresponding \skewchar) 121 | // This map is generated via `make metrics`. It should not be changed manually. 122 | var metricMap = require("./fontMetricsData"); 123 | 124 | /** 125 | * This function is a convenience function for looking up information in the 126 | * metricMap table. It takes a character as a string, and a style. 127 | * 128 | * Note: the `width` property may be undefined if fontMetricsData.js wasn't 129 | * built using `Make extended_metrics`. 130 | */ 131 | var getCharacterMetrics = function(character, style) { 132 | var metrics = metricMap[style][character.charCodeAt(0)]; 133 | if (metrics) { 134 | return { 135 | depth: metrics[0], 136 | height: metrics[1], 137 | italic: metrics[2], 138 | skew: metrics[3], 139 | width: metrics[4], 140 | }; 141 | } 142 | }; 143 | 144 | module.exports = { 145 | metrics: metrics, 146 | getCharacterMetrics: getCharacterMetrics, 147 | }; 148 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/preprocessing/third_party/katex/src/mathMLTree.js: -------------------------------------------------------------------------------- 1 | /** 2 | * These objects store data about MathML nodes. This is the MathML equivalent 3 | * of the types in domTree.js. Since MathML handles its own rendering, and 4 | * since we're mainly using MathML to improve accessibility, we don't manage 5 | * any of the styling state that the plain DOM nodes do. 6 | * 7 | * The `toNode` and `toMarkup` functions work simlarly to how they do in 8 | * domTree.js, creating namespaced DOM nodes and HTML text markup respectively. 9 | */ 10 | 11 | var utils = require("./utils"); 12 | 13 | /** 14 | * This node represents a general purpose MathML node of any type. The 15 | * constructor requires the type of node to create (for example, `"mo"` or 16 | * `"mspace"`, corresponding to `` and `` tags). 17 | */ 18 | function MathNode(type, children) { 19 | this.type = type; 20 | this.attributes = {}; 21 | this.children = children || []; 22 | } 23 | 24 | /** 25 | * Sets an attribute on a MathML node. MathML depends on attributes to convey a 26 | * semantic content, so this is used heavily. 27 | */ 28 | MathNode.prototype.setAttribute = function(name, value) { 29 | this.attributes[name] = value; 30 | }; 31 | 32 | /** 33 | * Converts the math node into a MathML-namespaced DOM element. 34 | */ 35 | MathNode.prototype.toNode = function() { 36 | var node = document.createElementNS( 37 | "http://www.w3.org/1998/Math/MathML", this.type); 38 | 39 | for (var attr in this.attributes) { 40 | if (Object.prototype.hasOwnProperty.call(this.attributes, attr)) { 41 | node.setAttribute(attr, this.attributes[attr]); 42 | } 43 | } 44 | 45 | for (var i = 0; i < this.children.length; i++) { 46 | node.appendChild(this.children[i].toNode()); 47 | } 48 | 49 | return node; 50 | }; 51 | 52 | /** 53 | * Converts the math node into an HTML markup string. 54 | */ 55 | MathNode.prototype.toMarkup = function() { 56 | var markup = "<" + this.type; 57 | 58 | // Add the attributes 59 | for (var attr in this.attributes) { 60 | if (Object.prototype.hasOwnProperty.call(this.attributes, attr)) { 61 | markup += " " + attr + "=\""; 62 | markup += utils.escape(this.attributes[attr]); 63 | markup += "\""; 64 | } 65 | } 66 | 67 | markup += ">"; 68 | 69 | for (var i = 0; i < this.children.length; i++) { 70 | markup += this.children[i].toMarkup(); 71 | } 72 | 73 | markup += ""; 74 | 75 | return markup; 76 | }; 77 | 78 | /** 79 | * This node represents a piece of text. 80 | */ 81 | function TextNode(text) { 82 | this.text = text; 83 | } 84 | 85 | /** 86 | * Converts the text node into a DOM text node. 87 | */ 88 | TextNode.prototype.toNode = function() { 89 | return document.createTextNode(this.text); 90 | }; 91 | 92 | /** 93 | * Converts the text node into HTML markup (which is just the text itself). 94 | */ 95 | TextNode.prototype.toMarkup = function() { 96 | return utils.escape(this.text); 97 | }; 98 | 99 | module.exports = { 100 | MathNode: MathNode, 101 | TextNode: TextNode, 102 | }; 103 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/preprocessing/third_party/katex/src/parseData.js: -------------------------------------------------------------------------------- 1 | /** 2 | * The resulting parse tree nodes of the parse tree. 3 | */ 4 | function ParseNode(type, value, mode) { 5 | this.type = type; 6 | this.value = value; 7 | this.mode = mode; 8 | } 9 | 10 | module.exports = { 11 | ParseNode: ParseNode, 12 | }; 13 | 14 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/preprocessing/third_party/katex/src/parseTree.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Provides a single function for parsing an expression using a Parser 3 | * TODO(emily): Remove this 4 | */ 5 | 6 | var Parser = require("./Parser"); 7 | 8 | /** 9 | * Parses an expression using a Parser, then returns the parsed result. 10 | */ 11 | var parseTree = function(toParse, settings) { 12 | var parser = new Parser(toParse, settings); 13 | 14 | return parser.parse(); 15 | }; 16 | 17 | module.exports = parseTree; 18 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/preprocessing/third_party/katex/src/utils.js: -------------------------------------------------------------------------------- 1 | /** 2 | * This file contains a list of utility functions which are useful in other 3 | * files. 4 | */ 5 | 6 | /** 7 | * Provide an `indexOf` function which works in IE8, but defers to native if 8 | * possible. 9 | */ 10 | var nativeIndexOf = Array.prototype.indexOf; 11 | var indexOf = function(list, elem) { 12 | if (list == null) { 13 | return -1; 14 | } 15 | if (nativeIndexOf && list.indexOf === nativeIndexOf) { 16 | return list.indexOf(elem); 17 | } 18 | var i = 0; 19 | var l = list.length; 20 | for (; i < l; i++) { 21 | if (list[i] === elem) { 22 | return i; 23 | } 24 | } 25 | return -1; 26 | }; 27 | 28 | /** 29 | * Return whether an element is contained in a list 30 | */ 31 | var contains = function(list, elem) { 32 | return indexOf(list, elem) !== -1; 33 | }; 34 | 35 | /** 36 | * Provide a default value if a setting is undefined 37 | */ 38 | var deflt = function(setting, defaultIfUndefined) { 39 | return setting === undefined ? defaultIfUndefined : setting; 40 | }; 41 | 42 | // hyphenate and escape adapted from Facebook's React under Apache 2 license 43 | 44 | var uppercase = /([A-Z])/g; 45 | var hyphenate = function(str) { 46 | return str.replace(uppercase, "-$1").toLowerCase(); 47 | }; 48 | 49 | var ESCAPE_LOOKUP = { 50 | "&": "&", 51 | ">": ">", 52 | "<": "<", 53 | "\"": """, 54 | "'": "'", 55 | }; 56 | 57 | var ESCAPE_REGEX = /[&><"']/g; 58 | 59 | function escaper(match) { 60 | return ESCAPE_LOOKUP[match]; 61 | } 62 | 63 | /** 64 | * Escapes text to prevent scripting attacks. 65 | * 66 | * @param {*} text Text value to escape. 67 | * @return {string} An escaped string. 68 | */ 69 | function escape(text) { 70 | return ("" + text).replace(ESCAPE_REGEX, escaper); 71 | } 72 | 73 | /** 74 | * A function to set the text content of a DOM element in all supported 75 | * browsers. Note that we don't define this if there is no document. 76 | */ 77 | var setTextContent; 78 | if (typeof document !== "undefined") { 79 | var testNode = document.createElement("span"); 80 | if ("textContent" in testNode) { 81 | setTextContent = function(node, text) { 82 | node.textContent = text; 83 | }; 84 | } else { 85 | setTextContent = function(node, text) { 86 | node.innerText = text; 87 | }; 88 | } 89 | } 90 | 91 | /** 92 | * A function to clear a node. 93 | */ 94 | function clearNode(node) { 95 | setTextContent(node, ""); 96 | } 97 | 98 | module.exports = { 99 | contains: contains, 100 | deflt: deflt, 101 | escape: escape, 102 | hyphenate: hyphenate, 103 | indexOf: indexOf, 104 | setTextContent: setTextContent, 105 | clearNode: clearNode, 106 | }; 107 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/preprocessing/third_party/match-at/README.md: -------------------------------------------------------------------------------- 1 | # match-at [![Build Status](https://travis-ci.org/spicyj/match-at.svg?branch=master)](https://travis-ci.org/spicyj/match-at) 2 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/preprocessing/third_party/match-at/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "match-at", 3 | "version": "0.1.0", 4 | "description": "Relocatable regular expressions.", 5 | "repository": { 6 | "type": "git", 7 | "url": "https://github.com/spicyj/match-at" 8 | }, 9 | "main": "lib/matchAt.js", 10 | "files": [ 11 | "lib/" 12 | ], 13 | "devDependencies": { 14 | "babel": "^4.7.16", 15 | "jest-cli": "^0.4.0", 16 | "react-tools": "^0.13.1" 17 | }, 18 | "jest": { 19 | "scriptPreprocessor": "/jestSupport/preprocessor.js", 20 | "unmockedModulePathPatterns": [ 21 | "" 22 | ] 23 | }, 24 | "scripts": { 25 | "prepublish": "babel -d lib/ src/", 26 | "test": "jest" 27 | }, 28 | "gitHead": "4197daff69720734c72ba3321ed68a41c0527fb2", 29 | "bugs": { 30 | "url": "https://github.com/spicyj/match-at/issues" 31 | }, 32 | "homepage": "https://github.com/spicyj/match-at", 33 | "_id": "match-at@0.1.0", 34 | "_shasum": "f561e7709ff9a105b85cc62c6b8ee7c15bf24f31", 35 | "_from": "match-at@", 36 | "_npmVersion": "2.2.0", 37 | "_nodeVersion": "0.10.35", 38 | "_npmUser": { 39 | "name": "spicyj", 40 | "email": "ben@benalpert.com" 41 | }, 42 | "maintainers": [ 43 | { 44 | "name": "spicyj", 45 | "email": "ben@benalpert.com" 46 | } 47 | ], 48 | "dist": { 49 | "shasum": "f561e7709ff9a105b85cc62c6b8ee7c15bf24f31", 50 | "tarball": "https://registry.npmjs.org/match-at/-/match-at-0.1.0.tgz" 51 | }, 52 | "directories": {}, 53 | "_resolved": "https://registry.npmjs.org/match-at/-/match-at-0.1.0.tgz" 54 | } 55 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/render.py: -------------------------------------------------------------------------------- 1 | # try: 2 | # from dataset.latex2png import * 3 | # except ModuleNotFoundError: 4 | # from dataset.latex2png import * 5 | from dataset.latex2png import * 6 | import argparse 7 | import sys 8 | import os 9 | import glob 10 | import shutil 11 | from tqdm.auto import tqdm 12 | import cv2 13 | import numpy as np 14 | 15 | 16 | def render_dataset(dataset: np.ndarray, names: np.ndarray, args): 17 | '''Renders a list of tex equations 18 | 19 | Args: 20 | dataset (numpy.ndarray): List of equations 21 | names (numpy.ndarray): List of integers of size `dataset` that give the name of the saved image 22 | args (Union[Namespace, Munch]): additional arguments: mode (equation or inline), out (output directory), divable (common factor ) 23 | batchsize (how many samples to render at once), dpi, font (Math font), preprocess (crop, alpha off) 24 | shuffle (bool) 25 | 26 | Returns: 27 | list: equation indices that could not be rendered. 28 | ''' 29 | assert len(names) == len(dataset), 'names and dataset must be of equal size' 30 | math_mode = '$$'if args.mode == 'equation' else '$' 31 | os.makedirs(args.out, exist_ok=True) 32 | indices = np.array([int(os.path.basename(img).split('.')[0]) for img in glob.glob(os.path.join(args.out, '*.png'))]) 33 | 34 | valid = [i for i, j in enumerate(names) if j not in indices] 35 | dataset = dataset[valid] 36 | names = names[valid] 37 | order = np.random.permutation(len(dataset)) if args.shuffle else np.arange(len(dataset)) 38 | faulty = [] 39 | for i in tqdm(range(0, len(dataset), args.batchsize)): 40 | batch = dataset[order[i:i+args.batchsize]] 41 | #batch = [x for j, x in enumerate(batch) if order[i+j] not in indices] 42 | if len(batch) == 0: 43 | continue 44 | math = [math_mode+x+math_mode for x in batch if x != ''] 45 | #print('\n', i, len(math), '\n'.join(math)) 46 | if len(args.font) > 1: 47 | font = np.random.choice(args.font) 48 | else: 49 | font = args.font[0] 50 | if len(args.dpi) > 1: 51 | dpi = np.random.choice(np.arange(min(args.dpi), max(args.dpi))) 52 | else: 53 | dpi = args.dpi[0] 54 | if len(math) > 0: 55 | try: 56 | if args.preprocess: 57 | pngs = tex2pil(math, dpi=dpi, font=font) 58 | else: 59 | pngs = Latex(math, dpi=dpi, font=font).write(return_bytes=False) 60 | except Exception as e: 61 | #print(e) 62 | #print(math) 63 | #raise e 64 | faulty.extend(list(names[order[i:i+args.batchsize]])) 65 | continue 66 | 67 | for j, k in enumerate(range(i, i+len(pngs))): 68 | outpath = os.path.join(args.out, '%07d.png' % names[order[k]]) 69 | if args.preprocess: 70 | try: 71 | data = np.asarray(pngs[j]) 72 | # print(data.shape) 73 | gray = 255*(data[..., 0] < 128).astype(np.uint8) # To invert the text to white 74 | coords = cv2.findNonZero(gray) # Find all non-zero points (text) 75 | a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box 76 | rect = data[b:b+h, a:a+w] 77 | im = Image.fromarray((255-rect[..., -1]).astype(np.uint8)).convert('L') 78 | dims = [] 79 | for x in [w, h]: 80 | div, mod = divmod(x, args.divable) 81 | dims.append(args.divable*(div + (1 if mod > 0 else 0))) 82 | padded = Image.new('L', dims, 255) 83 | padded.paste(im, im.getbbox()) 84 | padded.save(outpath) 85 | except Exception as e: 86 | print(e) 87 | pass 88 | else: 89 | shutil.move(pngs[j], outpath) 90 | 91 | return np.array(faulty) 92 | 93 | 94 | if __name__ == '__main__': 95 | 96 | parser = argparse.ArgumentParser(description='Render dataset') 97 | parser.add_argument('-i', '--data', type=str, required=True, help='file of list of latex code') 98 | parser.add_argument('-o', '--out', type=str, required=True, help='output directory') 99 | parser.add_argument('-b', '--batchsize', type=int, default=100, help='How many equations to render at once') 100 | parser.add_argument('-f', '--font', nargs='+', type=str, default=['Latin Modern Math', 'GFSNeohellenicMath.otf', 'Asana Math', 'XITS Math', 101 | 'Cambria Math', 'Latin Modern Math', 'Latin Modern Math', 'Latin Modern Math'], help='font to use. default = Latin Modern Math') 102 | parser.add_argument('-m', '--mode', choices=['inline', 'equation'], default='equation', help='render as inline or equation') 103 | parser.add_argument('--dpi', type=int, default=[110, 170], nargs='+', help='dpi range to render in') 104 | parser.add_argument('-p', '--no-preprocess', dest='preprocess', default=True, action='store_false', help='crop, remove alpha channel, padding') 105 | parser.add_argument('-d', '--divable', type=int, default=32, help='To what factor to pad the images') 106 | parser.add_argument('-s', '--shuffle', action='store_true', help='Whether to shuffle the equations in the first iteration') 107 | args = parser.parse_args(sys.argv[1:]) 108 | 109 | dataset = np.array(open(args.data, 'r').read().split('\n'), dtype=object) 110 | names = np.arange(len(dataset)) 111 | prev_names = None 112 | for i in range(12): 113 | if len(names) == 0: 114 | break 115 | prev_names = names 116 | names = render_dataset(dataset[names], names, args) 117 | same = names == prev_names 118 | if (type(same) == bool and same) or (type(same) == np.ndarray and same.all()) or (args.batchsize == 1): 119 | break 120 | if len(names) < 50*args.batchsize: 121 | args.batchsize = max([1, args.batchsize//2]) 122 | args.shuffle = True 123 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/dataset/scraping.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | from tqdm import tqdm 5 | import html 6 | import requests 7 | import re 8 | import tempfile 9 | try: 10 | from arxiv import * 11 | from extract_latex import * 12 | except: 13 | from dataset.arxiv import * 14 | from dataset.extract_latex import * 15 | 16 | wikilinks = re.compile(r'href="/wiki/(.*?)"') 17 | htmltags = re.compile(r'<(noscript|script)>.*?<\/\1>', re.S) 18 | wiki_base = 'https://en.wikipedia.org/wiki/' 19 | 20 | 21 | def parse_url(url, encoding=None): 22 | r = requests.get(url) 23 | if r.ok: 24 | if encoding: 25 | r.encoding = encoding 26 | return html.unescape(re.sub(htmltags, '', r.text)) 27 | 28 | 29 | def parse_wiki(url): 30 | text = parse_url(url) 31 | linked = list(set([l for l in re.findall(wikilinks, text) if not ':' in l])) 32 | return find_math(text, wiki=True), linked 33 | 34 | 35 | # recursive search 36 | def recursive_search(parser, seeds, depth=2, skip=[], unit='links', base_url=None): 37 | visited, links = set(skip), set(seeds) 38 | math = [] 39 | try: 40 | for i in range(int(depth)): 41 | link_list = list(links) 42 | random.shuffle(link_list) 43 | t_bar = tqdm(link_list, initial=len(visited), unit=unit) 44 | for link in t_bar: 45 | if not link in visited: 46 | t_bar.set_description('searching %s' % (link)) 47 | if base_url: 48 | m, l = parser(base_url+link) 49 | else: 50 | m, l = parser(link) 51 | # check if we got any math from this wiki page and 52 | # if not terminate the tree 53 | if len(m) > 0: 54 | for li in l: 55 | links.add(li) 56 | t_bar.total = len(links) 57 | math.extend(m) 58 | visited.add(link) 59 | return list(visited), list(set(math)) 60 | except Exception as e: 61 | raise(e) 62 | return list(visited), list(set(math)) 63 | except KeyboardInterrupt: 64 | return list(visited), list(set(math)) 65 | 66 | # recursive wiki search 67 | 68 | 69 | def recursive_wiki(seeds, depth=4, skip=[]): 70 | '''Recursivley search wikipedia for math. Every link on the starting page `start` will be visited in the next round and so on, until there is no 71 | math in the child page anymore. This will be repeated `depth` times.''' 72 | start = [s.split('/')[-1] for s in seeds] 73 | return recursive_search(parse_wiki, start, depth, skip, base_url=wiki_base, unit='links') 74 | 75 | 76 | if __name__ == '__main__': 77 | if len(sys.argv) > 2: 78 | url = [sys.argv[1]] 79 | else: 80 | url = ['https://en.wikipedia.org/wiki/Mathematics', 'https://en.wikipedia.org/wiki/Physics'] 81 | visited, math = recursive_wiki(url) 82 | for l, name in zip([visited, math], ['visited_wiki.txt', 'math_wiki.txt']): 83 | f = open(os.path.join(sys.path[0], 'dataset', 'data', name), 'a', encoding='utf-8') 84 | for element in l: 85 | f.write(element) 86 | f.write('\n') 87 | f.close() 88 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/eval.py: -------------------------------------------------------------------------------- 1 | from dataset.dataset import Im2LatexDataset 2 | import os 3 | import sys 4 | import argparse 5 | import logging 6 | import yaml 7 | 8 | import numpy as np 9 | import torch 10 | from torchtext.data import metrics 11 | from munch import Munch 12 | from tqdm.auto import tqdm 13 | import wandb 14 | from Levenshtein import distance 15 | 16 | from models import get_model, Model 17 | from utils import * 18 | 19 | 20 | def detokenize(tokens, tokenizer): 21 | toks = [tokenizer.convert_ids_to_tokens(tok) for tok in tokens] 22 | for b in range(len(toks)): 23 | for i in reversed(range(len(toks[b]))): 24 | if toks[b][i] is None: 25 | toks[b][i] = '' 26 | toks[b][i] = toks[b][i].replace('Ġ', ' ').strip() 27 | if toks[b][i] in (['[BOS]', '[EOS]', '[PAD]']): 28 | del toks[b][i] 29 | return toks 30 | 31 | 32 | @torch.no_grad() 33 | def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: int = None, name: str = 'test'): 34 | """evaluates the model. Returns bleu score on the dataset 35 | 36 | Args: 37 | model (torch.nn.Module): the model 38 | dataset (Im2LatexDataset): test dataset 39 | args (Munch): arguments 40 | num_batches (int): How many batches to evaluate on. Defaults to None (all batches). 41 | name (str, optional): name of the test e.g. val or test for wandb. Defaults to 'test'. 42 | 43 | Returns: 44 | bleu_score: BLEU score of validation set. 45 | """ 46 | assert len(dataset) > 0 47 | device = args.device 48 | log = {} 49 | bleus, edit_dists = [], [] 50 | bleu_score, edit_distance = 0, 1 51 | pbar = tqdm(enumerate(iter(dataset)), total=len(dataset)) 52 | for i, (seq, im) in pbar: 53 | if seq is None or im is None: 54 | continue 55 | tgt_seq, tgt_mask = seq['input_ids'].to(device), seq['attention_mask'].bool().to(device) 56 | encoded = model.encoder(im.to(device)) 57 | #loss = decoder(tgt_seq, mask=tgt_mask, context=encoded) 58 | dec = model.decoder.generate(torch.LongTensor([args.bos_token]*len(encoded))[:, None].to(device), args.max_seq_len, 59 | eos_token=args.pad_token, context=encoded, temperature=args.get('temperature', .2)) 60 | pred = detokenize(dec, dataset.tokenizer) 61 | truth = detokenize(seq['input_ids'], dataset.tokenizer) 62 | bleus.append(metrics.bleu_score(pred, [alternatives(x) for x in truth])) 63 | for predi, truthi in zip(token2str(dec, dataset.tokenizer), token2str(seq['input_ids'], dataset.tokenizer)): 64 | ts = post_process(truthi) 65 | if len(ts) > 0: 66 | edit_dists.append(distance(post_process(predi), ts)/len(ts)) 67 | pbar.set_description('BLEU: %.3f, ED: %.2e' % (np.mean(bleus), np.mean(edit_dists))) 68 | if num_batches is not None and i >= num_batches: 69 | break 70 | if len(bleus) > 0: 71 | bleu_score = np.mean(bleus) 72 | log[name+'/bleu'] = bleu_score 73 | if len(edit_dists) > 0: 74 | edit_distance = np.mean(edit_dists) 75 | log[name+'/edit_distance'] = edit_distance 76 | if args.wandb: 77 | # samples 78 | pred = token2str(dec, dataset.tokenizer) 79 | truth = token2str(seq['input_ids'], dataset.tokenizer) 80 | table = wandb.Table(columns=["Truth", "Prediction"]) 81 | for k in range(min([len(pred), args.test_samples])): 82 | table.add_data(post_process(truth[k]), post_process(pred[k])) 83 | log[name+'/examples'] = table 84 | wandb.log(log) 85 | else: 86 | print('\n%s\n%s' % (truth, pred)) 87 | print('BLEU: %.2f' % bleu_score) 88 | return bleu_score, edit_distance 89 | 90 | 91 | if __name__ == '__main__': 92 | parser = argparse.ArgumentParser(description='Test model') 93 | parser.add_argument('--config', default='settings/config.yaml', help='path to yaml config file', type=argparse.FileType('r')) 94 | parser.add_argument('-c', '--checkpoint', default='checkpoints/weights.pth', type=str, help='path to model checkpoint') 95 | parser.add_argument('-d', '--data', default='dataset/data/val.pkl', type=str, help='Path to Dataset pkl file') 96 | parser.add_argument('--no-cuda', action='store_true', help='Use CPU') 97 | parser.add_argument('-b', '--batchsize', type=int, default=10, help='Batch size') 98 | parser.add_argument('--debug', action='store_true', help='DEBUG') 99 | parser.add_argument('-t', '--temperature', type=float, default=.333, help='sampling emperature') 100 | parser.add_argument('-n', '--num-batches', type=int, default=None, help='how many batches to evaluate on. Defaults to None (all)') 101 | 102 | parsed_args = parser.parse_args() 103 | with parsed_args.config as f: 104 | params = yaml.load(f, Loader=yaml.FullLoader) 105 | args = parse_args(Munch(params)) 106 | args.testbatchsize = parsed_args.batchsize 107 | args.wandb = False 108 | args.temperature = parsed_args.temperature 109 | logging.getLogger().setLevel(logging.DEBUG if parsed_args.debug else logging.WARNING) 110 | seed_everything(args.seed if 'seed' in args else 42) 111 | model = get_model(args) 112 | if parsed_args.checkpoint is not None: 113 | model.load_state_dict(torch.load(parsed_args.checkpoint, args.device)) 114 | dataset = Im2LatexDataset().load(parsed_args.data) 115 | valargs = args.copy() 116 | valargs.update(batchsize=args.testbatchsize, keep_smaller_batches=True, test=True) 117 | dataset.update(**valargs) 118 | evaluate(model, dataset, args, num_batches=parsed_args.num_batches) 119 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/newlatex/gongshi6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/newlatex/gongshi6.png -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/newlatex/img2latex.py: -------------------------------------------------------------------------------- 1 | from pix2text import Pix2Text 2 | 3 | img_fp = 'gongshi6.png' 4 | p2t = Pix2Text(analyzer_config=dict(model_name='mfd')) 5 | outs = p2t(img_fp, resized_shape=600) # 也可以使用 `p2t.recognize(img_fp)` 获得相同的结果 6 | 7 | # 如果只需要识别出的文字和Latex表示,可以使用下面行的代码合并所有结果 8 | only_text = '\n'.join([out['text'] for out in outs]) 9 | 10 | print(only_text) -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm>=4.47.0 2 | munch>=2.5.0 3 | torch>=1.7.1 4 | torchvision>=0.8.1 5 | opencv_python_headless>=4.1.1.26 6 | requests>=2.22.0 7 | einops>=0.3.0 8 | chardet>=3.0.4 9 | x_transformers==0.15.0 10 | imagesize>=1.2.0 11 | transformers==4.2.2 12 | tokenizers==0.9.4 13 | numpy>=1.19.5 14 | Pillow>=8.1.0 15 | PyYAML>=5.4.1 16 | torchtext>=0.6.0 17 | albumentations>=0.5.2 18 | pandas>=1.0.0 19 | timm 20 | python-Levenshtein>=0.12.2 21 | 22 | # GUI 23 | PyQt5 24 | PyQtWebEngine 25 | pynput 26 | screeninfo 27 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/resources/__pycache__/resources.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/resources/__pycache__/resources.cpython-37.pyc -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/resources/__pycache__/resources.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/resources/__pycache__/resources.cpython-38.pyc -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/resources/icon.svg: -------------------------------------------------------------------------------- 1 | 2 | 18 | 20 | 21 | 23 | image/svg+xml 24 | 26 | 27 | 28 | 29 | 30 | 54 | 56 | 60 | 64 | 65 | 73 | 78 | 86 | 87 | 92 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/resources/processing-icon-anim.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/resources/resources.qrc: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | resources/icon.svg 5 | resources/processing-icon-anim.svg 6 | 7 | 8 | resources/MathJax.js 9 | 10 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/settings/config.yaml: -------------------------------------------------------------------------------- 1 | backbone_layers: 2 | - 2 3 | - 3 4 | - 7 5 | betas: 6 | - 0.9 7 | - 0.999 8 | batchsize: 10 9 | bos_token: 1 10 | channels: 1 11 | data: dataset/data/train.pkl 12 | debug: false 13 | decoder_args: 14 | attn_on_attn: true 15 | cross_attend: true 16 | ff_glu: true 17 | rel_pos_bias: false 18 | use_scalenorm: false 19 | dim: 256 20 | encoder_depth: 4 21 | eos_token: 2 22 | epochs: 10 23 | gamma: 0.9995 24 | heads: 8 25 | id: null 26 | load_chkpt: null 27 | lr: 0.001 28 | lr_step: 30 29 | max_height: 192 30 | max_seq_len: 512 31 | max_width: 672 32 | min_height: 32 33 | min_width: 32 34 | model_path: checkpoints 35 | name: pix2tex 36 | num_layers: 4 37 | num_tokens: 8000 38 | optimizer: Adam 39 | output_path: outputs 40 | pad: false 41 | pad_token: 0 42 | patch_size: 16 43 | sample_freq: 3000 44 | save_freq: 5 45 | scheduler: StepLR 46 | seed: 42 47 | temperature: 0.2 48 | test_samples: 5 49 | testbatchsize: 20 50 | tokenizer: dataset/tokenizer.json 51 | valbatches: 100 52 | valdata: dataset/data/val.pkl 53 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/settings/debug.yaml: -------------------------------------------------------------------------------- 1 | # Input/Output/Name 2 | data: "dataset/data/dataset.pkl" 3 | valdata: "dataset/data/val.pkl" 4 | tokenizer: "dataset/tokenizer.json" 5 | output_path: "outputs" 6 | model_path: "checkpoints" 7 | load_chkpt: null 8 | save_freq: 5 # save every nth epoch 9 | name: "pix2tex" 10 | 11 | # Training parameters 12 | epochs: 10 13 | batchsize: 8 14 | 15 | # Testing parameters 16 | testbatchsize: 20 17 | valbatches: 100 18 | temperature: 0.2 19 | 20 | # Optimizer configurations 21 | optimizer: "Adam" 22 | scheduler: "StepLR" 23 | lr: 0.001 24 | gamma: 0.9995 25 | lr_step: 30 26 | betas: 27 | - 0.9 28 | - 0.999 29 | 30 | # Parameters for model architectures 31 | max_width: 128 32 | max_height: 128 33 | min_width: 32 34 | min_height: 32 35 | channels: 1 36 | patch_size: 32 37 | # Encoder / Decoder 38 | dim: 128 39 | backbone_layers: 40 | - 3 41 | - 4 42 | - 9 43 | encoder_depth: 4 44 | num_layers: 4 45 | decoder_args: 46 | cross_attend: true 47 | ff_glu: true 48 | attn_on_attn: false 49 | use_scalenorm: true 50 | rel_pos_bias: false 51 | heads: 8 52 | num_tokens: 8000 53 | max_seq_len: 1024 54 | 55 | # Other 56 | seed: 42 57 | id: null 58 | sample_freq: 50 59 | test_samples: 5 60 | debug: True 61 | pad: False 62 | 63 | # Token ids 64 | pad_token: 0 65 | bos_token: 1 66 | eos_token: 2 67 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/setup_desktop.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | '''Simple installer for the graphical user interface of pix2tex''' 4 | 5 | import argparse 6 | import os 7 | import sys 8 | 9 | 10 | def _check_file( 11 | main_file 12 | ): 13 | if os.path.exists(main_file): 14 | return 15 | raise FileNotFoundError( 16 | f'Unable to find file {main_file}' 17 | ) 18 | 19 | 20 | def _make_desktop_file( 21 | desktop_path, 22 | desktop_entry 23 | ): 24 | with open(desktop_path, 'w') as desktop_file: 25 | desktop_file.write(desktop_entry) 26 | 27 | 28 | def setup_desktop( 29 | gui_file = 'gui.py', 30 | icon_file = 'resources/icon.svg', 31 | ): 32 | '''Main function for setting up .desktop files (on Linux)''' 33 | parser = argparse.ArgumentParser( 34 | description='Simple installer for the pix2tex GUI' 35 | ) 36 | 37 | parser.add_argument( 38 | 'pix2tex_dir', 39 | default='.', 40 | nargs='?', 41 | help='The directory where pix2tex was downloaded' 42 | ) 43 | 44 | parser.add_argument( 45 | '--uninstall', '-u', 46 | action='store_true', 47 | help='Uninstalls the desktop entry' 48 | ) 49 | 50 | parser.add_argument( 51 | '--venv_dir', '-e', 52 | help='In case a virtual environment is needed for running pix2tex, specifies its directory' 53 | ) 54 | 55 | parser.add_argument( 56 | '--overwrite', '-o', 57 | action='store_true', 58 | help='Unconditionally overwrite .desktop file (if it exists)' 59 | ) 60 | 61 | args = parser.parse_args() 62 | 63 | # where the desktop file will be created 64 | desktop_dir = os.environ.get( 65 | 'XDG_DATA_HOME', 66 | os.path.join(os.environ.get('HOME'), '.local/share/applications') 67 | ) 68 | desktop_path = os.path.abspath(os.path.join(desktop_dir, 'pix2tex.desktop')) 69 | 70 | # check if we want to uninstall it instead 71 | if args.uninstall: 72 | if os.path.exists(desktop_path): 73 | remove = input( 74 | f'Are you sure you want to remove the pix2tex desktop entry {desktop_path}? [y/n]' 75 | ) 76 | if remove.lower() == 'y': 77 | try: 78 | os.remove(desktop_path) 79 | print('Successfully uninstalled the desktop entry') 80 | return 0 81 | except: 82 | raise OSError( 83 | f'Something went wrong, unable to remove the desktop entry {desktop_path}' 84 | ) 85 | elif remove.lower() == 'n': 86 | print( 87 | 'Not removing the desktop entry;' \ 88 | 'if you wish to install/uninstall pix2tex, please run this script again' 89 | ) 90 | return 0 91 | else: 92 | print('No file to remove') 93 | return 0 94 | 95 | _check_file(os.path.join(args.pix2tex_dir, gui_file)) 96 | _check_file(os.path.join(args.pix2tex_dir, icon_file)) 97 | 98 | pix2tex_dir = os.path.abspath(args.pix2tex_dir) 99 | gui_path = os.path.join(pix2tex_dir, gui_file) 100 | icon_path = os.path.join(pix2tex_dir, icon_file) 101 | 102 | interpreter_path = \ 103 | os.path.join(args.venv_dir, 'bin/python3') \ 104 | if (args.venv_dir and os.path.exists(os.path.join(args.venv_dir, 'bin/python3'))) \ 105 | else sys.executable 106 | interpreter_path = os.path.abspath(interpreter_path) 107 | 108 | desktop_entry = f"""[Desktop Entry] 109 | Version=1.0 110 | Name=pix2tex 111 | Comment=LaTeX math recognition using machine learning 112 | Exec={interpreter_path} {gui_path} 113 | Icon={icon_path} 114 | Terminal=false 115 | Type=Application 116 | Categories=Utility; 117 | """ 118 | 119 | if os.path.exists(desktop_path): 120 | if not args.overwrite: 121 | overwrite = input( 122 | f'Desktop entry {desktop_path} exists, do you wish to overwrite it? [y/n]' 123 | ) 124 | if overwrite.lower() == 'y': 125 | _make_desktop_file(desktop_path, desktop_entry) 126 | elif overwrite.lower() == 'n': 127 | print('Not overwriting existing desktop entry, exiting...', file=sys.stderr) 128 | return 1 129 | else: 130 | print('Unable to understand input, exiting...', file=sys.stderr) 131 | return 255 132 | else: 133 | _make_desktop_file(desktop_path, desktop_entry) 134 | else: 135 | _make_desktop_file(desktop_path, desktop_entry) 136 | 137 | return 0 138 | 139 | 140 | if __name__ == '__main__': 141 | setup_desktop() 142 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/train.py: -------------------------------------------------------------------------------- 1 | from dataset.dataset import Im2LatexDataset 2 | import os 3 | import sys 4 | import argparse 5 | import logging 6 | import yaml 7 | 8 | import numpy as np 9 | import torch 10 | import torch.optim as optim 11 | import torch.nn as nn 12 | from munch import Munch 13 | from tqdm.auto import tqdm 14 | import wandb 15 | 16 | from eval import evaluate 17 | from models import get_model 18 | from utils import * 19 | 20 | 21 | def train(args): 22 | dataloader = Im2LatexDataset().load(args.data) 23 | dataloader.update(**args, test=False) 24 | valdataloader = Im2LatexDataset().load(args.valdata) 25 | valargs = args.copy() 26 | valargs.update(batchsize=args.testbatchsize, keep_smaller_batches=True, test=True) 27 | valdataloader.update(**valargs) 28 | device = args.device 29 | model = get_model(args, training=True) 30 | if args.load_chkpt is not None: 31 | model.load_state_dict(torch.load(args.load_chkpt, map_location=device)) 32 | encoder, decoder = model.encoder, model.decoder 33 | 34 | def save_models(e): 35 | torch.save(model.state_dict(), os.path.join(args.out_path, '%s_e%02d.pth' % (args.name, e+1))) 36 | yaml.dump(dict(args), open(os.path.join(args.out_path, 'config.yaml'), 'w+')) 37 | 38 | opt = get_optimizer(args.optimizer)(model.parameters(), args.lr, betas=args.betas) 39 | scheduler = get_scheduler(args.scheduler)(opt, step_size=args.lr_step, gamma=args.gamma) 40 | try: 41 | for e in range(args.epoch, args.epochs): 42 | args.epoch = e 43 | dset = tqdm(iter(dataloader)) 44 | for i, (seq, im) in enumerate(dset): 45 | if seq is not None and im is not None: 46 | opt.zero_grad() 47 | tgt_seq, tgt_mask = seq['input_ids'].to(device), seq['attention_mask'].bool().to(device) 48 | encoded = encoder(im.to(device)) 49 | loss = decoder(tgt_seq, mask=tgt_mask, context=encoded) 50 | loss.backward() 51 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1) 52 | opt.step() 53 | scheduler.step() 54 | dset.set_description('Loss: %.4f' % loss.item()) 55 | if args.wandb: 56 | wandb.log({'train/loss': loss.item()}) 57 | if (i+1) % args.sample_freq == 0: 58 | evaluate(model, valdataloader, args, num_batches=int(args.valbatches*e/args.epochs), name='val') 59 | if (e+1) % args.save_freq == 0: 60 | save_models(e) 61 | if args.wandb: 62 | wandb.log({'train/epoch': e+1}) 63 | except KeyboardInterrupt: 64 | if e >= 2: 65 | save_models(e) 66 | raise KeyboardInterrupt 67 | save_models(e) 68 | 69 | 70 | if __name__ == '__main__': 71 | parser = argparse.ArgumentParser(description='Train model') 72 | parser.add_argument('--config', default='settings/debug.yaml', help='path to yaml config file', type=argparse.FileType('r')) 73 | parser.add_argument('-d', '--data', default='dataset/data/train.pkl', type=str, help='Path to Dataset pkl file') 74 | parser.add_argument('--no_cuda', action='store_true', help='Use CPU') 75 | parser.add_argument('--debug', action='store_true', help='DEBUG') 76 | parser.add_argument('--resume', help='path to checkpoint folder', action='store_true') 77 | 78 | parsed_args = parser.parse_args() 79 | with parsed_args.config as f: 80 | params = yaml.load(f, Loader=yaml.FullLoader) 81 | args = parse_args(Munch(params), **vars(parsed_args)) 82 | logging.getLogger().setLevel(logging.DEBUG if parsed_args.debug else logging.WARNING) 83 | seed_everything(args.seed) 84 | if args.wandb: 85 | if not parsed_args.resume: 86 | args.id = wandb.util.generate_id() 87 | wandb.init(config=dict(args), resume='allow', name=args.name, id=args.id) 88 | train(args) 89 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/train_resizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.optim import Adam 5 | from torch.optim.lr_scheduler import OneCycleLR 6 | from timm.models.resnetv2 import ResNetV2 7 | from timm.models.layers import StdConv2dSame 8 | import numpy as np 9 | from PIL import Image 10 | import cv2 11 | import imagesize 12 | import yaml 13 | from tqdm.auto import tqdm 14 | from utils import * 15 | from dataset.dataset import * 16 | from munch import Munch 17 | import argparse 18 | 19 | 20 | def prepare_data(dataloader): 21 | _, ims = dataloader.pairs[dataloader.i-1].T 22 | images = [] 23 | scale = None 24 | c = 0 25 | width, height = imagesize.get(ims[0]) 26 | while True: 27 | c += 1 28 | s = np.array([width, height]) 29 | scale = 5*(np.random.random()+.02) 30 | if all((s*scale) <= dataloader.max_dimensions[0]) and all((s*scale) >= 16): 31 | break 32 | if c > 25: 33 | return None, None 34 | x, y = 0, 0 35 | for path in list(ims): 36 | im = Image.open(path) 37 | modes = [Image.BICUBIC, 38 | Image.BILINEAR] 39 | if scale < 1: 40 | modes.append(Image.LANCZOS) 41 | m = modes[int(len(modes)*np.random.random())] 42 | im = im.resize((int(width*scale), int(height*scale)), m) 43 | try: 44 | im = pad(im) 45 | except: 46 | return None, None 47 | if im is None: 48 | print(path, 'not found!') 49 | continue 50 | im = np.array(im) 51 | im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) 52 | images.append(dataloader.transform(image=im)['image'][:1]) 53 | if images[-1].shape[-1] > x: 54 | x = images[-1].shape[-1] 55 | if images[-1].shape[-2] > y: 56 | y = images[-1].shape[-2] 57 | if x > dataloader.max_dimensions[0] or y > dataloader.max_dimensions[1]: 58 | return None, None 59 | for i in range(len(images)): 60 | h, w = images[i].shape[1:] 61 | images[i] = F.pad(images[i], (0, x-w, 0, y-h), value=0) 62 | try: 63 | images = torch.cat(images).float().unsqueeze(1) 64 | except RuntimeError as e: 65 | #print(e, 'Images not working: %s' % (' '.join(list(ims)))) 66 | return None, None 67 | dataloader.i += 1 68 | labels = torch.tensor(width//32-1).repeat(len(ims)).long() 69 | return images, labels 70 | 71 | 72 | def val(val, model, num_samples=400, device='cuda'): 73 | model.eval() 74 | c, t = 0, 0 75 | iter(val) 76 | with torch.no_grad(): 77 | for i in range(num_samples): 78 | im, l = prepare_data(val) 79 | if im is None: 80 | continue 81 | p = model(im.to(device)).argmax(-1).detach().cpu().numpy() 82 | c += (p == l[0].item()).sum() 83 | t += len(im) 84 | model.train() 85 | return c/t 86 | 87 | 88 | def main(args): 89 | # data 90 | dataloader = Im2LatexDataset().load(args.data) 91 | dataloader.update(batchsize=args.batchsize, test=False, max_dimensions=args.max_dimensions, keep_smaller_batches=True, device=args.device) 92 | valloader = Im2LatexDataset().load(args.valdata) 93 | valloader.update(batchsize=args.batchsize, test=True, max_dimensions=args.max_dimensions, keep_smaller_batches=True, device=args.device) 94 | 95 | # model 96 | model = ResNetV2(layers=[2, 3, 3], num_classes=int(max(args.max_dimensions)//32), global_pool='avg', in_chans=args.channels, drop_rate=.05, 97 | preact=True, stem_type='same', conv_layer=StdConv2dSame).to(args.device) 98 | if args.resume: 99 | model.load_state_dict(torch.load(args.resume)) 100 | opt = Adam(model.parameters(), lr=args.lr) 101 | crit = nn.CrossEntropyLoss() 102 | sched = OneCycleLR(opt, .005, total_steps=args.num_epochs*len(dataloader)) 103 | global bestacc 104 | bestacc = val(valloader, model, args.valbatches, args.device) 105 | 106 | def train_epoch(sched=None): 107 | iter(dataloader) 108 | dset = tqdm(range(len(dataloader))) 109 | for i in dset: 110 | im, label = prepare_data(dataloader) 111 | if im is not None: 112 | if im.shape[-1] > dataloader.max_dimensions[0] or im.shape[-2] > dataloader.max_dimensions[1]: 113 | continue 114 | opt.zero_grad() 115 | label = label.to(args.device) 116 | 117 | pred = model(im.to(args.device)) 118 | loss = crit(pred, label) 119 | if i % 2 == 0: 120 | dset.set_description('Loss: %.4f' % loss.item()) 121 | loss.backward() 122 | opt.step() 123 | if sched is not None: 124 | sched.step() 125 | if (i+1) % args.sample_freq == 0 or i+1 == len(dset): 126 | acc = val(valloader, model, args.valbatches, args.device) 127 | print('Accuracy %.2f' % (100*acc), '%') 128 | global bestacc 129 | if acc > bestacc: 130 | torch.save(model.state_dict(), args.out) 131 | bestacc = acc 132 | for _ in range(args.num_epochs): 133 | train_epoch(sched) 134 | 135 | 136 | if __name__ == '__main__': 137 | parser = argparse.ArgumentParser(description='Train size classification model') 138 | parser.add_argument('--config', default='settings/debug.yaml', help='path to yaml config file', type=argparse.FileType('r')) 139 | parser.add_argument('--no_cuda', action='store_true', help='Use CPU') 140 | parser.add_argument('--lr', type=float, default=5e-4, help='learning rate') 141 | parser.add_argument('--resume', help='path to checkpoint folder', type=str, default='') 142 | parser.add_argument('--out', type=str, default='checkpoints/image_resizer.pth', help='output destination for trained model') 143 | parser.add_argument('--num_epochs', type=int, default=10, help='number of epochs to train') 144 | parser.add_argument('--batchsize', type=int, default=10) 145 | parsed_args = parser.parse_args() 146 | with parsed_args.config as f: 147 | params = yaml.load(f, Loader=yaml.FullLoader) 148 | args = parse_args(Munch(params), **vars(parsed_args)) 149 | args.update(**vars(parsed_args)) 150 | main(args) 151 | -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from utils.utils import * -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/utils/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/utils/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)/utils/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import cv2 4 | import re 5 | from PIL import Image 6 | import numpy as np 7 | import torch 8 | from munch import Munch 9 | from inspect import isfunction 10 | 11 | operators = '|'.join(['arccos', 'arcsin', 'arctan', 'arg', 'cos', 'cosh', 'cot', 'coth', 'csc', 'deg', 'det', 'dim', 'exp', 'gcd', 'hom', 'inf', 12 | 'injlim', 'ker', 'lg', 'lim', 'liminf', 'limsup', 'ln', 'log', 'max', 'min', 'Pr', 'projlim', 'sec', 'sin', 'sinh', 'sup', 'tan', 'tanh']) 13 | ops = re.compile(r'\\operatorname{(%s)}' % operators) 14 | 15 | 16 | class EmptyStepper: 17 | def __init__(self, *args, **kwargs): 18 | pass 19 | 20 | def step(self, *args, **kwargs): 21 | pass 22 | 23 | # helper functions from lucidrains 24 | 25 | 26 | def exists(val): 27 | return val is not None 28 | 29 | 30 | def default(val, d): 31 | if exists(val): 32 | return val 33 | return d() if isfunction(d) else d 34 | 35 | 36 | def seed_everything(seed: int): 37 | """Seed all RNGs 38 | 39 | Args: 40 | seed (int): seed 41 | """ 42 | random.seed(seed) 43 | os.environ['PYTHONHASHSEED'] = str(seed) 44 | np.random.seed(seed) 45 | torch.manual_seed(seed) 46 | torch.cuda.manual_seed(seed) 47 | torch.backends.cudnn.deterministic = True 48 | torch.backends.cudnn.benchmark = True 49 | 50 | 51 | def parse_args(args, **kwargs): 52 | args = Munch({'epoch': 0}, **args) 53 | kwargs = Munch({'no_cuda': False, 'debug': False}, **kwargs) 54 | args.wandb = not kwargs.debug and not args.debug 55 | args.device = 'cuda' if torch.cuda.is_available() and not kwargs.no_cuda else 'cpu' 56 | args.max_dimensions = [args.max_width, args.max_height] 57 | args.min_dimensions = [args.get('min_width', 32), args.get('min_height', 32)] 58 | if 'decoder_args' not in args or args.decoder_args is None: 59 | args.decoder_args = {} 60 | if 'model_path' in args: 61 | args.out_path = os.path.join(args.model_path, args.name) 62 | os.makedirs(args.out_path, exist_ok=True) 63 | return args 64 | 65 | 66 | def token2str(tokens, tokenizer): 67 | if len(tokens.shape) == 1: 68 | tokens = tokens[None, :] 69 | dec = [tokenizer.decode(tok) for tok in tokens] 70 | return [''.join(detok.split(' ')).replace('Ġ', ' ').replace('[EOS]', '').replace('[BOS]', '').replace('[PAD]', '').strip() for detok in dec] 71 | 72 | 73 | def pad(img: Image, divable=32): 74 | """Pad an Image to the next full divisible value of `divable`. Also normalizes the image and invert if needed. 75 | 76 | Args: 77 | img (PIL.Image): input image 78 | divable (int, optional): . Defaults to 32. 79 | 80 | Returns: 81 | PIL.Image 82 | """ 83 | data = np.array(img.convert('LA')) 84 | data = (data-data.min())/(data.max()-data.min())*255 85 | if data[..., 0].mean() > 128: 86 | gray = 255*(data[..., 0] < 128).astype(np.uint8) # To invert the text to white 87 | else: 88 | gray = 255*(data[..., 0] > 128).astype(np.uint8) 89 | data[..., 0] = 255-data[..., 0] 90 | 91 | coords = cv2.findNonZero(gray) # Find all non-zero points (text) 92 | a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box 93 | rect = data[b:b+h, a:a+w] 94 | if rect[..., -1].var() == 0: 95 | im = Image.fromarray((rect[..., 0]).astype(np.uint8)).convert('L') 96 | else: 97 | im = Image.fromarray((255-rect[..., -1]).astype(np.uint8)).convert('L') 98 | dims = [] 99 | for x in [w, h]: 100 | div, mod = divmod(x, divable) 101 | dims.append(divable*(div + (1 if mod > 0 else 0))) 102 | padded = Image.new('L', dims, 255) 103 | padded.paste(im, im.getbbox()) 104 | return padded 105 | 106 | 107 | def post_process(s: str): 108 | """Remove unnecessary whitespace from LaTeX code. 109 | 110 | Args: 111 | s (str): Input string 112 | 113 | Returns: 114 | str: Processed image 115 | """ 116 | text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})' 117 | letter = '[a-zA-Z]' 118 | noletter = '[\W_^\d]' 119 | names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)] 120 | s = re.sub(text_reg, lambda match: str(names.pop(0)), s) 121 | news = s 122 | while True: 123 | s = news 124 | news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s) 125 | news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news) 126 | news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news) 127 | if news == s: 128 | break 129 | return s 130 | 131 | 132 | def alternatives(s): 133 | # TODO takes list of list of tokens 134 | # try to generate equivalent code eg \ne \neq or \to \rightarrow 135 | # alts = [s] 136 | # names = ['\\'+x for x in re.findall(ops, s)] 137 | # alts.append(re.sub(ops, lambda match: str(names.pop(0)), s)) 138 | 139 | # return alts 140 | return [s] 141 | 142 | 143 | def get_optimizer(optimizer): 144 | return getattr(torch.optim, optimizer) 145 | 146 | 147 | def get_scheduler(scheduler): 148 | if scheduler is None: 149 | return EmptyStepper 150 | return getattr(torch.optim.lr_scheduler, scheduler) 151 | 152 | 153 | def num_model_params(model): 154 | return sum([p.numel() for p in model.parameters()]) 155 | -------------------------------------------------------------------------------- /深度学习实战11(进阶版)-BERT模型的微调应用-文本分类案例/bert_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from datasets import load_dataset 3 | import torch.nn.functional as F 4 | from transformers import BertTokenizer 5 | 6 | # 加载字典和分词工具 7 | token = BertTokenizer.from_pretrained('bert-base-chinese') 8 | 9 | # 定义数据集 10 | class Dataset(torch.utils.data.Dataset): 11 | def __init__(self, split): 12 | self.dataset = load_dataset(path='data', split=split) 13 | 14 | def __len__(self): 15 | return len(self.dataset) 16 | 17 | def __getitem__(self, i): 18 | text = self.dataset[i]['text'] 19 | label = self.dataset[i]['label'] 20 | 21 | return text, label 22 | 23 | 24 | dataset = Dataset('train') 25 | print(len(dataset), dataset[0]) 26 | 27 | 28 | def collate_fn(data): 29 | sents = [i[0] for i in data] 30 | labels = [i[1] for i in data] 31 | 32 | # 编码 33 | data = token.batch_encode_plus(batch_text_or_text_pairs=sents, 34 | truncation=True, 35 | padding='max_length', 36 | max_length=500, 37 | return_tensors='pt', 38 | return_length=True) 39 | 40 | # input_ids:编码之后的数字 41 | # attention_mask:是补零的位置是0,其他位置是1 42 | input_ids = data['input_ids'] 43 | attention_mask = data['attention_mask'] 44 | token_type_ids = data['token_type_ids'] 45 | labels = torch.LongTensor(labels) 46 | 47 | # print(data['length'], data['length'].max()) 48 | return input_ids, attention_mask, token_type_ids, labels 49 | 50 | 51 | # 数据加载器 52 | loader = torch.utils.data.DataLoader(dataset=dataset, 53 | batch_size=10, 54 | collate_fn=collate_fn, 55 | shuffle=True, 56 | drop_last=True) 57 | 58 | for i, (input_ids, attention_mask, token_type_ids, 59 | labels) in enumerate(loader): 60 | break 61 | 62 | print(len(loader)) 63 | print(input_ids.shape, attention_mask.shape, token_type_ids.shape, labels) 64 | 65 | from transformers import BertModel 66 | 67 | # 加载预训练模型 68 | pretrained = BertModel.from_pretrained('bert-base-chinese') 69 | 70 | # 不训练,不需要计算梯度 71 | for param in pretrained.parameters(): 72 | param.requires_grad_(False) 73 | 74 | # 模型试算 75 | out = pretrained(input_ids=input_ids, 76 | attention_mask=attention_mask, 77 | token_type_ids=token_type_ids) 78 | 79 | print(out.last_hidden_state.shape) 80 | 81 | 82 | # 定义下游任务模型 83 | class Model(torch.nn.Module): 84 | def __init__(self): 85 | super().__init__() 86 | self.fc = torch.nn.Linear(768, 2) 87 | # 可加入CNN卷积层,可以自行操作 88 | # self.conv1D = torch.nn.Conv1d(in_channels=500, out_channels=500, kernel_size=1) 89 | # self.MaxPool1D = torch.nn.MaxPool1d(4, stride=2) 90 | # self.Dropout = torch.nn.Dropout(p=0.5, inplace=False) 91 | 92 | def forward(self, input_ids, attention_mask, token_type_ids): 93 | with torch.no_grad(): 94 | out = pretrained(input_ids=input_ids, 95 | attention_mask=attention_mask, 96 | token_type_ids=token_type_ids) 97 | out = self.fc(out.last_hidden_state[:, 0]) 98 | out = out.softmax(dim=1) 99 | print(out.shape) 100 | return out 101 | 102 | 103 | model = Model() 104 | print(model) 105 | # model.summary() 106 | model(input_ids=input_ids, 107 | attention_mask=attention_mask, 108 | token_type_ids=token_type_ids).shape 109 | 110 | from transformers import AdamW 111 | 112 | # 训练 113 | optimizer = AdamW(model.parameters(), lr=5e-4) 114 | criterion = torch.nn.CrossEntropyLoss() 115 | 116 | model.train() 117 | epochs = 30 118 | 119 | for i, (input_ids, attention_mask, token_type_ids, 120 | labels) in enumerate(loader): 121 | out = model(input_ids=input_ids, 122 | attention_mask=attention_mask, 123 | token_type_ids=token_type_ids) 124 | 125 | loss = criterion(out, labels) 126 | loss.backward() 127 | optimizer.step() 128 | optimizer.zero_grad() 129 | 130 | if i % 1 == 0: 131 | out = out.argmax(dim=1) 132 | accuracy = (out == labels).sum().item() / len(labels) 133 | 134 | print('epochs:', i, 'loss:', loss.item(), 'accuracy:', accuracy) 135 | 136 | if i == epochs: 137 | torch.save(model, 'text_classfiy.model') 138 | # model_load = torch.load('model/命名实体识别_中文.model') 139 | break 140 | 141 | 142 | # 测试函数 143 | def test(): 144 | model.eval() 145 | correct = 0 146 | total = 0 147 | 148 | loader_test = torch.utils.data.DataLoader(dataset=Dataset('validation'), 149 | batch_size=10, 150 | collate_fn=collate_fn, 151 | shuffle=True, 152 | drop_last=True) 153 | 154 | for i, (input_ids, attention_mask, token_type_ids, 155 | labels) in enumerate(loader_test): 156 | 157 | if i == 5: 158 | break 159 | 160 | with torch.no_grad(): 161 | out = model(input_ids=input_ids, 162 | attention_mask=attention_mask, 163 | token_type_ids=token_type_ids) 164 | 165 | out = out.argmax(dim=1) 166 | correct += (out == labels).sum().item() 167 | total += len(labels) 168 | 169 | print(correct / total) -------------------------------------------------------------------------------- /深度学习实战12(进阶版)-利用Dewarp实现文本扭曲矫正/123.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战12(进阶版)-利用Dewarp实现文本扭曲矫正/123.png -------------------------------------------------------------------------------- /深度学习实战12(进阶版)-利用Dewarp实现文本扭曲矫正/dewarp.py: -------------------------------------------------------------------------------- 1 | from page_dewarp import __main__ 2 | 3 | 4 | if __name__ == "__main__": 5 | imgfile = '123.png' 6 | __main__.main(imgfile) # 图片文字扭曲处理 7 | 8 | # 注意保存的地址设置 对应的地址 -------------------------------------------------------------------------------- /深度学习实战13(进阶版)-文本纠错功能,经常写错别字的小伙伴的福星/correct1.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append("..") 4 | 5 | import pycorrector 6 | 7 | if __name__ == '__main__': 8 | 9 | error_sentences = [ 10 | '他是有明的侦探', 11 | '这场比赛我甘败下风', 12 | '这家伙还蛮格尽职守的', 13 | '报应接中迩来', 14 | '今天我很高形', 15 | '少先队员因该为老人让坐', 16 | '老是在较书。' 17 | ] 18 | for line in error_sentences: 19 | correct_sent, err = pycorrector.correct(line) 20 | print("{} => {} {}".format(line, correct_sent, err)) -------------------------------------------------------------------------------- /深度学习实战13(进阶版)-文本纠错功能,经常写错别字的小伙伴的福星/correct2.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import operator 3 | import torch 4 | from transformers import BertTokenizer, BertForMaskedLM 5 | 6 | tokenizer = BertTokenizer.from_pretrained("shibing624/macbert4csc-base-chinese") 7 | model = BertForMaskedLM.from_pretrained("shibing624/macbert4csc-base-chinese") 8 | 9 | 10 | def ai_text(text): 11 | with torch.no_grad(): 12 | outputs = model(**tokenizer([text], padding=True, return_tensors='pt')) 13 | 14 | def to_highlight(corrected_sent, errs): 15 | output = [{"entity": "纠错", "word": err[1], "start": err[2], "end": err[3]} for i, err in 16 | enumerate(errs)] 17 | return {"text": corrected_sent, "entities": output} 18 | 19 | def get_errors(corrected_text, origin_text): 20 | sub_details = [] 21 | for i, ori_char in enumerate(origin_text): 22 | if ori_char in [' ', '“', '”', '‘', '’', '琊', '\n', '…', '—', '擤']: 23 | # add unk word 24 | corrected_text = corrected_text[:i] + ori_char + corrected_text[i:] 25 | continue 26 | if i >= len(corrected_text): 27 | continue 28 | if ori_char != corrected_text[i]: 29 | if ori_char.lower() == corrected_text[i]: 30 | # pass english upper char 31 | corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:] 32 | continue 33 | sub_details.append((ori_char, corrected_text[i], i, i + 1)) 34 | sub_details = sorted(sub_details, key=operator.itemgetter(2)) 35 | return corrected_text, sub_details 36 | 37 | _text = tokenizer.decode(torch.argmax(outputs.logits[0], dim=-1), skip_special_tokens=True).replace(' ', '') 38 | corrected_text = _text[:len(text)] 39 | corrected_text, details = get_errors(corrected_text, text) 40 | print(text, ' => ', corrected_text, details) 41 | return to_highlight(corrected_text, details), details 42 | 43 | 44 | if __name__ == '__main__': 45 | print(ai_text('少先队员因该为老人让坐')) 46 | 47 | examples = [ 48 | ['真麻烦你了。希望你们好好的跳无'], 49 | ['少先队员因该为老人让坐'], 50 | ['他是有明的侦探'], 51 | ['今天心情很不搓'], 52 | ['他法语说的很好,的语也不错'], 53 | ['这场比赛我甘败下风'], 54 | ] 55 | 56 | gr.Interface( 57 | ai_text, 58 | inputs="textbox", 59 | outputs=[ 60 | gr.outputs.HighlightedText( 61 | label="Output", 62 | show_legend=True, 63 | ), 64 | gr.outputs.JSON( 65 | label="JSON Output" 66 | ) 67 | ], 68 | title="中文纠错模型", 69 | description="输入一段话,判断这段话中是否有错别字或语法错误", 70 | article="Link to Github REPO", 71 | examples=examples 72 | ).launch() -------------------------------------------------------------------------------- /深度学习实战13(进阶版)-文本纠错功能,经常写错别字的小伙伴的福星/correct3.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import sys 4 | 5 | sys.path.append("..") 6 | from pycorrector.macbert.macbert_corrector import MacBertCorrector 7 | 8 | 9 | def use_origin_transformers(): 10 | # 原生transformers库调用 11 | import operator 12 | import torch 13 | from transformers import BertTokenizer, BertForMaskedLM 14 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 15 | 16 | tokenizer = BertTokenizer.from_pretrained("shibing624/macbert4csc-base-chinese") 17 | model = BertForMaskedLM.from_pretrained("shibing624/macbert4csc-base-chinese") 18 | model.to(device) 19 | 20 | texts = ["今天新情很好", "你找到你最喜欢的工作,我也很高心。", "我不唉“看 琅擤琊榜”"] 21 | 22 | text_tokens = tokenizer(texts, padding=True, return_tensors='pt').to(device) 23 | with torch.no_grad(): 24 | outputs = model(**text_tokens) 25 | 26 | def get_errors(corrected_text, origin_text): 27 | sub_details = [] 28 | for i, ori_char in enumerate(origin_text): 29 | if ori_char in [' ', '“', '”', '‘', '’', '\n', '…', '—', '擤']: 30 | # add unk word 31 | corrected_text = corrected_text[:i] + ori_char + corrected_text[i:] 32 | continue 33 | if i >= len(corrected_text): 34 | break 35 | if ori_char != corrected_text[i]: 36 | if ori_char.lower() == corrected_text[i]: 37 | # pass english upper char 38 | corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:] 39 | continue 40 | sub_details.append((ori_char, corrected_text[i], i, i + 1)) 41 | sub_details = sorted(sub_details, key=operator.itemgetter(2)) 42 | return corrected_text, sub_details 43 | 44 | result = [] 45 | for ids, (i, text) in zip(outputs.logits, enumerate(texts)): 46 | _text = tokenizer.decode((torch.argmax(ids, dim=-1) * text_tokens.attention_mask[i]), 47 | skip_special_tokens=True).replace(' ', '') 48 | corrected_text, details = get_errors(_text, text) 49 | print(text, ' => ', corrected_text, details) 50 | result.append((corrected_text, details)) 51 | print(result) 52 | return result 53 | 54 | 55 | if __name__ == '__main__': 56 | # 原生transformers库调用 57 | use_origin_transformers() 58 | 59 | # pycorrector封装调用 60 | error_sentences = [ 61 | '他是有明的侦探', 62 | '这场比赛我甘败下风', 63 | '这家伙还蛮格尽职守的', 64 | '报应接中迩来', 65 | '今天我很高形', 66 | '少先队员因该为老人让坐', 67 | '老是在较书。' 68 | ] 69 | 70 | m = MacBertCorrector() 71 | for line in error_sentences: 72 | correct_sent, err = m.macbert_correct(line) 73 | print("query:{} => {} err:{}".format(line, correct_sent, err)) -------------------------------------------------------------------------------- /深度学习实战13(进阶版)-文本纠错功能,经常写错别字的小伙伴的福星/correct4.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append("..") 4 | from pycorrector.t5.t5_corrector import T5Corrector 5 | 6 | if __name__ == '__main__': 7 | # pycorrector封装调用 8 | error_sentences = [ 9 | '他是有明的侦探', 10 | '这场比赛我甘败下风', 11 | '这家伙还蛮格尽职守的', 12 | '报应接中迩来', 13 | '今天我很高形', 14 | '少先队员因该为老人让坐', 15 | '老是在较书。' 16 | ] 17 | 18 | m = T5Corrector() 19 | res = m.batch_t5_correct(error_sentences) 20 | for line, r in zip(error_sentences, res): 21 | correct_sent, err = r[0], r[1] 22 | print("query:{} => {} err:{}".format(line, correct_sent, err)) -------------------------------------------------------------------------------- /深度学习实战14(进阶版)-手写文字OCR识别,手写笔记也可以识别了/123456.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战14(进阶版)-手写文字OCR识别,手写笔记也可以识别了/123456.png -------------------------------------------------------------------------------- /深度学习实战14(进阶版)-手写文字OCR识别,手写笔记也可以识别了/ocr_result/ndarray_1671255339.407687.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战14(进阶版)-手写文字OCR识别,手写笔记也可以识别了/ocr_result/ndarray_1671255339.407687.jpg -------------------------------------------------------------------------------- /深度学习实战14(进阶版)-手写文字OCR识别,手写笔记也可以识别了/writeOCR.py: -------------------------------------------------------------------------------- 1 | #模型导入 2 | import paddlehub as hub 3 | 4 | ocr = hub.Module(name="chinese_ocr_db_crnn_server") 5 | 6 | import cv2 7 | image_path = '123456.png' 8 | # 读取测试文件夹test.txt中的照片路径 9 | np_images =[cv2.imread(image_path)] 10 | 11 | results = ocr.recognize_text( 12 | images=np_images, # 图片数据,ndarray.shape 为 [H, W, C],BGR格式; 13 | use_gpu=False, # 是否使用 GPU;若使用GPU,请先设置CUDA_VISIBLE_DEVICES环境变量 14 | output_dir='ocr_result', # 图片的保存路径,默认设为 ocr_result; 15 | visualization=True, # 是否将识别结果保存为图片文件; 16 | box_thresh=0.5, # 检测文本框置信度的阈值; 17 | text_thresh=0.5) # 识别中文文本置信度的阈值; 18 | 19 | for result in results: 20 | data = result['data'] 21 | save_path = result['save_path'] 22 | for infomation in data: 23 | print('text: ', infomation['text'], '\nconfidence: ', infomation['confidence'], '\ntext_box_position: ', infomation['text_box_position']) -------------------------------------------------------------------------------- /深度学习实战14(进阶版)-手写文字OCR识别,手写笔记也可以识别了/writeOCR_new.py: -------------------------------------------------------------------------------- 1 | from modelscope.pipelines import pipeline 2 | from modelscope.utils.constant import Tasks 3 | from modelscope.outputs import OutputKeys 4 | 5 | # ModelScope Library >= 1.2.0 6 | ocr_recognize = pipeline(Tasks.ocr_recognition, model='damo/ofa_ocr-recognition_handwriting_base_zh', model_revision='v1.0.1') 7 | result = ocr_recognize('https://xingchen-data.oss-cn-zhangjiakou.aliyuncs.com/maas/ocr/ocr_handwriting_demo.png') 8 | print(result[OutputKeys.TEXT]) -------------------------------------------------------------------------------- /深度学习实战15(进阶版)-让机器进行阅读理解+你可以变成出题者提问/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战15(进阶版)-让机器进行阅读理解+你可以变成出题者提问/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /深度学习实战15(进阶版)-让机器进行阅读理解+你可以变成出题者提问/reading.py: -------------------------------------------------------------------------------- 1 | import paddle 2 | from paddlenlp.data import Stack, Dict, Pad 3 | import paddlenlp 4 | from paddlenlp.datasets import load_dataset 5 | from utils import prepare_train_features, prepare_validation_features 6 | from functools import partial 7 | 8 | train_ds, dev_ds, test_ds = load_dataset('dureader_robust', splits=('train', 'dev', 'test')) 9 | 10 | for idx in range(2): 11 | print(train_ds[idx]['question']) 12 | print(train_ds[idx]['context']) 13 | print(train_ds[idx]['answers']) 14 | print(train_ds[idx]['answer_starts']) 15 | print() 16 | 17 | # 设置模型名称 18 | MODEL_NAME = 'ernie-1.0' 19 | tokenizer = paddlenlp.transformers.ErnieTokenizer.from_pretrained(MODEL_NAME) 20 | 21 | max_seq_length = 512 22 | doc_stride = 128 23 | 24 | train_trans_func = partial(prepare_train_features, 25 | max_seq_length=max_seq_length, 26 | doc_stride=doc_stride, 27 | tokenizer=tokenizer) 28 | 29 | train_ds.map(train_trans_func, batched=True, num_workers=4) 30 | 31 | dev_trans_func = partial(prepare_validation_features, 32 | max_seq_length=max_seq_length, 33 | doc_stride=doc_stride, 34 | tokenizer=tokenizer) 35 | 36 | dev_ds.map(dev_trans_func, batched=True, num_workers=4) 37 | test_ds.map(dev_trans_func, batched=True, num_workers=4) 38 | 39 | for idx in range(2): 40 | print(train_ds[idx]['input_ids']) 41 | print(train_ds[idx]['token_type_ids']) 42 | print(train_ds[idx]['overflow_to_sample']) 43 | print(train_ds[idx]['offset_mapping']) 44 | print(train_ds[idx]['start_positions']) 45 | print(train_ds[idx]['end_positions']) 46 | 47 | batch_size = 12 48 | 49 | # 定义BatchSampler 50 | train_batch_sampler = paddle.io.DistributedBatchSampler( 51 | train_ds, batch_size=batch_size, shuffle=True) 52 | 53 | dev_batch_sampler = paddle.io.BatchSampler( 54 | dev_ds, batch_size=batch_size, shuffle=False) 55 | 56 | test_batch_sampler = paddle.io.BatchSampler( 57 | test_ds, batch_size=batch_size, shuffle=False) 58 | 59 | # 定义batchify_fn 60 | train_batchify_fn = lambda samples, fn=Dict({ 61 | "input_ids": Pad(axis=0, pad_val=tokenizer.pad_token_id), 62 | "token_type_ids": Pad(axis=0, pad_val=tokenizer.pad_token_type_id), 63 | "start_positions": Stack(dtype="int64"), 64 | "end_positions": Stack(dtype="int64") 65 | }): fn(samples) 66 | 67 | dev_batchify_fn = lambda samples, fn=Dict({ 68 | "input_ids": Pad(axis=0, pad_val=tokenizer.pad_token_id), 69 | "token_type_ids": Pad(axis=0, pad_val=tokenizer.pad_token_type_id) 70 | }): fn(samples) 71 | 72 | # 构造DataLoader 73 | train_data_loader = paddle.io.DataLoader( 74 | dataset=train_ds, 75 | batch_sampler=train_batch_sampler, 76 | collate_fn=train_batchify_fn, 77 | return_list=True) 78 | 79 | dev_data_loader = paddle.io.DataLoader( 80 | dataset=dev_ds, 81 | batch_sampler=dev_batch_sampler, 82 | collate_fn=dev_batchify_fn, 83 | return_list=True) 84 | 85 | test_data_loader = paddle.io.DataLoader( 86 | dataset=test_ds, 87 | batch_sampler=test_batch_sampler, 88 | collate_fn=dev_batchify_fn, 89 | return_list=True) 90 | 91 | for step, batch in enumerate(train_data_loader, start=1): 92 | input_ids, segment_ids, start_positions, end_positions = batch 93 | print(input_ids) 94 | break 95 | 96 | from paddlenlp.transformers import ErnieForQuestionAnswering 97 | 98 | # 模型加载 99 | model = ErnieForQuestionAnswering.from_pretrained(MODEL_NAME) 100 | 101 | 102 | # 损失函数设定 103 | class CrossEntropyLossForRobust(paddle.nn.Layer): 104 | def __init__(self): 105 | super(CrossEntropyLossForRobust, self).__init__() 106 | 107 | def forward(self, y, label): 108 | start_logits, end_logits = y # both shape are [batch_size, seq_len] 109 | start_position, end_position = label 110 | start_position = paddle.unsqueeze(start_position, axis=-1) 111 | end_position = paddle.unsqueeze(end_position, axis=-1) 112 | start_loss = paddle.nn.functional.softmax_with_cross_entropy( 113 | logits=start_logits, label=start_position, soft_label=False) 114 | start_loss = paddle.mean(start_loss) 115 | end_loss = paddle.nn.functional.softmax_with_cross_entropy( 116 | logits=end_logits, label=end_position, soft_label=False) 117 | end_loss = paddle.mean(end_loss) 118 | 119 | loss = (start_loss + end_loss) / 2 120 | return loss 121 | 122 | 123 | # 训练过程中的最大学习率 124 | learning_rate = 3e-5 125 | 126 | # 训练轮次 127 | epochs = 2 128 | 129 | # 学习率预热比例 130 | warmup_proportion = 0.1 131 | 132 | # 权重衰减系数,类似模型正则项策略,避免模型过拟合 133 | weight_decay = 0.01 134 | 135 | num_training_steps = len(train_data_loader) * epochs 136 | 137 | # 学习率衰减策略 138 | lr_scheduler = paddlenlp.transformers.LinearDecayWithWarmup(learning_rate, num_training_steps, warmup_proportion) 139 | 140 | decay_params = [ 141 | p.name for n, p in model.named_parameters() 142 | if not any(nd in n for nd in ["bias", "norm"]) 143 | ] 144 | optimizer = paddle.optimizer.AdamW( 145 | learning_rate=lr_scheduler, 146 | parameters=model.parameters(), 147 | weight_decay=weight_decay, 148 | apply_decay_param_fun=lambda x: x in decay_params) 149 | 150 | from utils import evaluate 151 | 152 | criterion = CrossEntropyLossForRobust() 153 | global_step = 0 154 | for epoch in range(1, epochs + 1): 155 | for step, batch in enumerate(train_data_loader, start=1): 156 | 157 | global_step += 1 158 | input_ids, segment_ids, start_positions, end_positions = batch 159 | logits = model(input_ids=input_ids, token_type_ids=segment_ids) 160 | loss = criterion(logits, (start_positions, end_positions)) 161 | 162 | if global_step % 100 == 0: 163 | print("global step %d, epoch: %d, batch: %d, loss: %.5f" % (global_step, epoch, step, loss)) 164 | 165 | loss.backward() 166 | optimizer.step() 167 | lr_scheduler.step() 168 | optimizer.clear_grad() 169 | 170 | evaluate(model=model, data_loader=dev_data_loader) 171 | -------------------------------------------------------------------------------- /深度学习实战2-(keras框架)企业信用评级与预测/enterprise_credit.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from keras.models import Sequential 3 | from keras.layers.core import Dense, Dropout, Activation 4 | import matplotlib .pyplot as plt 5 | import pandas as pd 6 | import csv 7 | 8 | data = pd.read_csv('train_new.csv',encoding = 'utf-8') 9 | 10 | # 提取数据特性x1,x2,x3,x4,作为训练集 11 | train = data[['x1', 'x2', 'x3', 'x4']] 12 | 13 | # 设置标签值 one-hot编码 14 | y_train = np.zeros((len(data), 5), dtype=np.int) 15 | for i in range(len(data)): 16 | y_train[i][data['class'][i]] = 1 17 | print(np.array(y_train)) 18 | 19 | model=Sequential() 20 | model.add(Dense(input_dim=4,units=666,activation='relu')) 21 | model.add(Dropout(0.5)) # Dropout(0.5) 表示随机丢弃50%的神经元,防止过拟合 22 | model.add(Dense(units=666,activation='relu')) 23 | model.add(Dropout(0.5)) 24 | model.add(Dense(units=666,activation='relu')) 25 | model.add(Dropout(0.5)) 26 | model.add(Dense(units=5,activation='softmax')) #输出层 输出5个等级结果 27 | 28 | model.compile(loss='mse',optimizer='adam',metrics=['acc']) 29 | history = model.fit(train,y_train,batch_size=123,epochs=500,validation_split=0.2) #训练500次 30 | 31 | weights = np.array(model.get_weights()) 32 | result2 = model.evaluate(train, y_train) 33 | 34 | # 绘制图形函数 35 | def show_train_history(history, train, validation): 36 | plt.plot(history.history[train]) 37 | plt.plot(history.history[validation]) 38 | plt.title('Train History') 39 | plt.ylabel(train) 40 | plt.xlabel('Epoch') 41 | plt.legend(['train', 'validation'], loc='upper left') 42 | plt.show() 43 | 44 | show_train_history(history,'acc','val_acc') 45 | 46 | show_train_history(history,'loss','val_loss') 47 | 48 | -------------------------------------------------------------------------------- /深度学习实战2-(keras框架)企业信用评级与预测/train_new.csv: -------------------------------------------------------------------------------- 1 | x1,x2,x3,x4,class 2 | 0.2457,0.0869,0.0393,0.1299,0 3 | 0.3208,0.0702,0.0202,0.5188,1 4 | 0.996,0.9047,0.0814,0.5806,1 5 | 0.0052,0.6423,0.0434,0.6413,1 6 | 0.0423,0.0869,0.0534,0.2042,0 7 | 0.2819,0.0482,0.0269,0.1622,0 8 | 0.1125,0.4617,0.0571,0.0824,0 9 | 0.1823,0.069,0.0235,0.1754,0 10 | 0.1994,0.0071,0.062,0.5745,1 11 | 0.0131,0.0911,0.0473,0.7423,3 12 | 0.0816,0.3036,0.0293,0.6389,1 13 | 0.0757,0.1208,0.0657,0.2987,0 14 | 0.0667,0.0869,0.0758,0.0823,4 15 | 0.1449,0.0127,0.0607,0.2766,0 16 | 0.0343,0.0869,0.0563,0.5554,0 17 | 0.002,0.2194,0.0688,0.322,0 18 | 0.0116,0.5628,0.0676,0.7386,1 19 | 0.0055,0.1466,0.054,0.4641,0 20 | 0.0279,0.0844,0.0902,0.6001,1 21 | 0.008,0.1005,0.0326,0.2532,0 22 | 0.0387,0.1011,0.0488,0.4903,0 23 | 0.0219,0.0753,0.0766,0.2867,4 24 | 0.0131,0.0911,0.0477,0.7433,3 25 | 0.0131,0.0901,0.0473,0.7723,3 26 | 0.0131,0.0911,0.0493,0.7423,3 27 | 0.0131,0.0911,0.0472,0.7423,3 28 | 0.0121,0.0911,0.0473,0.7423,3 29 | 0.0131,0.0911,0.0483,0.7623,3 30 | 0.0131,0.0911,0.0473,0.7423,3 31 | 0.0131,0.0911,0.0463,0.7523,3 32 | 0.0131,0.0911,0.0477,0.7423,3 33 | 0.0343,0.0819,0.0563,0.5554,0 34 | 0.002,0.2194,0.0688,0.222,0 35 | 0.0055,0.1426,0.054,0.4641,0 36 | 0.0387,0.1021,0.0488,0.4503,0 37 | 0.0229,0.0723,0.0706,0.288,4 38 | 0.0199,0.0753,0.0736,0.2877,4 39 | 0.0219,0.0713,0.0736,0.2967,4 40 | 0.0229,0.0733,0.0714,0.2867,4 41 | 0.0229,0.0723,0.0756,0.2807,4 42 | 0.0259,0.0752,0.0738,0.2467,4 43 | 0.1994,0.0071,0.062,0.5745,1 44 | 0.0816,0.3036,0.0291,0.6289,1 45 | 0.0126,0.5628,0.0676,0.7386,1 46 | 0.0259,0.0834,0.0902,0.6301,1 47 | 0.0219,0.0753,0.0736,0.2967,4 48 | 0.0229,0.0733,0.0734,0.2867,4 49 | 0.0219,0.0723,0.0756,0.2807,4 50 | 0.0329,0.0652,0.0738,0.2467,4 51 | -------------------------------------------------------------------------------- /深度学习实战3-文本卷积神经网络(TextCNN)新闻文本分类/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战3-文本卷积神经网络(TextCNN)新闻文本分类/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /深度学习实战3-文本卷积神经网络(TextCNN)新闻文本分类/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Created on 2020-07-19 00:20 5 | @Author : Justin Jiang 6 | @Email : jw_jiang@pku.edu.com 7 | 8 | 配置模型、路径、与训练相关参数 9 | """ 10 | 11 | class Config(object): 12 | def __init__(self): 13 | self.config_dict = { 14 | "data_path": { 15 | "vocab_path": "cnews.vocab.txt", 16 | "trainingSet_path": "cnews.train.txt", 17 | "valSet_path": "cnews.val.txt", 18 | "testingSet_path": "cnews.test.txt" 19 | }, 20 | "CNN_training_rule": { 21 | "embedding_dim": 64, 22 | "seq_length": 600, 23 | "num_classes": 3, 24 | 25 | "conv1_num_filters": 64, 26 | "conv1_kernel_size": 1, 27 | 28 | "conv2_num_filters": 64, 29 | "conv2_kernel_size": 1, 30 | 31 | "vocab_size": 5000, 32 | 33 | "hidden_dim": 128, 34 | 35 | "dropout_keep_prob": 0.5, 36 | "learning_rate": 1e-3, 37 | 38 | "batch_size": 10, 39 | "epochs": 5, 40 | 41 | "print_per_batch": 100, 42 | "save_per_batch": 1000 43 | }, 44 | "LSTM": { 45 | "seq_length": 600, 46 | "num_classes": 10, 47 | "vocab_size": 5000, 48 | "batch_size": 64 49 | }, 50 | "result": { 51 | "CNN_model_path": "../result/CNN_model.h5", 52 | "LSTM_model_path": "../result/LSTM_model.h5" 53 | } 54 | } 55 | 56 | def get(self, section, name): 57 | return self.config_dict[section][name] -------------------------------------------------------------------------------- /深度学习实战4-卷积神经网络(DenseNet)数学图形识别+题目模式识别/math_classif.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | # from keras import keras.layers 3 | import matplotlib.pyplot as plt 4 | from time import * 5 | # from tensorflow.keras.models import Sequential 6 | # from tensorflow.keras.layers import Dense,Dropout,MaxPooling2D,Flatten,Conv2D 7 | 8 | data_dir = "./data" 9 | 10 | batch_size = 12 11 | img_height = 224 12 | img_width = 224 13 | 14 | train_ds = tf.keras.preprocessing.image_dataset_from_directory( 15 | data_dir, 16 | validation_split=0.2, 17 | subset="training", 18 | seed=12, 19 | image_size=(img_height, img_width), 20 | batch_size=batch_size) 21 | 22 | val_ds = tf.keras.preprocessing.image_dataset_from_directory( 23 | data_dir, 24 | validation_split=0.2, 25 | subset="validation", 26 | seed=12, 27 | image_size=(img_height, img_width), 28 | batch_size=batch_size) 29 | 30 | model = tf.keras.applications.DenseNet121(weights='imagenet') 31 | model.summary() 32 | 33 | # 设置初始学习率 34 | initial_learning_rate = 1e-3 35 | 36 | lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( 37 | initial_learning_rate, 38 | decay_steps=5, # 敲黑板!!!这里是指 steps,不是指epochs 39 | decay_rate=0.96, # lr经过一次衰减就会变成 decay_rate*lr 40 | staircase=True) 41 | 42 | # 将指数衰减学习率送入优化器 43 | optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule) 44 | model.compile(optimizer=optimizer, 45 | loss ='sparse_categorical_crossentropy', 46 | metrics =['accuracy']) 47 | 48 | def show_loss_acc(history): 49 | # 从history中提取模型训练集和验证集准确率信息和误差信息 50 | acc = history.history['accuracy'] 51 | val_acc = history.history['val_accuracy'] 52 | loss = history.history['loss'] 53 | val_loss = history.history['val_loss'] 54 | 55 | # 按照上下结构将图画输出 56 | plt.figure(figsize=(8, 8)) 57 | plt.subplot(2, 1, 1) 58 | plt.plot(acc, label='Training Accuracy') 59 | plt.plot(val_acc, label='Validation Accuracy') 60 | plt.legend(loc='lower right') 61 | plt.ylabel('Accuracy') 62 | plt.ylim([min(plt.ylim()), 1]) 63 | plt.title('Training and Validation Accuracy') 64 | 65 | plt.subplot(2, 1, 2) 66 | plt.plot(loss, label='Training Loss') 67 | plt.plot(val_loss, label='Validation Loss') 68 | plt.legend(loc='upper right') 69 | plt.ylabel('Cross Entropy') 70 | plt.title('Training and Validation Loss') 71 | plt.xlabel('epoch') 72 | plt.savefig('results/results_cnn.png', dpi=100) 73 | plt.show() 74 | 75 | def train(epochs): 76 | # 开始训练,记录开始时间 77 | begin_time = time() 78 | AUTOTUNE = tf.data.AUTOTUNE 79 | train_ds1 = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE) 80 | val_ds1 = val_ds.cache().prefetch(buffer_size=AUTOTUNE) 81 | # print(class_names) 82 | # 加载模型 83 | # model = model_load(class_num=len(class_names)) 84 | # 指明训练的轮数epoch,开始训练 85 | history = model.fit(train_ds1,validation_data=val_ds1,epochs=epochs) 86 | # todo 保存模型, 修改为你要保存的模型的名称 87 | model.save("cnn_fv.h5") 88 | # 记录结束时间 89 | end_time = time() 90 | run_time = end_time - begin_time 91 | print('该循环程序运行时间:', run_time, "s") # 该循环程序运行时间: 1.4201874732 92 | # 绘制模型训练过程图 93 | show_loss_acc(history) 94 | 95 | 96 | train(epochs=6) 97 | 98 | -------------------------------------------------------------------------------- /深度学习实战5-卷积神经网络(CNN)中文OCR识别项目/DroidSansFallback.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战5-卷积神经网络(CNN)中文OCR识别项目/DroidSansFallback.ttf -------------------------------------------------------------------------------- /深度学习实战5-卷积神经网络(CNN)中文OCR识别项目/chinese.txt: -------------------------------------------------------------------------------- 1 | 0123456789QWERTYUIOPASDFGHJKLZXCVBNM“”《》的一是不人有了在你我个大中要这为上生时会以就子到来可能和自们年多发心好用家出关长他成天对也小后下学都点国过地行信方得最说二业分作如看女于面注别经动公开现而美么还事己理维没之情高法全很日体里工微者实力做等水加定果去所新活着让起市身间码品进孩前想道种识按同车本然月机性与那无手爱样因老内部每更意号电其重化当只文入产合些她三费通但感常明给主名保提将元话气从教相平物场量资知或外度金正次期问放头位安比真务男第解原制区消路及色网花把打吃系回此应友选什表商再万妈被并两题服少风食变容员交儿质建民价养房门需影请利管白简司代口受图处才特报城单西完使已目收十候山数展快强式精结东师求接至海片清各直带程世向先任记持格总运联计觉何太线又免热件权调专医乐效神击设钱健流由见台几增病投易南导功介证走今光朋即视造您立改母推眼复政买传认非基宝营院四习越包游转技条息血科难规众喜便创干界示广红住欢源指该观读享深油达告具取轻康型周装张五满店亲标查育配字类优始整据考案北它客火必购办社命味步护术阅吧素户往菜适边却失节料较形近级准皮衣书马超照值父怎试空切找华供米企助反望香足福且排阳统未治决确项除低根岁则百备像早领酒款防集环富财跟致瘦速择温销团离呢议论吗王州态思参许远责布编随细春克听减言招组景穿黄药肉售股首限检修验共约段笑洗况续底园帮引婚份历济险士错语村伤局票善校战际益职够晚极支存旅故含算送诉留角松积省仅江境称半星升象材预群获青终害肤属显卡餐银声站队落假县饭补研连德哪钟遇黑双待毒断充智演讲压农愿尽拉粉响死牌古货玩苦率千施蛋器楼痛究睡状订义绝石亮势音搭委斯居李紧坚脸独依丽严止疗右喝鸡牛林板某负京丰句评融军懂吸划念夫层降哦税豆彩官络胸拿画尔龙察班构秘否叫球幸座慢兴佛室啊均付模协互置般英净换短左版课茶策毛停河肥答良久承控激范章云普套另奖须例写灵担志顾草镇退希谢爸采六鱼围密庭脑奇八卖童土圈谁拥糖监甚怕贵顺鲜冷差梦警拍铁亿争夜背永街律饮继刻初突倒聘木熟婆列频虽刚妆举尚汽曾脚奶破静驾块蓝酸核锅艺绿博额陈坐靠巧掉飞盘币腿巴培若闻史亚纸症季叶乡丝询剧礼七址添织略虚迎摄余乎缺胃爆域妻练荐临佳府追患树颜诚伴湖贴午困似测肝归宁暖纳宜阿异卫录液私谈泡惊索盐漂损稳休折讯堂怀惠汤纪散藏湿透令冰妇麻醒宣抗典执秀肌训刘急赶播苏淡革阴批盖腰肠脱印硬促冲床努脏跑雅厅罗惯族姐犯罪赛趣骨烧哈避征劳载润炒软慧驶妹占租馆累签副键煮尊予缘港雨兰斤呼申障坏竟疑顶饰九炎歌审戏借误辆端沙掌恶疾露括固移脂武寒零烟毕雪登朝聚笔姓波涨救厂央咨党延耳危斑汉沉夏侧鞋牙媒腹龄励瓜敢忙宽箱释操输抱野癌守搞染姜默翻哥洁娘挑凉末潮违附杀宫迷杂弱岛础贫析乱乳辣弃桃轮浪赏抽镜盛胜玉烦植绍恋冒缓渐虑肯赚绩忘珍恩针猪既聊蜜握舞甜败汇抓刺骗杯啦灯赞寻仍陪涉椒荣哭欲词巨圆刷概沟幼尤偏斗胡启尼述弟屋田判触柔忍架吉肾狗欧遍甘瓶综曲威齐桥纯阶贷丁伙眠罚逐韩封扎厚著督冬舒杨惜汁庆迪洋洲旧映疼席暴漫辈射鼓葱侵羊倍挂束幅碗裤胖旺川搜航弹嘴派脾届托库唯奥菌君途讨券距粗诗授祛谓序账凡晓峰剂筑敏肚暗辑访岗腐痘摩烈扬谷纹遗偿穷帝尿腾禁竞豪苹跳挥抢卷胆递珠敬甲乘孕绪纷隐滑浓膜姑探宗姻诺摆狂篇睛闲勇蒜尾旦庄窗扫辛陆塑幕聪详污圳扮肿楚忆匀炼耐衡措铺薪泰懒贝磨怨鼻圣孙眉泉洞焦毫戴旁符泪邮爷钢混厨抵灰献扣怪碎擦胎缩扶恐欣顿伟丈皇蒙胞尝寿攻曲威齐桥纯阶贷丁伙眠罚逐韩封扎厚著督冬舒杨惜汁庆迪洋洲旧映疼席暴漫辈射鼓葱侵羊倍挂束幅碗裤胖旺川搜航弹嘴派脾届托库唯奥菌君途讨券距粗诗授祛谓序账凡晓峰剂筑敏肚暗辑访岗腐痘摩烈扬谷纹遗偿穷帝尿腾禁竞豪苹跳挥抢卷胆递珠敬甲乘孕绪纷隐滑浓膜姑探宗姻诺摆狂篇睛闲勇蒜尾旦庄窗扫辛陆塑幕聪详污圳扮肿楚忆匀炼耐衡措铺薪泰懒贝磨怨鼻圣孙眉泉洞焦毫戴旁符泪邮爷钢混厨抵灰献扣怪碎擦胎缩扶恐欣顿伟丈皇蒙胞尝寿攻仁津潜滴晨颗舍秒刀酱悲妙隔桌册迹仔闭奋袋墙嫌萝唐跌尖莫拌赔忽宿扩胶雷燕衰挺宋湾脉凭丹繁拒肺涂郁剩仪紫滋泽薄森唱残虎档猫麦劲偶秋疯俗悉弄船雄兵晒扰蒸悟肪览籍丑拼诊吴循偷灭伸赢魅勤旗亡乏估仁津潜滴晨颗舍秒刀酱悲妙隔桌册迹仔闭奋袋墙嫌萝唐跌尖莫拌赔忽宿扩胶雷燕衰挺宋湾脉凭丹繁拒肺涂郁剩仪紫滋泽薄森唱残虎档猫麦劲偶秋疯俗悉弄船雄兵晒扰蒸悟肪览籍丑拼诊吴循偷灭伸赢魅勤旗亡乏估替吐碰淘彻逼氧梅遭孔稿嘉卜赵姿储呈乌娱闹裙倾震忧貌萨塞鬼池沿畅盟仙醋炸粥咖瑜返稍灾肩殊逃荷描朱朵横徐杰陷迟莱纠榜债烂伽拟匙圾巾恼誉垃颈壁链糊悦屏浮魔毁拜宾迅芝燃迫疫柜烤塔赠伪阻绕饱辅醉抑撒粘丢卧徒奔锁董枣截番蔬摇亦趋冠呀疲婴诸贸泥伦嫁祖朗琴拔孤摸壮帅阵梁宅啥伊鲁怒熊艾裁犹撑莲吹纤昨谱咳蜂闪嫩瞬霸兼恨昌踏瑞樱萌厕郑寺愈傻慈汗奉缴暂爽堆浙忌慎坦撞耗粒仿诱凤矿锻馨尘兄杭虫熬赖恰恒鸟猛唇幻窍浸诀填亏覆盆彼腺胀苗竹魂吵刑惑岸傲垂呵荒页抹揭贪宇泛劣臭呆梯啡径咱筹娃鉴禅召艳澳恢践迁废燥裂兔溪梨饼勺碍穴坛诈宏井仓删挣柳戒腔涵寸弯隆插祝氏泌盒邀煤膏棒跨拖葡骂喷肖洛盈浅逆夹贤晶厌侠欺敌钙冻盾桂仰滚萄厦牵疏齿挡孝滨吨渠囊慕捷淋桶脆沫辉耀渴邪轨悔猎煎沈虾醇贯衫荡谋携晋糕玻肃杜皆秦盗臂舌杆俱棉挤拨剪阔稀腻骑玛忠伯伍狠宠勒浴勿媳晕佩屈纵奈抬栏菲坑茄雾坡幽跃坊枝凝拳谨筋菇锋璃郭钻酷愁摘捐谐遵苍飘搅漏泄祥锦衬矛猜凌挖喊猴芳曼痕鼠允叔牢绘嘛吓振墨烫厉昆拓卵凯淀皱枪尺疆姆笋粮邻菩署柠遮艰芽爬夸捞叹缝妨奏岩寄吊狮剑驻洪夺募凶辨崇莓斜檬悬瘤欠刊曝傅悠椅戳棋慰丧拆绵炉徽驱曹履俄兑闷赋狼愉纽膝饿窝辞躺瓦逢堪薯哟袭壳咽岭槽雕昂闺御旋寨抛祸殖喂俩贡狐弥遥桑搬陌陶乃寂滩妥辰堵蛇侣邦蝙陵洒浆蹲惧霜丸娜扔肢姨援炫岳迈躁蝠埋泻巡溶氛械翠陕乔漠滞哲浩驰摊糟铜赤谅蕉昏劝杞扭骤杏娇渡抚羡娶串碧叉廉膀柱垫伏痒捕咸瓣庙敷卑碑沸鸭纱赴盲辽疮浦逛愤黎耍咬仲枸催喉瑰勾臀泼椎翼奢郎杠碳谎悄瓷绑酬菠朴割砖惨埃霍耶仇嗽塘邓漆蹈鹰披橘薇溃拾押炖霉痰袖巢帽宴卓踪屁刮晰豫玫驭羞讼茫厘扑亩鸣罐澡惩刹啤揉纲腥沾陀蟹枕逸橙梳浑毅吕泳碱缠柿砂羽黏芹馈柴侦卢憾疹贺牧俊峡硕倡蓄赌吞躲旨崩寞碌堡唤韭趁惹衷蛮译彭掩攀慌牲栋鼎鹅弘敲诞撕卦腌葛舟寓氨弗伞罩芒沃棚契巷磁浇逻廊棵溢箭匹矩颇爹玲绒雀鸿贩锐曰蕾竭剥沪畏渣歉摔旬颖茂擅铃淮叠挫逗晴柏舰翁框涌琳罢辩勃霞裹烹庸臣莉匆熙轩骄煲掘搓乙痴恭韵渗薏炭痣锡丨脊夕丘苑蔡裸灌庞龟窄兽敦辟牺僧怜湘篮妖喘瘾蓬挽芦谦踩辱辖捧坠滤炮撩狱亭虹吻煌谊枯脐螺扇抖戚怖帐盼冯劫墓崔酵殿蝶袁袜枚芯绳颠耕壶叨乖呕筷捡鹿潘笨扁渔株斥砸涩倦沥丛翔吼裕翰蒂尸莴暑肴凰馅阁誓匠侯韧钥哒狸媚壤驴逝渍嘲颁谜翅笼冈蓉脖甩扯宙叛帖萧芬潭涛闯泊宰梗鑫祭嚼卸尬尴怡咒晾嚣哄掏哀盯腊灿涯钞轰髦斌茅骚咋茨蝇枢捣顽彰拘坎役砍皂汪孟筱愚滥妒塌轿窃喻胁钓墅糙浏愧赫捏妮溜谣膳郊睫沧撤搏汰鹏菊帘秤衔捉鹤贿廷撼钾绽轴凸魄晃磷蒋栽荆蠢魏蜡缸筒遂茎芭伐邵瞎帕凑唠祈赁秩辫玄酶潇稻兜婷栓屡削钉拭蕴糯煞坪兹妃兆沂纺酿柚瀑稠腕勉疡贱冀跪凹辜铭赐绎灶弛嫉姚慨褐翘饶焯蒲哎僵隙犬剖昧湛矮舆吾剔甄逾虐粹牡莎罕蠕拐琪瑟霾辐帆拇榨冤绣痔筛雇祷歪贼肛垢抄饺琐裔黛睁捂萎酥饥衍靓榄嗯肆咯槛寡诵贬瞧乞贾弓珊眸屑熏籽乾聆狭韦锈毯蹄涤磊赘歇坝豹橄葬竖奴磅蝴淑柯敛侈叙惫俞翡叮蜀逊葩拯咪喔灸橱函厢瑶橡俯沛嘱佣陇莞妄榆淫靖俏敞嫂烘腑崖扒洽宵膨亨妞硫剁秉淤婉稣筝屌挨儒哑铅斩阱钩睿彬啪琼桩萍蔓焖踢铝仗荠棍棕铸榴惕巩杉芋攒髓拦蝎飙栗畜挪冥藤坤嘿磕椰憋荟坞屯饲懈梭夷嘘沐蔗蚕粕吁卉昭饪钮恳睦讶穆拣傍岂蘸噪戈靴瑕龈讽泣浊哇趾蔽丫歧蚊暨钠芪艇暮擎畔禽拧惟俭蔚恤蚀尹侍馒锌骼咏堕渊桐窒焕阀藕耻躯薛菱谭豁昕喧藉丙鸦驼拢奸爪睹绸暧佐颊澜禄缀煸趟揽蘑瘀阜拎屎颤邑胰肇哺噢矫讳雌怠楂苛暇酪佑妍婿耿妊萃灼丶澄撰弊挚庐雯靡牟硝酮醛苓紊肘趴廓卤昔鄂哮赣汕貅渝媛貔彦荫觅蹭巅岚甸漓迦邂稚濮陋逅窑笈弧颐禾瘙脓刃愣拴旭蚁滔仕荔琢澈睐隶粤盏遣汾镁硅枫淹仆胺娠舅弦殷惰麟苔芙堤旱蛙驳羯涕侨铲糜烯扛腮猿烛昵韶莹洱诠襄棠鸽仑峻啃瞒喇绊胱咙踝褶娩鲍掀漱绅奠芡蜗疤兮矣熔俺掰拱骏贞姥哼倘栖屉眷渭幢芜溺茯袍淳沦绞倪缚碟雁孵粪崛舱褪诡悍芸宪壹诫窟葵呐锤摧碾鞭嗓呱芥 -------------------------------------------------------------------------------- /深度学习实战6-卷积神经网络(Pytorch)+聚类分析实现空气质量与天气预测/net.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战6-卷积神经网络(Pytorch)+聚类分析实现空气质量与天气预测/net.pkl -------------------------------------------------------------------------------- /深度学习实战6-卷积神经网络(Pytorch)+聚类分析实现空气质量与天气预测/weather.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.data as Data 4 | import numpy as np 5 | import pymysql 6 | import datetime 7 | import csv 8 | import time 9 | import matplotlib.pyplot as plt 10 | import pandas as pd 11 | from sklearn.cluster import KMeans 12 | from sklearn.decomposition import PCA 13 | import os 14 | 15 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 16 | data =pd.read_csv("weather.csv",encoding='gb18030') 17 | data = data.drop(columns=data.columns[-1]) 18 | print(data) 19 | 20 | pca = PCA(n_components=2) 21 | new_pca = pd.DataFrame(pca.fit_transform(data)) 22 | X = new_pca.values 23 | print(new_pca) 24 | 25 | kms = KMeans(n_clusters=6) # 6表示聚类的个数 26 | #获取类别标签 27 | Y= kms.fit_predict(data) 28 | data['class'] = Y 29 | data.to_csv("weather_new.csv",index=False) #保存文件 30 | 31 | #绘制聚类发布图 32 | d = new_pca[Y == 0] 33 | plt.plot(d[0], d[1], 'r.') 34 | d = new_pca[Y == 1] 35 | plt.plot(d[0], d[1], 'g.') 36 | d = new_pca[Y == 2] 37 | plt.plot(d[0], d[1], 'b.') 38 | d = new_pca[Y == 3] 39 | plt.plot(d[0], d[1], 'y.') 40 | d = new_pca[Y == 4] 41 | plt.plot(d[0], d[1],'c.') 42 | d = new_pca[Y == 5] 43 | plt.plot(d[0], d[1],'k.') 44 | #plt.show() 45 | 46 | 47 | class MyNet(nn.Module): 48 | def __init__(self): 49 | super(MyNet, self).__init__() 50 | self.con1 = nn.Sequential( 51 | nn.Conv1d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1), 52 | nn.MaxPool1d(kernel_size=1), 53 | nn.ReLU(), 54 | ) 55 | self.con2 = nn.Sequential( 56 | nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), 57 | nn.MaxPool1d(kernel_size=1), 58 | nn.ReLU(), 59 | ) 60 | self.fc = nn.Sequential( 61 | # 线性分类器 62 | nn.Linear(128 * 6 * 1, 128), 63 | nn.ReLU(), 64 | nn.Linear(128, 6), 65 | # nn.Softmax(dim=1), 66 | ) 67 | self.mls = nn.MSELoss() 68 | self.opt = torch.optim.Adam(params=self.parameters(), lr=1e-3) 69 | self.start = datetime.datetime.now() 70 | 71 | def forward(self, inputs): 72 | out = self.con1(inputs) 73 | out = self.con2(out) 74 | out = out.view(out.size(0), -1) # 展开成一维 75 | out = self.fc(out) 76 | return out 77 | 78 | def train(self, x, y): 79 | out = self.forward(x) 80 | loss = self.mls(out, y) 81 | self.opt.zero_grad() 82 | loss.backward() 83 | self.opt.step() 84 | 85 | return loss 86 | 87 | def test(self, x): 88 | out = self.forward(x) 89 | return out 90 | 91 | def get_data(self): 92 | with open('weather_new.csv', 'r') as f: 93 | results = csv.reader(f) 94 | results = [row for row in results] 95 | results = results[1:1500] 96 | inputs = [] 97 | labels = [] 98 | for result in results: 99 | # one-hot独热编码 100 | one_hot = [0 for i in range(6)] 101 | index = int(result[6]) - 1 102 | one_hot[index] = 1 103 | labels.append(one_hot) 104 | input = result[:6] 105 | input = [float(x) for x in input] 106 | 107 | inputs.append(input) 108 | 109 | inputs = np.array(inputs) 110 | labels = np.array(labels) 111 | inputs = torch.from_numpy(inputs).float() 112 | inputs = torch.unsqueeze(inputs, 1) 113 | 114 | labels = torch.from_numpy(labels).float() 115 | return inputs, labels 116 | 117 | def get_test_data(self): 118 | with open('weather_new.csv', 'r') as f: 119 | results = csv.reader(f) 120 | results = [row for row in results] 121 | results = results[1500: 1817] 122 | inputs = [] 123 | labels = [] 124 | for result in results: 125 | label = [result[6]] 126 | input = result[:6] 127 | input = [float(x) for x in input] 128 | label = [float(y) for y in label] 129 | inputs.append(input) 130 | labels.append(label) 131 | inputs = np.array(inputs) 132 | 133 | inputs = torch.from_numpy(inputs).float() 134 | inputs = torch.unsqueeze(inputs, 1) 135 | labels = np.array(labels) 136 | labels = torch.from_numpy(labels).float() 137 | return inputs, labels 138 | 139 | if __name__ == '__main__': 140 | EPOCH = 100 141 | BATCH_SIZE = 50 142 | 143 | net = MyNet() 144 | x_data, y_data = net.get_data() 145 | torch_dataset = Data.TensorDataset(x_data, y_data) 146 | loader = Data.DataLoader( 147 | dataset=torch_dataset, 148 | batch_size=BATCH_SIZE, 149 | shuffle=True, 150 | num_workers=2, 151 | ) 152 | for epoch in range(EPOCH): 153 | for step, (batch_x, batch_y) in enumerate(loader): 154 | # print(step) 155 | # print(step,'batch_x={}; batch_y={}'.format(batch_x, batch_y)) 156 | a = net.train(batch_x, batch_y) 157 | print('step:',step,a) 158 | # 保存模型 159 | torch.save(net, 'net.pkl') 160 | 161 | # 加载模型 162 | net = torch.load('net.pkl') 163 | x_data, y_data = net.get_test_data() 164 | torch_dataset = Data.TensorDataset(x_data, y_data) 165 | loader = Data.DataLoader( 166 | dataset=torch_dataset, 167 | batch_size=100, 168 | shuffle=False, 169 | num_workers=1, 170 | ) 171 | num_success = 0 172 | num_sum = 317 173 | for step, (batch_x, batch_y) in enumerate(loader): 174 | # print(step) 175 | output = net.test(batch_x) 176 | # output = output.detach().numpy() 177 | y = batch_y.detach().numpy() 178 | for index, i in enumerate(output): 179 | i = i.detach().numpy() 180 | i = i.tolist() 181 | j = i.index(max(i)) 182 | print('输出为{}标签为{}'.format(j+1, y[index][0])) 183 | loss = j+1-y[index][0] 184 | if loss == 0.0: 185 | num_success += 1 186 | print('正确率为{}'.format(num_success/num_sum)) 187 | -------------------------------------------------------------------------------- /深度学习实战7-电商产品评论的情感分析/Sentiment.py: -------------------------------------------------------------------------------- 1 | import data_loader 2 | from tensorflow.keras.preprocessing import sequence 3 | from tensorflow.keras.models import Sequential 4 | from tensorflow.keras.layers import Dense, Embedding 5 | from tensorflow.keras.layers import Flatten 6 | from tensorflow.keras.utils import to_categorical 7 | import numpy as np 8 | 9 | x_train,y_train,x_test,y_test =data_loader.load_data() 10 | 11 | #创建评论数据的词库索引 12 | vocalen,word_index = data_loader.createWordIndex(x_train,x_test) 13 | #print(vocalen) 14 | 15 | #获取训练数据每个词的索引 16 | x_train_index =data_loader.word2Index(x_train,word_index) 17 | x_test_index=data_loader.word2Index(x_test,word_index) 18 | 19 | #最大长度的限制 20 | maxlen =25 21 | x_train_index =sequence.pad_sequences(x_train_index,maxlen=maxlen ) 22 | x_test_index =sequence.pad_sequences(x_test_index,maxlen=maxlen) 23 | y_train= to_categorical(y_train) 24 | y_test= to_categorical(y_test) 25 | 26 | model =Sequential() 27 | model.add(Embedding(trainable=False, input_dim= vocalen+1, output_dim=300, input_length=maxlen)) 28 | model.add(Flatten()) 29 | model.add(Dense(256, activation='relu')) 30 | model.add(Dense(256, activation= 'relu')) 31 | model.add(Dense(256, activation='relu')) 32 | model.add(Dense(2, activation= 'sigmoid')) 33 | model.compile(loss='binary_crossentropy', optimizer='adam',metrics=['accuracy']) #二分类问题 34 | 35 | print(x_train_index, y_train) 36 | model.fit(x_train_index, y_train,batch_size=512, epochs=200) 37 | score, acc = model.evaluate(x_test_index, y_test) 38 | print('Test score:', score) 39 | print('test accuracy:',acc) 40 | 41 | test = np.array([x_test_index[1000]]) 42 | print(test) 43 | print(test.shape) 44 | 45 | predict = model.predict(test) 46 | print(predict) 47 | print(np.argmax(predict,axis=1)) 48 | 49 | -------------------------------------------------------------------------------- /深度学习实战7-电商产品评论的情感分析/__pycache__/data_loader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战7-电商产品评论的情感分析/__pycache__/data_loader.cpython-38.pyc -------------------------------------------------------------------------------- /深度学习实战7-电商产品评论的情感分析/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import keras 3 | import numpy as np 4 | import keras.preprocessing.text as text 5 | import re 6 | import jieba 7 | import random 8 | 9 | 10 | def load_data(): 11 | xs = [] 12 | ys = [] 13 | with open('online_shopping_10_cats.csv', 'r', encoding='utf-8') as f: 14 | line = f.readline() # escape first line"label review" 15 | while line: 16 | line = f.readline() 17 | if not line: 18 | break 19 | contents = line.split(',') 20 | 21 | # if contents[0]=="书籍": 22 | # continue 23 | 24 | label = int(contents[1]) 25 | review = contents[2] 26 | if len(review) > 20: 27 | continue 28 | 29 | xs.append(review) 30 | ys.append(label) 31 | 32 | xs = np.array(xs) 33 | ys = np.array(ys) 34 | 35 | # 打乱数据集 36 | indies = [i for i in range(len(xs))] 37 | random.seed(666) 38 | random.shuffle(indies) 39 | xs = xs[indies] 40 | ys = ys[indies] 41 | 42 | m = len(xs) 43 | cutpoint = int(m * 3 / 5) 44 | x_train = xs[:cutpoint] 45 | y_train = ys[:cutpoint] 46 | 47 | x_test = xs[cutpoint:] 48 | y_test = ys[cutpoint:] 49 | 50 | print(x_train) 51 | print(y_train) 52 | 53 | print('总样本数量:%d' % (len(xs))) 54 | print('训练集数量:%d' % (len(x_train))) 55 | print('测试集数量:%d' % (len(x_test))) 56 | 57 | return x_train, y_train, x_test, y_test 58 | 59 | 60 | load_data() 61 | 62 | 63 | def createWordIndex(x_train, x_test): 64 | x_all = np.concatenate((x_train, x_test), axis=0) 65 | # 建立词索引 66 | tokenizer = text.Tokenizer() 67 | # create word index 68 | word_dic = {} 69 | voca = [] 70 | for sentence in x_all: 71 | # 去掉标点 72 | sentence = re.sub("[\s+\.\!\/_,$%^*(+\"\']+|[+——!,。?、~@#¥%……&*()]+", "", sentence) 73 | # 结巴分词 74 | cut = jieba.cut(sentence) 75 | # cut_list = [ i for i in cut ] 76 | 77 | for word in cut: 78 | if not (word in word_dic): 79 | word_dic[word] = 0 80 | else: 81 | word_dic[word] += 1 82 | voca.append(word) 83 | word_dic = sorted(word_dic.items(), key=lambda kv: kv[1], reverse=True) 84 | 85 | voca = [v[0] for v in word_dic] 86 | 87 | tokenizer.fit_on_texts(voca) 88 | print("voca:" + str(len(voca))) 89 | return len(voca), tokenizer.word_index 90 | 91 | 92 | def word2Index(words, word_index): 93 | vecs = [] 94 | for sentence in words: 95 | # 去掉标点 96 | sentence = re.sub("[\s+\.\!\/_,$%^*(+\"\']+|[+——!,。?、~@#¥%……&*()]+", "", sentence) 97 | # 结巴分词 98 | cut = jieba.cut(sentence) 99 | # cut_list = [ i for i in cut ] 100 | index = [] 101 | 102 | for word in cut: 103 | if word in word_index: 104 | index.append(float(word_index[word])) 105 | 106 | # if len(index)>25: 107 | # index = index[0:25] 108 | vecs.append(np.array(index)) 109 | 110 | return np.array(vecs) 111 | -------------------------------------------------------------------------------- /深度学习实战8-生活照片转化漫画照片应用/img2cartoon.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from modelscope.outputs import OutputKeys 3 | from modelscope.pipelines import pipeline 4 | from modelscope.utils.constant import Tasks 5 | from PIL import Image 6 | 7 | # 目前支持单人的漫画生成 8 | img_cartoon = pipeline(Tasks.image_portrait_stylization, 9 | model='damo/cv_unet_person-image-cartoon_compound-models', 10 | device='cpu') 11 | # 图像本地路径 12 | img_path = 'new_image.png' 13 | 14 | # img = Image.open(img_path) # 调整图片大小为800x640 15 | # new_size = (480, 270) 16 | # img = img.resize(new_size) # 保存修改后的图片 17 | # img.save("new_image.png") 18 | 19 | result = img_cartoon(img_path) 20 | cv2.imwrite('result21.png', result[OutputKeys.OUTPUT_IMG]) 21 | print('完成!') -------------------------------------------------------------------------------- /深度学习实战8-生活照片转化漫画照片应用/input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战8-生活照片转化漫画照片应用/input.png -------------------------------------------------------------------------------- /深度学习实战8-生活照片转化漫画照片应用/new_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战8-生活照片转化漫画照片应用/new_image.png -------------------------------------------------------------------------------- /深度学习实战8-生活照片转化漫画照片应用/result0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战8-生活照片转化漫画照片应用/result0.png -------------------------------------------------------------------------------- /深度学习实战8-生活照片转化漫画照片应用/result2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战8-生活照片转化漫画照片应用/result2.png -------------------------------------------------------------------------------- /深度学习实战8-生活照片转化漫画照片应用/result21.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战8-生活照片转化漫画照片应用/result21.png -------------------------------------------------------------------------------- /深度学习实战9-文本生成图像-本地电脑实现text2img/stable_diffusion_tf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战9-文本生成图像-本地电脑实现text2img/stable_diffusion_tf/__init__.py -------------------------------------------------------------------------------- /深度学习实战9-文本生成图像-本地电脑实现text2img/stable_diffusion_tf/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战9-文本生成图像-本地电脑实现text2img/stable_diffusion_tf/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /深度学习实战9-文本生成图像-本地电脑实现text2img/stable_diffusion_tf/__pycache__/autoencoder_kl.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战9-文本生成图像-本地电脑实现text2img/stable_diffusion_tf/__pycache__/autoencoder_kl.cpython-38.pyc -------------------------------------------------------------------------------- /深度学习实战9-文本生成图像-本地电脑实现text2img/stable_diffusion_tf/__pycache__/clip_encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战9-文本生成图像-本地电脑实现text2img/stable_diffusion_tf/__pycache__/clip_encoder.cpython-38.pyc -------------------------------------------------------------------------------- /深度学习实战9-文本生成图像-本地电脑实现text2img/stable_diffusion_tf/__pycache__/constants.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战9-文本生成图像-本地电脑实现text2img/stable_diffusion_tf/__pycache__/constants.cpython-38.pyc -------------------------------------------------------------------------------- /深度学习实战9-文本生成图像-本地电脑实现text2img/stable_diffusion_tf/__pycache__/diffusion_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战9-文本生成图像-本地电脑实现text2img/stable_diffusion_tf/__pycache__/diffusion_model.cpython-38.pyc -------------------------------------------------------------------------------- /深度学习实战9-文本生成图像-本地电脑实现text2img/stable_diffusion_tf/__pycache__/layers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战9-文本生成图像-本地电脑实现text2img/stable_diffusion_tf/__pycache__/layers.cpython-38.pyc -------------------------------------------------------------------------------- /深度学习实战9-文本生成图像-本地电脑实现text2img/stable_diffusion_tf/__pycache__/stable_diffusion.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战9-文本生成图像-本地电脑实现text2img/stable_diffusion_tf/__pycache__/stable_diffusion.cpython-38.pyc -------------------------------------------------------------------------------- /深度学习实战9-文本生成图像-本地电脑实现text2img/stable_diffusion_tf/autoencoder_kl.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow import keras 3 | import tensorflow_addons as tfa 4 | 5 | from .layers import apply_seq, PaddedConv2D 6 | 7 | 8 | class AttentionBlock(keras.layers.Layer): 9 | def __init__(self, channels): 10 | super().__init__() 11 | self.norm = tfa.layers.GroupNormalization(epsilon=1e-5) 12 | self.q = PaddedConv2D(channels, 1) 13 | self.k = PaddedConv2D(channels, 1) 14 | self.v = PaddedConv2D(channels, 1) 15 | self.proj_out = PaddedConv2D(channels, 1) 16 | 17 | def call(self, x): 18 | h_ = self.norm(x) 19 | q, k, v = self.q(h_), self.k(h_), self.v(h_) 20 | 21 | # Compute attention 22 | b, h, w, c = q.shape 23 | q = tf.reshape(q, (-1, h * w, c)) # b,hw,c 24 | k = keras.layers.Permute((3, 1, 2))(k) 25 | k = tf.reshape(k, (-1, c, h * w)) # b,c,hw 26 | w_ = q @ k 27 | w_ = w_ * (c ** (-0.5)) 28 | w_ = keras.activations.softmax(w_) 29 | 30 | # Attend to values 31 | v = keras.layers.Permute((3, 1, 2))(v) 32 | v = tf.reshape(v, (-1, c, h * w)) 33 | w_ = keras.layers.Permute((2, 1))(w_) 34 | h_ = v @ w_ 35 | h_ = keras.layers.Permute((2, 1))(h_) 36 | h_ = tf.reshape(h_, (-1, h, w, c)) 37 | return x + self.proj_out(h_) 38 | 39 | 40 | class ResnetBlock(keras.layers.Layer): 41 | def __init__(self, in_channels, out_channels): 42 | super().__init__() 43 | self.norm1 = tfa.layers.GroupNormalization(epsilon=1e-5) 44 | self.conv1 = PaddedConv2D(out_channels, 3, padding=1) 45 | self.norm2 = tfa.layers.GroupNormalization(epsilon=1e-5) 46 | self.conv2 = PaddedConv2D(out_channels, 3, padding=1) 47 | self.nin_shortcut = ( 48 | PaddedConv2D(out_channels, 1) 49 | if in_channels != out_channels 50 | else lambda x: x 51 | ) 52 | 53 | def call(self, x): 54 | h = self.conv1(keras.activations.swish(self.norm1(x))) 55 | h = self.conv2(keras.activations.swish(self.norm2(h))) 56 | return self.nin_shortcut(x) + h 57 | 58 | 59 | class Decoder(keras.Sequential): 60 | def __init__(self): 61 | super().__init__( 62 | [ 63 | keras.layers.Lambda(lambda x: 1 / 0.18215 * x), 64 | PaddedConv2D(4, 1), 65 | PaddedConv2D(512, 3, padding=1), 66 | ResnetBlock(512, 512), 67 | AttentionBlock(512), 68 | ResnetBlock(512, 512), 69 | ResnetBlock(512, 512), 70 | ResnetBlock(512, 512), 71 | ResnetBlock(512, 512), 72 | keras.layers.UpSampling2D(size=(2, 2)), 73 | PaddedConv2D(512, 3, padding=1), 74 | ResnetBlock(512, 512), 75 | ResnetBlock(512, 512), 76 | ResnetBlock(512, 512), 77 | keras.layers.UpSampling2D(size=(2, 2)), 78 | PaddedConv2D(512, 3, padding=1), 79 | ResnetBlock(512, 256), 80 | ResnetBlock(256, 256), 81 | ResnetBlock(256, 256), 82 | keras.layers.UpSampling2D(size=(2, 2)), 83 | PaddedConv2D(256, 3, padding=1), 84 | ResnetBlock(256, 128), 85 | ResnetBlock(128, 128), 86 | ResnetBlock(128, 128), 87 | tfa.layers.GroupNormalization(epsilon=1e-5), 88 | keras.layers.Activation("swish"), 89 | PaddedConv2D(3, 3, padding=1), 90 | ] 91 | ) 92 | 93 | 94 | class Encoder(keras.Sequential): 95 | def __init__(self): 96 | super().__init__( 97 | [ 98 | PaddedConv2D(128, 3, padding=1 ), 99 | ResnetBlock(128,128), 100 | ResnetBlock(128, 128), 101 | PaddedConv2D(128 , 3 , padding=(0,1), stride=2), 102 | 103 | ResnetBlock(128,256), 104 | ResnetBlock(256, 256), 105 | PaddedConv2D(256 , 3 , padding=(0,1), stride=2), 106 | 107 | ResnetBlock(256,512), 108 | ResnetBlock(512, 512), 109 | PaddedConv2D(512 , 3 , padding=(0,1), stride=2), 110 | 111 | ResnetBlock(512,512), 112 | ResnetBlock(512, 512), 113 | 114 | ResnetBlock(512, 512), 115 | AttentionBlock(512), 116 | ResnetBlock(512, 512), 117 | 118 | tfa.layers.GroupNormalization(epsilon=1e-5) , 119 | keras.layers.Activation("swish"), 120 | PaddedConv2D(8, 3, padding=1 ), 121 | PaddedConv2D(8, 1 ), 122 | keras.layers.Lambda(lambda x : x[... , :4] * 0.18215) 123 | ] 124 | ) 125 | 126 | -------------------------------------------------------------------------------- /深度学习实战9-文本生成图像-本地电脑实现text2img/stable_diffusion_tf/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow import keras 3 | import tensorflow_addons as tfa 4 | import numpy as np 5 | 6 | from .layers import quick_gelu 7 | 8 | 9 | class CLIPAttention(keras.layers.Layer): 10 | def __init__(self): 11 | super().__init__() 12 | self.embed_dim = 768 13 | self.num_heads = 12 14 | self.head_dim = self.embed_dim // self.num_heads 15 | self.scale = self.head_dim**-0.5 16 | self.q_proj = keras.layers.Dense(self.embed_dim) 17 | self.k_proj = keras.layers.Dense(self.embed_dim) 18 | self.v_proj = keras.layers.Dense(self.embed_dim) 19 | self.out_proj = keras.layers.Dense(self.embed_dim) 20 | 21 | def _shape(self, tensor, seq_len: int, bsz: int): 22 | a = tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)) 23 | return keras.layers.Permute((2, 1, 3))(a) # bs , n_head , seq_len , head_dim 24 | 25 | def call(self, inputs): 26 | hidden_states, causal_attention_mask = inputs 27 | bsz, tgt_len, embed_dim = hidden_states.shape 28 | query_states = self.q_proj(hidden_states) * self.scale 29 | key_states = self._shape(self.k_proj(hidden_states), tgt_len, -1) 30 | value_states = self._shape(self.v_proj(hidden_states), tgt_len, -1) 31 | 32 | proj_shape = (-1, tgt_len, self.head_dim) 33 | query_states = self._shape(query_states, tgt_len, -1) 34 | query_states = tf.reshape(query_states, proj_shape) 35 | key_states = tf.reshape(key_states, proj_shape) 36 | 37 | src_len = tgt_len 38 | value_states = tf.reshape(value_states, proj_shape) 39 | attn_weights = query_states @ keras.layers.Permute((2, 1))(key_states) 40 | 41 | attn_weights = tf.reshape(attn_weights, (-1, self.num_heads, tgt_len, src_len)) 42 | attn_weights = attn_weights + causal_attention_mask 43 | attn_weights = tf.reshape(attn_weights, (-1, tgt_len, src_len)) 44 | 45 | attn_weights = tf.nn.softmax(attn_weights) 46 | attn_output = attn_weights @ value_states 47 | 48 | attn_output = tf.reshape( 49 | attn_output, (-1, self.num_heads, tgt_len, self.head_dim) 50 | ) 51 | attn_output = keras.layers.Permute((2, 1, 3))(attn_output) 52 | attn_output = tf.reshape(attn_output, (-1, tgt_len, embed_dim)) 53 | 54 | return self.out_proj(attn_output) 55 | 56 | 57 | class CLIPEncoderLayer(keras.layers.Layer): 58 | def __init__(self): 59 | super().__init__() 60 | self.layer_norm1 = keras.layers.LayerNormalization(epsilon=1e-5) 61 | self.self_attn = CLIPAttention() 62 | self.layer_norm2 = keras.layers.LayerNormalization(epsilon=1e-5) 63 | self.fc1 = keras.layers.Dense(3072) 64 | self.fc2 = keras.layers.Dense(768) 65 | 66 | def call(self, inputs): 67 | hidden_states, causal_attention_mask = inputs 68 | residual = hidden_states 69 | 70 | hidden_states = self.layer_norm1(hidden_states) 71 | hidden_states = self.self_attn([hidden_states, causal_attention_mask]) 72 | hidden_states = residual + hidden_states 73 | 74 | residual = hidden_states 75 | hidden_states = self.layer_norm2(hidden_states) 76 | 77 | hidden_states = self.fc1(hidden_states) 78 | hidden_states = quick_gelu(hidden_states) 79 | hidden_states = self.fc2(hidden_states) 80 | 81 | return residual + hidden_states 82 | 83 | 84 | class CLIPEncoder(keras.layers.Layer): 85 | def __init__(self): 86 | super().__init__() 87 | self.layers = [CLIPEncoderLayer() for i in range(12)] 88 | 89 | def call(self, inputs): 90 | [hidden_states, causal_attention_mask] = inputs 91 | for l in self.layers: 92 | hidden_states = l([hidden_states, causal_attention_mask]) 93 | return hidden_states 94 | 95 | 96 | class CLIPTextEmbeddings(keras.layers.Layer): 97 | def __init__(self, n_words=77): 98 | super().__init__() 99 | self.token_embedding_layer = keras.layers.Embedding( 100 | 49408, 768, name="token_embedding" 101 | ) 102 | self.position_embedding_layer = keras.layers.Embedding( 103 | n_words, 768, name="position_embedding" 104 | ) 105 | 106 | def call(self, inputs): 107 | input_ids, position_ids = inputs 108 | word_embeddings = self.token_embedding_layer(input_ids) 109 | position_embeddings = self.position_embedding_layer(position_ids) 110 | return word_embeddings + position_embeddings 111 | 112 | 113 | class CLIPTextTransformer(keras.models.Model): 114 | def __init__(self, n_words=77): 115 | super().__init__() 116 | self.embeddings = CLIPTextEmbeddings(n_words=n_words) 117 | self.encoder = CLIPEncoder() 118 | self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5) 119 | self.causal_attention_mask = tf.constant( 120 | np.triu(np.ones((1, 1, 77, 77), dtype="float32") * -np.inf, k=1) 121 | ) 122 | 123 | def call(self, inputs): 124 | input_ids, position_ids = inputs 125 | x = self.embeddings([input_ids, position_ids]) 126 | x = self.encoder([x, self.causal_attention_mask]) 127 | return self.final_layer_norm(x) 128 | -------------------------------------------------------------------------------- /深度学习实战9-文本生成图像-本地电脑实现text2img/stable_diffusion_tf/clip_tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | import tensorflow as tf 10 | from tensorflow import keras 11 | 12 | 13 | @lru_cache() 14 | def default_bpe(): 15 | p = os.path.join( 16 | os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz" 17 | ) 18 | if os.path.exists(p): 19 | return p 20 | else: 21 | return keras.utils.get_file( 22 | "bpe_simple_vocab_16e6.txt.gz", 23 | "https://github.com/openai/CLIP/blob/main/clip/bpe_simple_vocab_16e6.txt.gz?raw=true", 24 | ) 25 | 26 | 27 | @lru_cache() 28 | def bytes_to_unicode(): 29 | """ 30 | Returns list of utf-8 byte and a corresponding list of unicode strings. 31 | The reversible bpe codes work on unicode strings. 32 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 33 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 34 | This is a signficant percentage of your normal, say, 32K bpe vocab. 35 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 36 | And avoids mapping to whitespace/control characters the bpe code barfs on. 37 | """ 38 | bs = ( 39 | list(range(ord("!"), ord("~") + 1)) 40 | + list(range(ord("¡"), ord("¬") + 1)) 41 | + list(range(ord("®"), ord("ÿ") + 1)) 42 | ) 43 | cs = bs[:] 44 | n = 0 45 | for b in range(2**8): 46 | if b not in bs: 47 | bs.append(b) 48 | cs.append(2**8 + n) 49 | n += 1 50 | cs = [chr(n) for n in cs] 51 | return dict(zip(bs, cs)) 52 | 53 | 54 | def get_pairs(word): 55 | """Return set of symbol pairs in a word. 56 | Word is represented as tuple of symbols (symbols being variable-length strings). 57 | """ 58 | pairs = set() 59 | prev_char = word[0] 60 | for char in word[1:]: 61 | pairs.add((prev_char, char)) 62 | prev_char = char 63 | return pairs 64 | 65 | 66 | def basic_clean(text): 67 | text = ftfy.fix_text(text) 68 | text = html.unescape(html.unescape(text)) 69 | return text.strip() 70 | 71 | 72 | def whitespace_clean(text): 73 | text = re.sub(r"\s+", " ", text) 74 | text = text.strip() 75 | return text 76 | 77 | 78 | class SimpleTokenizer(object): 79 | def __init__(self, bpe_path: str = default_bpe()): 80 | self.byte_encoder = bytes_to_unicode() 81 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 82 | merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") 83 | merges = merges[1 : 49152 - 256 - 2 + 1] 84 | merges = [tuple(merge.split()) for merge in merges] 85 | vocab = list(bytes_to_unicode().values()) 86 | vocab = vocab + [v + "" for v in vocab] 87 | for merge in merges: 88 | vocab.append("".join(merge)) 89 | vocab.extend(["<|startoftext|>", "<|endoftext|>"]) 90 | self.encoder = dict(zip(vocab, range(len(vocab)))) 91 | self.decoder = {v: k for k, v in self.encoder.items()} 92 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 93 | self.cache = { 94 | "<|startoftext|>": "<|startoftext|>", 95 | "<|endoftext|>": "<|endoftext|>", 96 | } 97 | self.pat = re.compile( 98 | r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", 99 | re.IGNORECASE, 100 | ) 101 | 102 | def bpe(self, token): 103 | if token in self.cache: 104 | return self.cache[token] 105 | word = tuple(token[:-1]) + (token[-1] + "",) 106 | pairs = get_pairs(word) 107 | 108 | if not pairs: 109 | return token + "" 110 | 111 | while True: 112 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) 113 | if bigram not in self.bpe_ranks: 114 | break 115 | first, second = bigram 116 | new_word = [] 117 | i = 0 118 | while i < len(word): 119 | try: 120 | j = word.index(first, i) 121 | new_word.extend(word[i:j]) 122 | i = j 123 | except: 124 | new_word.extend(word[i:]) 125 | break 126 | 127 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 128 | new_word.append(first + second) 129 | i += 2 130 | else: 131 | new_word.append(word[i]) 132 | i += 1 133 | new_word = tuple(new_word) 134 | word = new_word 135 | if len(word) == 1: 136 | break 137 | else: 138 | pairs = get_pairs(word) 139 | word = " ".join(word) 140 | self.cache[token] = word 141 | return word 142 | 143 | def encode(self, text): 144 | bpe_tokens = [] 145 | text = whitespace_clean(basic_clean(text)).lower() 146 | for token in re.findall(self.pat, text): 147 | token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) 148 | bpe_tokens.extend( 149 | self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") 150 | ) 151 | return [49406] + bpe_tokens + [49407] 152 | 153 | def decode(self, tokens): 154 | text = "".join([self.decoder[token] for token in tokens]) 155 | text = ( 156 | bytearray([self.byte_decoder[c] for c in text]) 157 | .decode("utf-8", errors="replace") 158 | .replace("", " ") 159 | ) 160 | return text 161 | -------------------------------------------------------------------------------- /深度学习实战9-文本生成图像-本地电脑实现text2img/stable_diffusion_tf/clip_tokenizer/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战9-文本生成图像-本地电脑实现text2img/stable_diffusion_tf/clip_tokenizer/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /深度学习实战9-文本生成图像-本地电脑实现text2img/stable_diffusion_tf/clip_tokenizer/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenqiang0601/deep_learning/6381803a6189390d4388804c2ade0bc0c9ca8350/深度学习实战9-文本生成图像-本地电脑实现text2img/stable_diffusion_tf/clip_tokenizer/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /深度学习实战9-文本生成图像-本地电脑实现text2img/stable_diffusion_tf/layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow import keras 3 | 4 | 5 | class PaddedConv2D(keras.layers.Layer): 6 | def __init__(self, channels, kernel_size, padding=0, stride=1): 7 | super().__init__() 8 | self.padding2d = keras.layers.ZeroPadding2D((padding, padding)) 9 | self.conv2d = keras.layers.Conv2D( 10 | channels, kernel_size, strides=(stride, stride) 11 | ) 12 | 13 | def call(self, x): 14 | x = self.padding2d(x) 15 | return self.conv2d(x) 16 | 17 | 18 | class GEGLU(keras.layers.Layer): 19 | def __init__(self, dim_out): 20 | super().__init__() 21 | self.proj = keras.layers.Dense(dim_out * 2) 22 | self.dim_out = dim_out 23 | 24 | def call(self, x): 25 | xp = self.proj(x) 26 | x, gate = xp[..., : self.dim_out], xp[..., self.dim_out :] 27 | return x * gelu(gate) 28 | 29 | 30 | def gelu(x): 31 | tanh_res = keras.activations.tanh(x * 0.7978845608 * (1 + 0.044715 * (x**2))) 32 | return 0.5 * x * (1 + tanh_res) 33 | 34 | 35 | def quick_gelu(x): 36 | return x * tf.sigmoid(x * 1.702) 37 | 38 | 39 | def apply_seq(x, layers): 40 | for l in layers: 41 | x = l(x) 42 | return x 43 | 44 | 45 | def td_dot(a, b): 46 | aa = tf.reshape(a, (-1, a.shape[2], a.shape[3])) 47 | bb = tf.reshape(b, (-1, b.shape[2], b.shape[3])) 48 | cc = keras.backend.batch_dot(aa, bb) 49 | return tf.reshape(cc, (-1, a.shape[1], cc.shape[1], cc.shape[2])) 50 | -------------------------------------------------------------------------------- /深度学习实战9-文本生成图像-本地电脑实现text2img/text2img.py: -------------------------------------------------------------------------------- 1 | from tensorflow import keras 2 | from stable_diffusion_tf.stable_diffusion import StableDiffusion 3 | import argparse 4 | from PIL import Image 5 | import os 6 | os.environ["CUDA_VISIBLE_DEVICES"] = '2,3' 7 | parser = argparse.ArgumentParser() 8 | 9 | parser.add_argument( 10 | "--prompt", 11 | type=str, 12 | nargs="?", 13 | default="Romantic lavender and sunset", 14 | help="the prompt to render", 15 | ) 16 | 17 | parser.add_argument( 18 | "--output", 19 | type=str, 20 | nargs="?", 21 | default="output.png", 22 | help="where to save the output image", 23 | ) 24 | 25 | parser.add_argument( 26 | "--H", 27 | type=int, 28 | default=256, 29 | help="image height, in pixels", 30 | ) 31 | 32 | parser.add_argument( 33 | "--W", 34 | type=int, 35 | default=512, 36 | help="image width, in pixels", 37 | ) 38 | 39 | parser.add_argument( 40 | "--scale", 41 | type=float, 42 | default=7.5, 43 | help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", 44 | ) 45 | 46 | parser.add_argument( 47 | "--steps", type=int, default=50, help="number of ddim sampling steps" 48 | ) 49 | 50 | parser.add_argument( 51 | "--seed", 52 | type=int, 53 | help="optionally specify a seed integer for reproducible results", 54 | ) 55 | 56 | parser.add_argument( 57 | "--mp", 58 | default=False, 59 | action="store_true", 60 | help="Enable mixed precision (fp16 computation)", 61 | ) 62 | 63 | args = parser.parse_args() 64 | 65 | if args.mp: 66 | print("Using mixed precision.") 67 | keras.mixed_precision.set_global_policy("mixed_float16") 68 | 69 | generator = StableDiffusion(img_height=args.H, img_width=args.W, jit_compile=False) 70 | img = generator.generate( 71 | args.prompt, 72 | num_steps=args.steps, 73 | unconditional_guidance_scale=args.scale, 74 | temperature=1, 75 | batch_size=1, 76 | seed=args.seed, 77 | ) 78 | Image.fromarray(img[0]).save(args.output) 79 | print(f"saved at {args.output}") 80 | -------------------------------------------------------------------------------- /深度学习实战9-文本生成图像-本地电脑实现text2img/text2img2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | from modelscope.pipelines import pipeline 4 | from modelscope.utils.constant import Tasks 5 | 6 | task = Tasks.text_to_image_synthesis 7 | model_id = 'modelscope/small-stable-diffusion-v0' 8 | # 基础调用 9 | pipe = pipeline(task=task, model=model_id, model_revision='v1.0.2') 10 | output = pipe({'text': 'an apple'}) --------------------------------------------------------------------------------