├── README.md ├── hands-dirty-nlp ├── Chapter 2. 序列标注 │ ├── output.png │ └── 序列标注.ipynb ├── Chapter 5. 生成任务 │ └── 生成任务.ipynb └── Chapter 6. 阅读理解 │ └── 阅读理解.ipynb └── 文本表示 └── 初识预训练模型elmo ├── 代码 ├── Andersen Fairy Tales.txt ├── ELMo.ipynb └── README.md └── 课件 ├── README.md └── 初识预训练模型:elmo.pdf /README.md: -------------------------------------------------------------------------------- 1 | # hands-dirty-nlp 2 | - 课程前置要求:默认已完成<必修课程> 3 | - 定位及目的:本课程面对具有一定机器学习基础,但尚未入门的NLPer或经验尚浅的NLPer,尽力避免陷入繁琐枯燥的公式讲解中,力求用代码展示每个模型背后的设计思想,同时也会带大家梳理每个模块下的技术演变,做到既知树木也知森林。 4 | 5 | ## 课程目录 6 | ### 文本表示 7 | - 任务定义:文本表示是把现实中的文本数据转化为计算机能够运算的数值向量,这样就可以针对文本信息进行计算,进而来完成各种NLP任务。本章将介绍文本表示的演进过程,基于此大家也可对NLP的发展历程有一个基本了解。 8 | - 离散表示:one-hot、bag-of-word、tf-idf 9 | - 稠密表示-词向量:word2vec、glove 10 | - 预训练系列: 11 | - 初识预训练模型:elmo 12 | - 自编码:bert、ernie等 13 | - 自回归:gpt等 14 | - Prompt 15 | - 发展脉络 16 | 17 | ### 序列标注 18 | - 任务定义:序列标注的涵盖范围很广泛,可用于解决一系列对token进行分类的问题,如分词、命名实体识别、词性标注等。本章以命名实体识别为切入点,介绍序列标注任务的前世今生。 19 | - NER 20 | - 任务简介:标注方式、评价指标 21 | - 传统方法:词典匹配、CRF、HMM 22 | - 深度学习方法:LSTM + CRF、BERT+CRF 23 | - 融入词汇的方法:lattice、FLAT、LEBERT(位置编码在NER的重要性) 24 | - 解码方式:softmax、CRF、span-pointer、片段排列 25 | - flat-NER 26 | - nested-NER 27 | - 发展脉络 28 | 29 | ### 分类任务 30 | - 任务定义:分类任务是NLP中应用最多的任务,在工业界很多实际业务问题都可以抽象成分类任务。本章以意图识别为切入点,介绍分类任务的经典方法。 31 | - 意图识别: 32 | - 任务简介:评价指标 33 | - 经典深度学习方法:fasttext、text-cnn、lstm 34 | - 基于预训练模型的方法:bert 35 | - 解码: 36 | - 单标签 37 | - 多标签 38 | - 发展脉络 39 | 40 | ### 文本匹配 41 | - 任务定义:文本匹配是为了判断两个文本之间的相关关系,不同的场景下对”相关的“的定义是不同的,例如:qq匹配(query-query):对话任务中,判断两个问句是否相似;qa匹配(query-answer):对话任务中,判断用户问句与回复是否对应;qt匹配(query-title):搜索任务中,判断用户的输入和文章标题是否相关。本小节以qq切入点,介绍匹配任务方法的演进。 42 | - qq匹配 43 | - 任务简介 44 | - 双塔:优缺点 45 | - dssm 46 | - bert双塔 47 | - 交互:优缺点 48 | - ESIM 49 | - simbert 50 | - 发展脉络 51 | 52 | ### 生成任务 53 | - 任务定义: 文本生成是NLP中较为复杂的任务,本章只聚焦于文本到文本的生成,例如生成式摘要、机器翻译、对话生成等等。本文以其中的对话生成和机器翻译为切入点,介绍其中的经典方法。 54 | - 对话生成 55 | - 任务简介 56 | - 方法:bert-base-s2s、gpt 57 | - 机器翻译 58 | - 任务简介 59 | - 方法:rnn-base-s2s、bert-base-s2s 60 | - 发展脉络 61 | 62 | ### 阅读理解 63 | - 任务定义:相对于其他NLP任务,阅读理解侧重于深入理解文档中的语义信息,一般来说包含完型填空、多项选择、答案抽取、自由问答这4个任务。本章以其中的多项选择与答案抽取作为切入点,介绍其中的经典方法。 64 | - 多项选择 65 | - 任务简介 66 | - 方法 67 | - 答案抽取 68 | - 任务简介 69 | - 方法 70 | - 发展脉络 71 | 72 | ## 贡献者名单 73 | | 姓名 | 个人简介 | 个人主页 | 74 | | :----| :---- | :---- | 75 | | 慎独 | NLP算法工程师,项目负责人 | - | 76 | | 西瓜骑士 | NLP算法工程师,项目负责人 | - | 77 | | 芙蕖 | NLP算法工程师 | - | 78 | 79 | 80 | ## 关注我们 81 |
82 |

扫描下方二维码关注公众号:Datawhale

83 | 84 |
85 | 86 | ## LICENSE 87 | 知识共享许可协议
本作品采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可。 88 | -------------------------------------------------------------------------------- /hands-dirty-nlp/Chapter 2. 序列标注/output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datawhalechina/hands-dirty-nlp/1218362ac4956169b358a9b462fc890a4e130df8/hands-dirty-nlp/Chapter 2. 序列标注/output.png -------------------------------------------------------------------------------- /hands-dirty-nlp/Chapter 2. 序列标注/序列标注.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "996a90f8", 6 | "metadata": {}, 7 | "source": [ 8 | "# **二、序列标注**\n", 9 | "\n", 10 | "## 2.1 任务定义\n", 11 | "\n", 12 | "序列标注(sequence labeling),就是对一个序列中的每一个token进行分类。它是许多自然语言处理问题的前驱,如情感分析,信息检索、推荐和过滤等等。同时在自然语言处理中,许多的任务可以转化为“将输入的语言序列转化为标注序列”来解决问题。\n", 13 | "\n", 14 | "输入输出:序列标注问题的输入是一个观测序列,输出是一个标记序列或状态序列。问题的目标在于学习一个模型,使它能够对观测序列给出标记序列作为预测。\n", 15 | "\n", 16 | "标注方式:序列标注的方法中有多种标注方式:BIO、BIOSE、IOB、BILOU、BMEWO,其中前三种最为常见。各种标注方法大同小异。下面列举一些常见的标签方案:\n", 17 | "\n", 18 | " 标签方案中通常都使用一些简短的英文字符[串]来编码。\n", 19 | "\n", 20 | " 标签是打在token上的。\n", 21 | "\n", 22 | " 对于英文,token可以是一个单词(e.g. awesome),也可以是一个字符(e.g. a)。\n", 23 | "\n", 24 | " 对于中文,token可以是一个词语(分词后的结果),也可以是单个汉字字符。\n", 25 | "\n", 26 | " 为便于说明,以下都将token试作等同于字符。\n", 27 | "\n", 28 | " 标签列表如下:\n", 29 | "\n", 30 | "- B,即Begin,表示开始\n", 31 | "- I,即Intermediate,表示中间\n", 32 | "- E,即End,表示结尾\n", 33 | "- S,即Single,表示单个字符\n", 34 | "- O,即Other,表示其他,用于标记无关字符\n", 35 | "\n", 36 | "1. BIO\n", 37 | "\n", 38 | "- B stands for 'beginning' (signifies beginning of an Named Entity, i.e. NE)\n", 39 | "- I stands for 'inside' (signifies that the word is inside an NE)\n", 40 | "- O stands for 'outside' (signifies that the word is just a regular word outside of an NE) \n", 41 | "\n", 42 | "2. BIOES\n", 43 | "\n", 44 | "- B stands for 'beginning' (signifies beginning of an NE)\n", 45 | "- I stands for 'inside' (signifies that the word is inside an NE)\n", 46 | "- O stands for 'outside' (signifies that the word is just a regular word outside of an NE)\n", 47 | "- E stands for 'end' (signifies that the word is the end of an NE)\n", 48 | "- S stands for 'singleton'(signifies that the single word is an NE )\n", 49 | "\n", 50 | "3. IOB (即IOB-1)\n", 51 | "\n", 52 | " IOB与BIO字母对应的含义相同,其不同点是IOB中,标签B仅用于两个连续的同类型命名实体的边界区分,不用于命名实体的起始位置,这里举个例子:\n", 53 | "\n", 54 | " 词序列:(word)(word)(word)(word)(word)(word)\n", 55 | "\n", 56 | " IOB标注:(I-loc)(I-loc)(B-loc)(I-loc)(o)(o)\n", 57 | "\n", 58 | " BIO标注:(B-loc)(I-loc)(B-loc)(I-loc)(o)(o)\n", 59 | "\n", 60 | " The IOB scheme is similar to the BIO scheme,however, here the tag B- is only used to start a segment if the previous token is of the same class but is not part of the segment.\n", 61 | "\n", 62 | " 因为IOB的整体效果不好,所以出现了IOB-2,约定了所有命名实体均以B tag开头。这样IOB-2就与BIO的标注方式等价了。\n", 63 | "\n", 64 | "评价指标:常见的序列标注算法的模型效果评估指标有准确率(accuracy)、查准率(percision)、召回率(recall)、F1值等,计算的公式如下:\n", 65 | "\n", 66 | "- 准确率: accuracy = 预测对的元素个数/总的元素个数\n", 67 | "- 查准率:precision = 预测正确的实体个数 / 预测的实体总个数\n", 68 | "- 召回率:recall = 预测正确的实体个数 / 标注的实体总个数\n", 69 | "- F1值:F1 = 2 *准确率 * 召回率 / (准确率 + 召回率)\n", 70 | "\n", 71 | "## 2.2 NER\n", 72 | "\n", 73 | "命名实体识别(Named Entity Recognition,简称NER),是指识别文本中具有特定意义的实体,主要包括人名、地名、机构名、专有名词等。属于序列标注任务的范畴。\n", 74 | "\n", 75 | "### 2.2.1 传统方法\n", 76 | "\n", 77 | "基于规则,其思想在于在观察特定的领域文本以及实体出现的语法构成和模式的情况后,设计特定的实体提取规则以完成提取。\n", 78 | "\n", 79 | "实体词表、关系词或属性词触发词表、正则表达式是基于词典规则方法的三大核心部件,主要是2种方式:\n", 80 | "\n", 81 | "1. 基于实体词表的匹配识别\n", 82 | "\n", 83 | "基于实体词表的匹配识别是使用最广泛的一种实体识别方法,虽然实体词表实现目标文本词表的有限匹配,但见效十分快速。\n", 84 | "\n", 85 | "一般,在进行领域实体识别时,每个特定领域都有专属的实体词典,如医药行业的药名、科室名、手术名,汽车行业的车型、车系、品牌名称,金融行业中的公司词典、行业词典,招聘领域的职位词典等,这些词典都可以用来进行实体识别。\n", 86 | "\n", 87 | "对于有歧义的词汇,可以先进行分词,比如采用最大匹配法,在分词的基础上在进行NER任务。\n", 88 | "\n", 89 | "2. 基于规则模板的匹配识别\n", 90 | "\n", 91 | "规则模板可以实现对实体词表识别的扩展,其中的核心在于规则模板的设计,在此之前需要分析实体词或者属性值的构词规则,包括基于字符构词规则的识别以及基于词性组合规则的识别两种。其中,基于字符构词规则的识别采用正则表达式进行提取。例如:\n", 92 | "\n", 93 | "Email的表现形式通常为“ xxxx@xxx .com ”;利用“^\\w+([-+.]\\w+)@\\w+([-.]\\w+).\\w+([-.]\\w+)*$”来匹配Email地址,\n", 94 | "\n", 95 | "借助“\\d{4}[年-]\\d{1,2}[月-]\\d{1,2}日”的正则模板表达式来提取日期;\n", 96 | "\n", 97 | "### 2.2.2 CRF\n", 98 | "\n", 99 | "CRF,英文全称为conditional random field, 中文名为条件随机场,是给定一组输入随机变量条件下另一组输出随机变量的条件概率分布模型,其特点是假设输出随机变量构成马尔可夫(Markov)随机场。 较为简单的条件随机场是定义在线性链上的条件随机场,称为线性链条件随机场(linear chain conditional random field). 线性链条件随机场可以用于序列标注等问题,这时,在条件概率模型P(Y|X)中,Y是输出变量,表示标记序列(或状态序列),X是输入变量,表示需要标注的观测序列。学习时利用训练数据集通过极大似然估计或正则化的极大似然估计得到条件概率模型P(Y|X);预测时,对于给定的输入序列x,求出条件概率P(y|x)最大的输出序列y0.\n", 100 | "\n", 101 | "### 2.2.3 HMM\n", 102 | "\n", 103 | "隐马尔可夫模型(Hidden Markov Model,HMM),是一个统计模型。隐马尔可夫模型有三种应用场景,我们做命名实体识别只用到其中的一种——**求观察序列的背后最可能的标注序列**。\n", 104 | "\n", 105 | "HMM中,有5个基本元素:{N,M,A,B,π},我结合序列标志任务对这5个基本元素做一个介绍:\n", 106 | "\n", 107 | "- N:状态的有限集合。在这里,是指每一个词语背后的标注。\n", 108 | "- M:观察值的有限集合。在这里,是指每一个词语本身。\n", 109 | "- A:状态转移概率矩阵。在这里,是指某一个标注转移到下一个标注的概率。\n", 110 | "- B:观测概率矩阵,也就是发射概率矩阵。在这里,是指在某个标注下,生成某个词的概率。\n", 111 | "- π:初始概率矩阵。在这里,是指每一个标注的初始化概率。\n", 112 | "\n", 113 | "而以上的这些元素,都是可以从训练语料集中统计出来的。最后,我们根据这些统计值,应用维特比(viterbi)算法,就可以算出词语序列背后的标注序列了。\n", 114 | "\n", 115 | "### 2.2.4 深度学习方法\n", 116 | "\n", 117 | "#### LSTM+CRF\n", 118 | "\n", 119 | "对于序列标注问题一个基于深度学习的方法便是BI-LSTM,简单的做法是将输入序列经过一个embeddig层转化为一个向量序列输入两个双向的LSTM单元,将每个时间序列的正向反向输出拼接,经过一个全连接层映射为一个维度为输出标签数量的一个向量,使用Softmax将输出归一化作为每种标签的概率。\n", 120 | "\n", 121 | "![](https://github.com/datawhalechina/hands-dirty-nlp/blob/main/hands-dirty-nlp/Chapter%202.%20%E5%BA%8F%E5%88%97%E6%A0%87%E6%B3%A8/output.png)\n", 122 | "\n", 123 | "首先我们要知道为什么使用LSTM+CRF,序列标注问题本质上是分类问题,因为其具有序列特征,所以LSTM就很合适进行序列标注,确实,我们可以直接利用LSTM进行序列标注。但是这样的做法有一个问题:每个时刻的输出没有考虑上一时刻的输出。我们在利用LSTM进行序列建模的时候只考虑了输入序列的信息,即单词信息,但是没有考虑标签信息,即输出标签信息。这样会导致一个问题,以“我 喜欢 跑步”为例,LSTM输出“喜欢”的标签是“动词”,而“跑步”的标签可能也是“动词”。但是实际上,“名词”标签更为合适,因为“跑步”这里是一项运动。也就是“动词”+“名词”这个规则并没有被LSTM模型捕捉到。也就是说这样使用LSTM无法对标签转移关系进行建模。而标签转移关系对序列标注任务来说是很重要的,所以就在LSTM的基础上引入一个标签转移矩阵对标签转移关系进行建模。这就和CRF很像了。我们知道,CRF有两类特征函数,一类是针对观测序列与状态的对应关系(如“我”一般是“名词”),一类是针对状态间关系(如“动词”后一般跟“名词”)。在LSTM+CRF模型中,前一类特征函数的输出由LSTM的输出替代,后一类特征函数就变成了标签转移矩阵。\n", 124 | "\n", 125 | "#### Bert+CRF\n", 126 | "\n", 127 | "BERT-CRF与BiLSTM-CRF模型较为相似,其本质上还是一个CRF模型。BERT模型+FC layer(全连接层)已经可以解决序列标注问题,以词性标注为例,BERT的encoding vector通过FC layer映射到标签集合后,单个token的output vector再经过Softmax处理,每一维度的数值就表示该token的词性为某一词性的概率。基于此数据便可计算loss并训练模型。但根据Bi-LSTM+CRF 模型的启发,我们在BERT+FC layer 的基础上增加CRF layer。 CRF是一种经典的概率图模型,具体数学原理不在此处展开。要声明的是,CRF层可以加入一些约束来保证最终的预测结果是有效的。这些约束可以在训练数据时被CRF层自动学习得到。具体的约束条件我们会在后面提及。有了这些有用的约束,错误的预测序列会大大减小。\n", 128 | "\n", 129 | "参考文献:\n", 130 | "\n", 131 | "https://zhuanlan.zhihu.com/p/268579769\n", 132 | "\n", 133 | "https://zhuanlan.zhihu.com/p/147537898\n", 134 | "\n", 135 | "http://nathanlvzs.github.io/Several-Tagging-Schemes-for-Sequential-Tagging.html" 136 | ] 137 | } 138 | ], 139 | "metadata": { 140 | "kernelspec": { 141 | "display_name": "Python 3 (ipykernel)", 142 | "language": "python", 143 | "name": "python3" 144 | }, 145 | "language_info": { 146 | "codemirror_mode": { 147 | "name": "ipython", 148 | "version": 3 149 | }, 150 | "file_extension": ".py", 151 | "mimetype": "text/x-python", 152 | "name": "python", 153 | "nbconvert_exporter": "python", 154 | "pygments_lexer": "ipython3", 155 | "version": "3.10.9" 156 | } 157 | }, 158 | "nbformat": 4, 159 | "nbformat_minor": 5 160 | } 161 | -------------------------------------------------------------------------------- /hands-dirty-nlp/Chapter 5. 生成任务/生成任务.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "1c3cf8bf", 6 | "metadata": {}, 7 | "source": [ 8 | "# **五、生成任务**\n", 9 | "\n", 10 | "任务定义: 文本生成是NLP中较为复杂的任务,本章只聚焦于文本到文本的生成,例如生成式摘要、机器翻译、对话生成等等。本文以其中的对话生成和机器翻译为切入点,介绍其中的经典方法。\n", 11 | "\n", 12 | "## 5.1 对话生成 \n", 13 | "\n", 14 | "### 5.1.1 任务简介\n", 15 | "\n", 16 | "基于生成的方式将对话生成问题看作是一种“源到目标”的映射问题,直接从大量的训练数据中学习从输入信息到最终输出之间的映射关系。 \n", 17 | "\n", 18 | "\n", 19 | "\n", 20 | "### 5.1.2 方法:bert-base-s2s、gpt\n", 21 | "\n", 22 | "代码案例:https://zhuanlan.zhihu.com/p/170358507\n", 23 | "\n", 24 | "\n", 25 | "\n", 26 | "## 5.2 机器翻译 \n", 27 | "\n", 28 | "### 5.2.1 机器翻译原理\n", 29 | "\n", 30 | "机器翻译就是将一个语言的句子翻译成另一个语言的句子,主要可以分为三个步骤:**「预处理、翻译模型、后处理」**。\n", 31 | "\n", 32 | "预处理是对源语言的句子进行规范化处理,把过长的句子通过标点符号分成几个短句子,过滤一些语气词和与意思无关的文字,将一些数字和表达不规范的地方,归整成符合规范的句子,等等。\n", 33 | "\n", 34 | "翻译模块是将输入的字符单元、序列翻译成目标语言序列的过程,这是机器翻译中最关键最核心的地方。纵观机器翻译发展的历史,翻译模块可以分为基于规则的翻译、基于统计的翻译和基于神经网络的翻译三大类。现如今基于神经网络的机器翻译已经成为了主流方法,效果也远远超过了前两类方法。\n", 35 | "\n", 36 | "后处理模块是将翻译结果进行大小写的转化、建模单元进行拼接,特殊符号进行处理,使得翻译结果更加符合人们的阅读习惯。\n", 37 | "\n", 38 | "\n", 39 | "\n", 40 | "### 5.2.2 任务简介\n", 41 | "\n", 42 | "机器翻译和神经网络有着千丝万缕的关系。神经网络是一种方法,而机器翻译是其中最大的目标应用场景之一,很多神经网络技术的发展,最早就是从做机器翻译任务开始的。\n", 43 | "\n", 44 | "机器翻译任务从机器学习的角度看,是一种生成任务,因为它要输出的结果不是类别编号,而是一串字符序列。在神经网络和深度学习中,做机器翻译任务的模型,也称为Seq2Seq(sequence to sequence)模型。\n", 45 | "\n", 46 | "Seq2Seq模型是一种经典的深度学习模型,通常采用Encoder-Decoder框架。这个框架最早出自2014年的一篇论文《Learning Phrase Representations using RNN Encoder-Decoder forStatistical Machine Translation》,没错,这又是一篇研究机器翻译的经典论文。在这篇论文的基础上,谷歌最终于2016年推出了基于神经网络的机器翻译并大获成功,而这篇论文所提出来的Encoder-Decoder框架,甚至超越了机器翻译领域。\n", 47 | "\n", 48 | "做NLP的同学应该都很熟悉Encoder-Decoder框架,现在做CV的同学也开始在熟悉这个框架。最近深度学习领域有一条重磅消息,在NLP领域称霸多年的Transformer模型,现在正在CV领域大杀特杀。NLP和CV一直是机器学习最热门的两个研究领域,不过长期以来一直有点生殖隔离的意思,现在让Transformer一拳打穿了次元壁,所以好几位大牛都在预测机器学习的大一统模型也许正在呼之欲出。\n", 49 | "\n", 50 | "不过,Transformer最早是用来做什么的?没错,还是做机器翻译。2017年5月,知名的深度学习研究团队FAIR搞出了个新模型,没给起名字,总之是用CNN+Attention来做机器翻译。论文一出,圈内哗然,毕竟在大家的认知中,CNN模型一向是用来处理图像,也就是做CV的,做文本做机器翻译这块,当时主要还是用RNN及其派生的LSTM等模型。RNN有个很大的缺点,就是没法并行训练,非常耗时所以用起来很鸡肋,现在FAIR用并行性好得多的CNN搞出了新模型,一下有种众望所归的感觉。结果另一家知名的研究团队Google Brain不干了,直接走力大飞砖的路线出了一篇爆款论文,叫《AttentionIs All You Need》,相信大家都有所耳闻。Google Brain没有明说,不过我总觉得这个霸气侧漏的标题显然多少有点暗指FAIR画蛇添足的意思:机器翻译还要什么CNN?直接Attention就完事了。在这篇论文里诞生了一款模型,现在我们都耳熟能详了,这就是Transformer。\n", 51 | "\n", 52 | "\n", 53 | "\n", 54 | "### 5.3.3 方法:rnn-base-s2s、bert-base-s2s\n", 55 | "\n", 56 | "\n", 57 | "\n", 58 | "### 5.2.4 发展脉络\n" 59 | ] 60 | } 61 | ], 62 | "metadata": { 63 | "kernelspec": { 64 | "display_name": "Python 3 (ipykernel)", 65 | "language": "python", 66 | "name": "python3" 67 | }, 68 | "language_info": { 69 | "codemirror_mode": { 70 | "name": "ipython", 71 | "version": 3 72 | }, 73 | "file_extension": ".py", 74 | "mimetype": "text/x-python", 75 | "name": "python", 76 | "nbconvert_exporter": "python", 77 | "pygments_lexer": "ipython3", 78 | "version": "3.10.9" 79 | } 80 | }, 81 | "nbformat": 4, 82 | "nbformat_minor": 5 83 | } 84 | -------------------------------------------------------------------------------- /hands-dirty-nlp/Chapter 6. 阅读理解/阅读理解.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "5002adc3", 6 | "metadata": {}, 7 | "source": [ 8 | "# **六、阅读理解**\n", 9 | "\n", 10 | "机器阅读理解任务的定义:\n", 11 | "\n", 12 | "\n", 13 | "\n", 14 | "## 6.1 任务的来源\n", 15 | "\n", 16 | "增强机器对于数据的理解能力是人工智能发展过程中至关重要的一个环节。1950年艾伦图灵提出的图灵测试,则是将对话作为了评判人工智能能力的一个基准。在现阶段段,为了实现表现优秀的对话系统(或更简单的问答系统),机器的阅读理解能力是至关重要的。\n", 17 | "\n", 18 | "机器阅读理解属于自然语言理解(natural language understanding)的范畴,该任务可以定义为“机器通过交互从文本中抽取和构建文章语义的过程”。\n", 19 | "\n", 20 | "就像人类做阅读理解题目一样,只有在充分理解问题以及上下文的基础上,机器才能给出更准确的答案。对于机器来说,可以从以下几种任务形式来具体定义,并将其应用到实际场景中。\n", 21 | "\n", 22 | "\n", 23 | "\n", 24 | "## 6.2 任务形式\n", 25 | "\n", 26 | "对于机器阅读理解任务来说,最通用的任务形式可以定义为“给出一个query和一系列candidate context,机器需要在充分理解问题和上下文的基础上,给出最终的输出答案answer”,即,“input:query + context,output:answer”。更细化一点来说,则可以分为以下几种任务格式:\n", 27 | "\n", 28 | "1. 多项选择式阅读理解\n", 29 | "\n", 30 | "对于多项选择式阅读理解任务来说,模型需要从给出的多个答案中选出正确的答案。这就好比人类在做选择题时的场景。给出一个问题和多个选项,需要选出一个正确(唯一活不唯一的)的答案。这可以衡量人类在某阶段的学习水平,同样,这也可以用来衡量机器的学习水平。\n", 31 | "\n", 32 | "2. 抽取式阅读理解\n", 33 | "\n", 34 | "抽取式阅读理解任务的前提是,假定问题所对应的正确答案出现在给出的context中,这样,对于模型来说,则需要根据问题,给出正确答案在context中的start和end的位置,最后经过处理模型会返回一个span作为最终的答案。如果context中没有包含可能是答案的部分,则可以通过预先设置的方式,让模型给出没有答案的反馈。\n", 35 | "\n", 36 | "3. 完形填空式阅读理解\n", 37 | "\n", 38 | "类比人类在做英语考试完形填空任务的场景:给出一篇文章,从中随机选择一部分token破坏掉(就是删除并用“__”代替),人类需要根据上下文从ABCD四个选项中选择出最有可能属于原始文档的一个答案。对于机器阅读理解来说,则是需要从vocabulary中自由的选择出若干个符合上下文的token,phrase或span。\n", 39 | "\n", 40 | "4. 自由生成式阅读理解\n", 41 | "\n", 42 | "在该种任务形式下,对于模型的输出没有严格的要求,需要模型根据答案生成正确的答案,这可以通过seq2seq的范式来解决。\n", 43 | "\n", 44 | "\n", 45 | "\n", 46 | "## 6.3 机器阅读理解的发展\n", 47 | "\n", 48 | "机器阅读理解任务的发展可以追溯到十九世纪七十年代。第一个机器阅读理解系统是1977年公布的QUALM系统,该系统建立在人为手写规则的基础上。\n", 49 | "\n", 50 | "而十年之后,一个包含120个故事的阅读理解数据集在1999在公布,同时还有一个基于规则的模型Deep Read。虽然该模型已经具有了一定的精确度,但它仍旧是建立在人类专家预先制定的规则上的,而这需要耗费大量的精力和财力。\n", 51 | "\n", 52 | "从2015年开始,得益于cv领域深度学习的快速发展,机器阅读理解任务也开始建立在深度学习的基础上。就像其他的深度学习模型一样,模型不再需要人类专家手写制定的规则或通过统计学习方法选择出的特征,而仅仅依靠输入输出和精妙的模型结构设计就可以达到end to end的训练效果。这大大推进了机器阅读理解任务的发展。\n", 53 | "\n", 54 | "自2019年BERT模型推出后,机器阅读理解能力在(2019年)之前的数据集和评测指标上已经超过人类,机器阅读理解开始了新的发展。\n", 55 | "\n", 56 | "\n", 57 | "\n", 58 | "## 6.4 机器阅读理解模型\n", 59 | "\n", 60 | "### 6.4.1 模型的输入和输出\n", 61 | "\n", 62 | "从整体来看,对于机器阅读理解任务,可以看作,给定一段文本和需要回答的问题,模型需要根据对文本的理解,按照预先定义的格式(多项选择、完形填空、抽取式、生成式)返回最终的答案。\n", 63 | "\n", 64 | "拆开来看,则可以分为以下几个步骤:首先,模型将输入的文本和问题进行编码,并获得融合上下文的向量表示;之后,为了更好的理解文本和问题之间的关系,模型需要对得到的向量表示(文本和答案)进行融合;最后,模型需要根据预先定义的格式,通过特定的输出层,返回最终的答案。\n", 65 | "\n", 66 | "\n", 67 | "\n", 68 | "### 6.4.2 模型的各个层次组建介绍\n", 69 | "\n", 70 | "从整体上来看,机器阅读理解模型可以分为encoding layer,interaction layer以及output layer三个部分。下面分别对这三个部分进行介绍:\n", 71 | "\n", 72 | "1. encoding layer\n", 73 | "\n", 74 | "模型的encoding layer需要将输入的文本和问题编码为向量表示。一般来说,可以建立一个vocabulary,通过查表的方式得到embedding表示,之后,再融入诸如POS,character,context,position,dependency tree的对应表示进行数据增强。在BERT模型之后,一种获得上下文向量表示的方式就是直接使用BERT模型的输出。也可以使用BI-LSTM模型来获取句子在两个方向上的表示并进行拼接得到。\n", 75 | "\n", 76 | "2. interaction layer\n", 77 | "\n", 78 | "interaction layer的作用是融合文本信息和问题信息,简单来说就是,对于同一段文本,不同的问题所需要的文本信息并不相同,所以需要通过interaction的方式来对文本信息和问题信息进行增强。一般情况下,可以分别先对文本向量和问题向量进行self-attention处理,之后再互相进行cross- attention来达到信息融合的效果。\n", 79 | "\n", 80 | "3. output layer\n", 81 | "\n", 82 | "一般情况下,如果需要进行抽取式的阅读理解任务,则可以在最后拼接一个指针网络,给出start和end位置的概率,如果是自由生成式的网络,则可以使用seq2seq架构中的decoder模型,逐步生成答案直到生成终止符或达到长度上限。\n", 83 | "\n", 84 | "\n", 85 | "\n", 86 | "### 6.4.3 典型模型介绍\n", 87 | "\n", 88 | "这里选取fusion net模型进行介绍。\n", 89 | "\n", 90 | "先介绍两个概念:history of word和fully-aware attention\n", 91 | "\n", 92 | "对于深度学习模型来说,较浅层的网络倾向于抽取表面的信息,而较深层的网络则倾向于抽取语义信息。但是对于阅读理解模型来说,字面信息和语义信息同样重要。这里,我们将从第一层到当前层所有的输入拼接之后得到的向量称为history of word。\n", 93 | "\n", 94 | "可以发现,通过拼接向量的方法固然可以捕获所需要的各个维度的信息,但是在实际应用中,随着模型层数的加深,拼接后得到的向量的维度会大大增加,而这会带来巨大的计算负荷。\n", 95 | "\n", 96 | "接下来就引入fully-aware attention的概念,简单来说,就是将history of word作为输入,然后,使用得到的注意力权重在一个特定的层的输出上进行计算,最终得到需要的注意力向量。在实际应用中,fully-aware attention可以使用多次来捕获更多的语义信息。 \n", 97 | "\n", 98 | "\\#这里先空上\n", 99 | "\n", 100 | "\n", 101 | "\n", 102 | "## 6.5 机器阅读理解的应用\n", 103 | "\n", 104 | "大数据时代产生了巨量的文本数据,但是,通过人为分析的方式来对得到的文本信息进行处理是不现实的,这会耗费大量的人力和财力。机器阅读理解任务则可以在这方面帮助人类更高效的处理相关领域的信息。比如,在法律卷宗方面,可以使用机器阅读理解的技术来分析案件的种种因素,这可以帮助法律从业人员更方便的处理文本卷宗,将更多的精力放在对案件更本质的分析上。在说明书领域,也可以使用机器阅读理解技术来回答用户提出的问题。虽然现阶段机器阅读理解技术在实际应用中,距离期望的效果还有很大的一段差距,但是,机器阅读理解技术在当下仍可以帮助减轻一部分人员负担。比如,机器阅读理解去处理一部分相对而言比较容易回答的问题,其余的问题则留给人类去完成,这种人机协同的方式可以充分利用现有的技术来提高实际场景中的业务效率。" 105 | ] 106 | } 107 | ], 108 | "metadata": { 109 | "kernelspec": { 110 | "display_name": "Python 3 (ipykernel)", 111 | "language": "python", 112 | "name": "python3" 113 | }, 114 | "language_info": { 115 | "codemirror_mode": { 116 | "name": "ipython", 117 | "version": 3 118 | }, 119 | "file_extension": ".py", 120 | "mimetype": "text/x-python", 121 | "name": "python", 122 | "nbconvert_exporter": "python", 123 | "pygments_lexer": "ipython3", 124 | "version": "3.10.9" 125 | } 126 | }, 127 | "nbformat": 4, 128 | "nbformat_minor": 5 129 | } 130 | -------------------------------------------------------------------------------- /文本表示/初识预训练模型elmo/代码/Andersen Fairy Tales.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datawhalechina/hands-dirty-nlp/1218362ac4956169b358a9b462fc890a4e130df8/文本表示/初识预训练模型elmo/代码/Andersen Fairy Tales.txt -------------------------------------------------------------------------------- /文本表示/初识预训练模型elmo/代码/ELMo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "7af2c24d", 6 | "metadata": {}, 7 | "source": [ 8 | "# 数据预处理\n", 9 | "- 将原始txt处理成每行的json格式" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "id": "0f47d4b8", 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import json" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 3, 25 | "id": "1de41397", 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "#统一全角转半角\n", 30 | "def strQ2B(ustring):\n", 31 | " cur_list = []\n", 32 | " for s in ustring:\n", 33 | " rstring = \"\"\n", 34 | " for uchar in s:\n", 35 | " inside_code = ord(uchar)\n", 36 | " if inside_code == 12288: # 全角空格直接转换\n", 37 | " inside_code = 32\n", 38 | " elif (inside_code >= 65281 and inside_code <= 65374): # 全角字符(除空格)根据关系转化\n", 39 | " inside_code -= 65248\n", 40 | " rstring += chr(inside_code)\n", 41 | " cur_list.append(rstring)\n", 42 | " return ''.join(cur_list)\n", 43 | "\n", 44 | "\n", 45 | "#转换特殊字符词组及标点符号\n", 46 | "trans_punctuations = {'don\\'t':\"do not\",\n", 47 | " '\"':'',\n", 48 | " ';':''\n", 49 | " }\n", 50 | "def process_data(strs):\n", 51 | " for key in trans_punctuations:\n", 52 | " strs = strs.replace(key, trans_punctuations[key])\n", 53 | " return strQ2B(strs)\n", 54 | "\n", 55 | "\n", 56 | "# 读取原始数据\n", 57 | "raw_data = []\n", 58 | "with open('Andersen Fairy Tales.txt', 'r') as f:\n", 59 | " for x in f:\n", 60 | " x = x.strip().lower()\n", 61 | " if x: raw_data.append(process_data(x))\n", 62 | "\n", 63 | "#保留长度大于1的句子\n", 64 | "raw_data = [x for x in raw_data if len(x.split(' '))>1]\n", 65 | "\n", 66 | "\n", 67 | "#保存数据\n", 68 | "with open(\"./corpus.json\",\"w\") as f:\n", 69 | " json.dump(raw_data, f, indent=4)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "id": "afe8b1ad", 75 | "metadata": {}, 76 | "source": [ 77 | "# 模型构建" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 4, 83 | "id": "3b3ba03b", 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "import json\n", 88 | "import os\n", 89 | "from collections import Counter\n", 90 | "from torch.utils.data import Dataset, DataLoader\n", 91 | "from tqdm.auto import tqdm\n", 92 | "import torch\n", 93 | "import torch.nn as nn\n", 94 | "from torch import optim\n", 95 | "import torch.nn.functional as F\n", 96 | "from torch.nn.functional import cross_entropy\n", 97 | "from torch.autograd import Variable" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "id": "80f708d1", 103 | "metadata": {}, 104 | "source": [ 105 | "## 配置文件" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 5, 111 | "id": "7f7fb4cd", 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "# 配置文件\n", 116 | "config ={\n", 117 | " \"elmo\": {\n", 118 | " \"activation\": \"relu\",\n", 119 | " \"filters\": [[1, 32], [2, 32], [3, 64], [4, 128], [5, 256], [6, 512], [7, 1024]],\n", 120 | " \"n_highway\": 2, \n", 121 | " \"word_dim\": 300,\n", 122 | " \"char_dim\": 50,\n", 123 | " \"max_char_token\": 50,\n", 124 | " \"min_count\":5,\n", 125 | " \"max_length\":256,\n", 126 | " \"output_dim\":150,\n", 127 | " \"units\":256,\n", 128 | " \"n_layers\":2,\n", 129 | " },\n", 130 | " \"batch_size\":32,\n", 131 | " \"epochs\":50,\n", 132 | " \"lr\":0.00001,\n", 133 | "}\n", 134 | "\n", 135 | "# 保存路径\n", 136 | "model_save_path=\"./elmo_model\"" 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "id": "c908fc59", 142 | "metadata": {}, 143 | "source": [ 144 | "## 数据集构建" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 6, 150 | "id": "867a490f", 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "# 读取语料\n", 155 | "with open(\"./corpus.json\") as f:\n", 156 | " corpus = json.load(f)\n", 157 | " corpus = corpus[:1000] # 测试" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 7, 163 | "id": "853721b5", 164 | "metadata": {}, 165 | "outputs": [ 166 | { 167 | "name": "stdout", 168 | "output_type": "stream", 169 | "text": [ 170 | "device: cuda\n" 171 | ] 172 | } 173 | ], 174 | "source": [ 175 | "# 检测是否有可用GPU\n", 176 | "if torch.cuda.is_available():\n", 177 | " device = torch.device('cuda')\n", 178 | "else:\n", 179 | " device = torch.device('cpu')\n", 180 | "print('device: ' + str(device))" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 8, 186 | "id": "1efec278", 187 | "metadata": {}, 188 | "outputs": [], 189 | "source": [ 190 | "#分词器\n", 191 | "class Tokenizer:\n", 192 | " def __init__(self, word2id,ch2id):\n", 193 | " self.word2id = word2id\n", 194 | " self.ch2id = ch2id\n", 195 | " self.id2word = {i: word for word, i in word2id.items()}\n", 196 | " self.id2ch = {i: char for char, i in ch2id.items()}\n", 197 | " \n", 198 | " def tokenize(self,text,max_length=512,max_char=50):\n", 199 | " oov_id, pad_id = self.word2id.get(\"\"), self.word2id.get(\"\")\n", 200 | " w = torch.LongTensor(max_length).fill_(pad_id)\n", 201 | " words = text.lower().split()\n", 202 | " for i, wi in enumerate(words[:max_length]):\n", 203 | " w[i] = self.word2id.get(wi, oov_id)\n", 204 | " oov_id, pad_id = self.ch2id.get(\"\"), self.ch2id.get(\"\")\n", 205 | " c = torch.LongTensor(max_length,max_char).fill_(pad_id)\n", 206 | " for i, wi in enumerate(words[:max_length]):\n", 207 | " for j,wij in enumerate(wi[:max_char]):\n", 208 | " c[i][j]=self.ch2id.get(wij, oov_id)\n", 209 | " return w , c , len(words[:max_length])\n", 210 | "\n", 211 | " def save(self,path):\n", 212 | " try:\n", 213 | " os.mkdir(path)\n", 214 | " except:\n", 215 | " pass\n", 216 | " tok ={\n", 217 | " \"word2id\":self.word2id,\n", 218 | " \"ch2id\":self.ch2id\n", 219 | " }\n", 220 | " with open(f\"{path}/tokenizer.json\",\"w\") as f:\n", 221 | " json.dump(tok,f,indent=4)\n", 222 | " \n", 223 | " \n", 224 | "# 从语料中构建\n", 225 | "def from_corpus(corpus,min_count=5):\n", 226 | " word_count = Counter()\n", 227 | " for sentence in corpus:\n", 228 | " word_count.update(sentence.split())\n", 229 | " word_count = list(word_count.items())\n", 230 | " word_count.sort(key=lambda x: x[1], reverse=True)\n", 231 | " for i, (word, count) in enumerate(word_count):\n", 232 | " if count < min_count:\n", 233 | " break\n", 234 | " vocab = word_count[:i]\n", 235 | " vocab = [v[0] for v in vocab]\n", 236 | " word_lexicon = {}\n", 237 | " for special_word in ['', '']:\n", 238 | " if special_word not in word_lexicon:\n", 239 | " word_lexicon[special_word] = len(word_lexicon)\n", 240 | " for word in vocab:\n", 241 | " if word not in word_lexicon:\n", 242 | " word_lexicon[word] = len(word_lexicon)\n", 243 | " char_lexicon = {}\n", 244 | " for special_char in ['', '']:\n", 245 | " if special_char not in char_lexicon:\n", 246 | " char_lexicon[special_char] = len(char_lexicon)\n", 247 | " for sentence in corpus:\n", 248 | " for word in sentence.split():\n", 249 | " for ch in word:\n", 250 | " if ch not in char_lexicon:\n", 251 | " char_lexicon[ch] = len(char_lexicon)\n", 252 | " return Tokenizer(word_lexicon,char_lexicon)\n", 253 | "\n", 254 | "\n", 255 | "# 从checkpoint中构建\n", 256 | "def from_file(path):\n", 257 | " with open(f\"{path}/tokenizer.json\") as f:\n", 258 | " d = json.load(f)\n", 259 | " return Tokenizer(d[\"word2id\"],d[\"ch2id\"])" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 9, 265 | "id": "1f77f6f0", 266 | "metadata": {}, 267 | "outputs": [], 268 | "source": [ 269 | "# 初始化分词器\n", 270 | "tokenizer = from_corpus(corpus, config[\"elmo\"][\"min_count\"])" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 10, 276 | "id": "a5c60f5a", 277 | "metadata": {}, 278 | "outputs": [], 279 | "source": [ 280 | "# ELMO数据集生成器\n", 281 | "class ELMoDataSet(Dataset):\n", 282 | " def __init__(self,corpus,tokenizer):\n", 283 | " self.corpus=corpus\n", 284 | " self.tokenizer=tokenizer\n", 285 | " \n", 286 | " def __getitem__(self, idx):\n", 287 | " text = self.corpus[idx]\n", 288 | " w,c,i= self.tokenizer.tokenize(text,max_length=config[\"elmo\"][\"max_length\"],max_char=config[\"elmo\"][\"max_char_token\"])\n", 289 | " return w,c,i\n", 290 | " \n", 291 | " def __len__(self):\n", 292 | " return len(self.corpus)\n", 293 | "\n", 294 | "# 初始化数据集生成器\n", 295 | "data = ELMoDataSet(corpus,tokenizer)" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 11, 301 | "id": "b24ad733", 302 | "metadata": {}, 303 | "outputs": [], 304 | "source": [ 305 | "# 初始化Pytorch框架的数据生成器\n", 306 | "data_loader = DataLoader(data, batch_size=config[\"batch_size\"])" 307 | ] 308 | }, 309 | { 310 | "cell_type": "markdown", 311 | "id": "5a2c87cc", 312 | "metadata": {}, 313 | "source": [ 314 | "## 模型初始化" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": 12, 320 | "id": "414187bb", 321 | "metadata": {}, 322 | "outputs": [], 323 | "source": [ 324 | "# Based upon https://gist.github.com/Redchards/65f1a6f758a1a5c5efb56f83933c3f6e\n", 325 | "# Original Paper https://arxiv.org/abs/1505.00387\n", 326 | "# 我们用残差网络替代HighWay\n", 327 | "class HighWay(nn.Module):\n", 328 | " def __init__(self, input_dim, num_layers=1,activation= nn.functional.relu):\n", 329 | " super(HighWay, self).__init__()\n", 330 | " self._input_dim = input_dim\n", 331 | " self._layers = nn.ModuleList([nn.Linear(input_dim, input_dim * 2) for _ in range(num_layers)])\n", 332 | " self._activation = activation\n", 333 | " for layer in self._layers:\n", 334 | " layer.bias[input_dim:].data.fill_(1)\n", 335 | " \n", 336 | " def forward(self, inputs):\n", 337 | " current_input = inputs\n", 338 | " for layer in self._layers:\n", 339 | " projected_input = layer(current_input)\n", 340 | " linear_part = current_input\n", 341 | " nonlinear_part = projected_input[:, (0 * self._input_dim):(1 * self._input_dim)]\n", 342 | " gate = projected_input[:, (1 * self._input_dim):(2 * self._input_dim)]\n", 343 | " nonlinear_part = self._activation(nonlinear_part)\n", 344 | " gate = torch.sigmoid(gate)\n", 345 | " current_input = gate * linear_part + (1 - gate) * nonlinear_part\n", 346 | " return current_input" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": 15, 352 | "id": "90cf8e3d", 353 | "metadata": {}, 354 | "outputs": [], 355 | "source": [ 356 | "class ELMo(nn.Module):\n", 357 | " def __init__(self,tokenizer,config):\n", 358 | " super(ELMo, self).__init__()\n", 359 | " self.config=config\n", 360 | " self.tokenizer = tokenizer\n", 361 | " self.word_embedding = nn.Embedding(len(tokenizer.word2id),config[\"elmo\"][\"word_dim\"],padding_idx=tokenizer.word2id.get(\"\"))\n", 362 | " self.char_embedding = nn.Embedding(len(tokenizer.ch2id),config[\"elmo\"][\"char_dim\"],padding_idx=tokenizer.ch2id.get(\"\"))\n", 363 | " self.output_dim = config[\"elmo\"][\"output_dim\"]\n", 364 | " activation = config[\"elmo\"][\"activation\"]\n", 365 | " if activation==\"relu\":\n", 366 | " self.act = nn.ReLU()\n", 367 | " elif activation==\"tanh\":\n", 368 | " self.act=nn.Tanh()\n", 369 | " self.emb_dim = config[\"elmo\"][\"word_dim\"]\n", 370 | " self.convolutions = []\n", 371 | " filters = config[\"elmo\"][\"filters\"]\n", 372 | " char_dim = config[\"elmo\"][\"char_dim\"]\n", 373 | " for i, (width, num) in enumerate(filters):\n", 374 | " conv = nn.Conv1d(in_channels=char_dim,\n", 375 | " out_channels=num,\n", 376 | " kernel_size=width,\n", 377 | " bias=True\n", 378 | " )\n", 379 | " self.convolutions.append(conv)\n", 380 | " self.convolutions = nn.ModuleList(self.convolutions)\n", 381 | " self.n_filters = sum(f[1] for f in filters)\n", 382 | " self.n_highway = config[\"elmo\"][\"n_highway\"]\n", 383 | " self.highways = HighWay(self.n_filters, self.n_highway, activation=self.act)\n", 384 | " self.emb_dim += self.n_filters\n", 385 | " self.projection = nn.Linear(self.emb_dim, self.output_dim, bias=True)\n", 386 | " self.f=[nn.LSTM(input_size = config[\"elmo\"][\"output_dim\"], hidden_size = config[\"elmo\"][\"units\"], batch_first=True)]\n", 387 | " self.b=[nn.LSTM(input_size = config[\"elmo\"][\"output_dim\"], hidden_size = config[\"elmo\"][\"units\"], batch_first=True)]\n", 388 | " for _ in range(config[\"elmo\"][\"n_layers\"]-1):\n", 389 | " self.f.append(nn.LSTM(input_size = config[\"elmo\"][\"units\"], hidden_size = config[\"elmo\"][\"units\"], batch_first=True))\n", 390 | " self.b.append(nn.LSTM(input_size = config[\"elmo\"][\"units\"], hidden_size = config[\"elmo\"][\"units\"], batch_first=True))\n", 391 | " self.f = nn.ModuleList(self.f)\n", 392 | " self.b = nn.ModuleList(self.b)\n", 393 | " self.ln = nn.Linear(in_features=config[\"elmo\"][\"units\"], out_features=len(tokenizer.word2id))\n", 394 | " \n", 395 | " def forward(self, word_inp, chars_inp):\n", 396 | " embs = []\n", 397 | " batch_size, seq_len = word_inp.size(0), word_inp.size(1)\n", 398 | " word_emb = self.word_embedding(Variable(word_inp))\n", 399 | " embs.append(word_emb)\n", 400 | " chars_inp = chars_inp.view(batch_size * seq_len, -1)\n", 401 | " char_emb = self.char_embedding(Variable(chars_inp))\n", 402 | " char_emb = char_emb.transpose(1, 2)\n", 403 | " convs = []\n", 404 | " for i in range(len(self.convolutions)):\n", 405 | " convolved = self.convolutions[i](char_emb)\n", 406 | " convolved, _ = torch.max(convolved, dim=-1)\n", 407 | " convolved = self.act(convolved)\n", 408 | " convs.append(convolved)\n", 409 | " char_emb = torch.cat(convs, dim=-1)\n", 410 | " char_emb = self.highways(char_emb)\n", 411 | " embs.append(char_emb.view(batch_size, -1, self.n_filters))\n", 412 | " token_embedding = torch.cat(embs, dim=2)\n", 413 | " embeddings = self.projection(token_embedding)\n", 414 | " fs = [embeddings] \n", 415 | " bs = [embeddings]\n", 416 | " for fl,bl in zip(self.f,self.b):\n", 417 | " o_f,_ = fl(fs[-1])\n", 418 | " fs.append(o_f)\n", 419 | " o_b,_ = bl(torch.flip(bs[-1],dims=[1,]))\n", 420 | " bs.append(torch.flip(o_b,dims=(1,)))\n", 421 | " return fs,bs\n", 422 | " \n", 423 | " def save_model(self,path):\n", 424 | " try:\n", 425 | " os.mkdir(path)\n", 426 | " except:\n", 427 | " pass\n", 428 | " torch.save(self.state_dict(),f'{path}/model.pt')\n", 429 | " with open(f\"{path}/config.json\",\"w\") as f:\n", 430 | " json.dump(self.config,f,indent=4)\n", 431 | " self.tokenizer.save(path)\n", 432 | " \n", 433 | " @classmethod\n", 434 | " def from_checkpoint(cls,path,device):\n", 435 | " with open(f\"{path}/config.json\") as f:\n", 436 | " config = json.load(f)\n", 437 | " tokenizer = Tokenizer.from_file(path)\n", 438 | " model = cls(tokenizer,config)\n", 439 | " model.load_state_dict(torch.load(f'{path}/model.pt',map_location=device))\n", 440 | " return model" 441 | ] 442 | }, 443 | { 444 | "cell_type": "code", 445 | "execution_count": 16, 446 | "id": "4805db12", 447 | "metadata": {}, 448 | "outputs": [ 449 | { 450 | "data": { 451 | "text/plain": [ 452 | "ELMo(\n", 453 | " (word_embedding): Embedding(1122, 300, padding_idx=1)\n", 454 | " (char_embedding): Embedding(39, 50, padding_idx=1)\n", 455 | " (act): ReLU()\n", 456 | " (convolutions): ModuleList(\n", 457 | " (0): Conv1d(50, 32, kernel_size=(1,), stride=(1,))\n", 458 | " (1): Conv1d(50, 32, kernel_size=(2,), stride=(1,))\n", 459 | " (2): Conv1d(50, 64, kernel_size=(3,), stride=(1,))\n", 460 | " (3): Conv1d(50, 128, kernel_size=(4,), stride=(1,))\n", 461 | " (4): Conv1d(50, 256, kernel_size=(5,), stride=(1,))\n", 462 | " (5): Conv1d(50, 512, kernel_size=(6,), stride=(1,))\n", 463 | " (6): Conv1d(50, 1024, kernel_size=(7,), stride=(1,))\n", 464 | " )\n", 465 | " (highways): HighWay(\n", 466 | " (_layers): ModuleList(\n", 467 | " (0): Linear(in_features=2048, out_features=4096, bias=True)\n", 468 | " (1): Linear(in_features=2048, out_features=4096, bias=True)\n", 469 | " )\n", 470 | " (_activation): ReLU()\n", 471 | " )\n", 472 | " (projection): Linear(in_features=2348, out_features=150, bias=True)\n", 473 | " (f): ModuleList(\n", 474 | " (0): LSTM(150, 256, batch_first=True)\n", 475 | " (1): LSTM(256, 256, batch_first=True)\n", 476 | " )\n", 477 | " (b): ModuleList(\n", 478 | " (0): LSTM(150, 256, batch_first=True)\n", 479 | " (1): LSTM(256, 256, batch_first=True)\n", 480 | " )\n", 481 | " (ln): Linear(in_features=256, out_features=1122, bias=True)\n", 482 | ")" 483 | ] 484 | }, 485 | "execution_count": 16, 486 | "metadata": {}, 487 | "output_type": "execute_result" 488 | } 489 | ], 490 | "source": [ 491 | "model = ELMo(tokenizer,config)\n", 492 | "model.to(device)" 493 | ] 494 | }, 495 | { 496 | "cell_type": "markdown", 497 | "id": "e8ad1d0b", 498 | "metadata": { 499 | "collapsed": true 500 | }, 501 | "source": [ 502 | "## 训练" 503 | ] 504 | }, 505 | { 506 | "cell_type": "code", 507 | "execution_count": 18, 508 | "id": "b1277113", 509 | "metadata": { 510 | "scrolled": true 511 | }, 512 | "outputs": [ 513 | { 514 | "name": "stdout", 515 | "output_type": "stream", 516 | "text": [ 517 | "Epoch: 1\n" 518 | ] 519 | }, 520 | { 521 | "data": { 522 | "application/vnd.jupyter.widget-view+json": { 523 | "model_id": "8149a2ccb4dd408398918be4c046df51", 524 | "version_major": 2, 525 | "version_minor": 0 526 | }, 527 | "text/plain": [ 528 | "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))" 529 | ] 530 | }, 531 | "metadata": {}, 532 | "output_type": "display_data" 533 | }, 534 | { 535 | "name": "stdout", 536 | "output_type": "stream", 537 | "text": [ 538 | "\n", 539 | "total_loss: 3546.50244140625\n", 540 | "Epoch: 2\n" 541 | ] 542 | }, 543 | { 544 | "data": { 545 | "application/vnd.jupyter.widget-view+json": { 546 | "model_id": "269103df89af41c3920ea3527295ae89", 547 | "version_major": 2, 548 | "version_minor": 0 549 | }, 550 | "text/plain": [ 551 | "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))" 552 | ] 553 | }, 554 | "metadata": {}, 555 | "output_type": "display_data" 556 | }, 557 | { 558 | "name": "stdout", 559 | "output_type": "stream", 560 | "text": [ 561 | "\n", 562 | "total_loss: 3521.490478515625\n", 563 | "Epoch: 3\n" 564 | ] 565 | }, 566 | { 567 | "data": { 568 | "application/vnd.jupyter.widget-view+json": { 569 | "model_id": "9f7846e7af9f4d43a5745954bcf2fb5e", 570 | "version_major": 2, 571 | "version_minor": 0 572 | }, 573 | "text/plain": [ 574 | "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))" 575 | ] 576 | }, 577 | "metadata": {}, 578 | "output_type": "display_data" 579 | }, 580 | { 581 | "name": "stdout", 582 | "output_type": "stream", 583 | "text": [ 584 | "\n", 585 | "total_loss: 3480.632568359375\n", 586 | "Epoch: 4\n" 587 | ] 588 | }, 589 | { 590 | "data": { 591 | "application/vnd.jupyter.widget-view+json": { 592 | "model_id": "2046b6ecbc0a41f9a5b1c824bd1820eb", 593 | "version_major": 2, 594 | "version_minor": 0 595 | }, 596 | "text/plain": [ 597 | "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))" 598 | ] 599 | }, 600 | "metadata": {}, 601 | "output_type": "display_data" 602 | }, 603 | { 604 | "name": "stdout", 605 | "output_type": "stream", 606 | "text": [ 607 | "\n", 608 | "total_loss: 3414.54736328125\n", 609 | "Epoch: 5\n" 610 | ] 611 | }, 612 | { 613 | "data": { 614 | "application/vnd.jupyter.widget-view+json": { 615 | "model_id": "6179a63d480c492a844270e8dfcf58da", 616 | "version_major": 2, 617 | "version_minor": 0 618 | }, 619 | "text/plain": [ 620 | "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))" 621 | ] 622 | }, 623 | "metadata": {}, 624 | "output_type": "display_data" 625 | }, 626 | { 627 | "name": "stdout", 628 | "output_type": "stream", 629 | "text": [ 630 | "\n", 631 | "total_loss: 3278.797607421875\n", 632 | "Epoch: 6\n" 633 | ] 634 | }, 635 | { 636 | "data": { 637 | "application/vnd.jupyter.widget-view+json": { 638 | "model_id": "74b18065f2c14dc9bcc4280ad081f93c", 639 | "version_major": 2, 640 | "version_minor": 0 641 | }, 642 | "text/plain": [ 643 | "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))" 644 | ] 645 | }, 646 | "metadata": {}, 647 | "output_type": "display_data" 648 | }, 649 | { 650 | "name": "stdout", 651 | "output_type": "stream", 652 | "text": [ 653 | "\n", 654 | "total_loss: 2848.858154296875\n", 655 | "Epoch: 7\n" 656 | ] 657 | }, 658 | { 659 | "data": { 660 | "application/vnd.jupyter.widget-view+json": { 661 | "model_id": "b0f6bf57a60145718d24a6a4a3cded16", 662 | "version_major": 2, 663 | "version_minor": 0 664 | }, 665 | "text/plain": [ 666 | "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))" 667 | ] 668 | }, 669 | "metadata": {}, 670 | "output_type": "display_data" 671 | }, 672 | { 673 | "name": "stdout", 674 | "output_type": "stream", 675 | "text": [ 676 | "\n", 677 | "total_loss: 2211.5849609375\n", 678 | "Epoch: 8\n" 679 | ] 680 | }, 681 | { 682 | "data": { 683 | "application/vnd.jupyter.widget-view+json": { 684 | "model_id": "f37146aa0032426bb6dbdabf255860f8", 685 | "version_major": 2, 686 | "version_minor": 0 687 | }, 688 | "text/plain": [ 689 | "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))" 690 | ] 691 | }, 692 | "metadata": {}, 693 | "output_type": "display_data" 694 | }, 695 | { 696 | "name": "stdout", 697 | "output_type": "stream", 698 | "text": [ 699 | "\n", 700 | "total_loss: 1741.710205078125\n", 701 | "Epoch: 9\n" 702 | ] 703 | }, 704 | { 705 | "data": { 706 | "application/vnd.jupyter.widget-view+json": { 707 | "model_id": "11e92d641d3b440485be364d32240165", 708 | "version_major": 2, 709 | "version_minor": 0 710 | }, 711 | "text/plain": [ 712 | "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))" 713 | ] 714 | }, 715 | "metadata": {}, 716 | "output_type": "display_data" 717 | }, 718 | { 719 | "name": "stdout", 720 | "output_type": "stream", 721 | "text": [ 722 | "\n", 723 | "total_loss: 1446.1951904296875\n", 724 | "Epoch: 10\n" 725 | ] 726 | }, 727 | { 728 | "data": { 729 | "application/vnd.jupyter.widget-view+json": { 730 | "model_id": "b4fa5675d68a45e1ac500a0cb0784e85", 731 | "version_major": 2, 732 | "version_minor": 0 733 | }, 734 | "text/plain": [ 735 | "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))" 736 | ] 737 | }, 738 | "metadata": {}, 739 | "output_type": "display_data" 740 | }, 741 | { 742 | "name": "stdout", 743 | "output_type": "stream", 744 | "text": [ 745 | "\n", 746 | "total_loss: 1268.07275390625\n", 747 | "Epoch: 11\n" 748 | ] 749 | }, 750 | { 751 | "data": { 752 | "application/vnd.jupyter.widget-view+json": { 753 | "model_id": "79c6f8166bd34d918756ac2ad1c4a9d9", 754 | "version_major": 2, 755 | "version_minor": 0 756 | }, 757 | "text/plain": [ 758 | "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))" 759 | ] 760 | }, 761 | "metadata": {}, 762 | "output_type": "display_data" 763 | }, 764 | { 765 | "name": "stdout", 766 | "output_type": "stream", 767 | "text": [ 768 | "\n", 769 | "total_loss: 1163.1085205078125\n", 770 | "Epoch: 12\n" 771 | ] 772 | }, 773 | { 774 | "data": { 775 | "application/vnd.jupyter.widget-view+json": { 776 | "model_id": "d4f547d4a674441e81359db4de5c6931", 777 | "version_major": 2, 778 | "version_minor": 0 779 | }, 780 | "text/plain": [ 781 | "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))" 782 | ] 783 | }, 784 | "metadata": {}, 785 | "output_type": "display_data" 786 | }, 787 | { 788 | "name": "stdout", 789 | "output_type": "stream", 790 | "text": [ 791 | "\n", 792 | "total_loss: 1097.5172119140625\n", 793 | "Epoch: 13\n" 794 | ] 795 | }, 796 | { 797 | "data": { 798 | "application/vnd.jupyter.widget-view+json": { 799 | "model_id": "b3a3296c1003467db97931acab637c4e", 800 | "version_major": 2, 801 | "version_minor": 0 802 | }, 803 | "text/plain": [ 804 | "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))" 805 | ] 806 | }, 807 | "metadata": {}, 808 | "output_type": "display_data" 809 | }, 810 | { 811 | "name": "stdout", 812 | "output_type": "stream", 813 | "text": [ 814 | "\n", 815 | "total_loss: 1054.2308349609375\n", 816 | "Epoch: 14\n" 817 | ] 818 | }, 819 | { 820 | "data": { 821 | "application/vnd.jupyter.widget-view+json": { 822 | "model_id": "25e85c000be74db8b460704fbc24f8a8", 823 | "version_major": 2, 824 | "version_minor": 0 825 | }, 826 | "text/plain": [ 827 | "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))" 828 | ] 829 | }, 830 | "metadata": {}, 831 | "output_type": "display_data" 832 | }, 833 | { 834 | "name": "stdout", 835 | "output_type": "stream", 836 | "text": [ 837 | "\n", 838 | "total_loss: 1023.152587890625\n", 839 | "Epoch: 15\n" 840 | ] 841 | }, 842 | { 843 | "data": { 844 | "application/vnd.jupyter.widget-view+json": { 845 | "model_id": "9be97e8d1f8f43e3a8a922719ae0b0d4", 846 | "version_major": 2, 847 | "version_minor": 0 848 | }, 849 | "text/plain": [ 850 | "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))" 851 | ] 852 | }, 853 | "metadata": {}, 854 | "output_type": "display_data" 855 | }, 856 | { 857 | "name": "stdout", 858 | "output_type": "stream", 859 | "text": [ 860 | "\n", 861 | "total_loss: 999.1707763671875\n", 862 | "Epoch: 16\n" 863 | ] 864 | }, 865 | { 866 | "data": { 867 | "application/vnd.jupyter.widget-view+json": { 868 | "model_id": "3cefb96da62646f08b46a0423fa872cf", 869 | "version_major": 2, 870 | "version_minor": 0 871 | }, 872 | "text/plain": [ 873 | "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))" 874 | ] 875 | }, 876 | "metadata": {}, 877 | "output_type": "display_data" 878 | }, 879 | { 880 | "name": "stdout", 881 | "output_type": "stream", 882 | "text": [ 883 | "\n", 884 | "total_loss: 980.1627197265625\n", 885 | "Epoch: 17\n" 886 | ] 887 | }, 888 | { 889 | "data": { 890 | "application/vnd.jupyter.widget-view+json": { 891 | "model_id": "9d3ce85d29c448dfbd6325d6a7eba605", 892 | "version_major": 2, 893 | "version_minor": 0 894 | }, 895 | "text/plain": [ 896 | "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))" 897 | ] 898 | }, 899 | "metadata": {}, 900 | "output_type": "display_data" 901 | }, 902 | { 903 | "name": "stdout", 904 | "output_type": "stream", 905 | "text": [ 906 | "\n", 907 | "total_loss: 964.4329833984375\n", 908 | "Epoch: 18\n" 909 | ] 910 | }, 911 | { 912 | "data": { 913 | "application/vnd.jupyter.widget-view+json": { 914 | "model_id": "fa2996b3ce85413399e426be07d72e7a", 915 | "version_major": 2, 916 | "version_minor": 0 917 | }, 918 | "text/plain": [ 919 | "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))" 920 | ] 921 | }, 922 | "metadata": {}, 923 | "output_type": "display_data" 924 | }, 925 | { 926 | "name": "stdout", 927 | "output_type": "stream", 928 | "text": [ 929 | "\n", 930 | "total_loss: 951.1318969726562\n", 931 | "Epoch: 19\n" 932 | ] 933 | }, 934 | { 935 | "data": { 936 | "application/vnd.jupyter.widget-view+json": { 937 | "model_id": "0293ff98543842ce8e7db4e92e507c04", 938 | "version_major": 2, 939 | "version_minor": 0 940 | }, 941 | "text/plain": [ 942 | "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))" 943 | ] 944 | }, 945 | "metadata": {}, 946 | "output_type": "display_data" 947 | }, 948 | { 949 | "name": "stdout", 950 | "output_type": "stream", 951 | "text": [ 952 | "\n", 953 | "total_loss: 939.527587890625\n", 954 | "Epoch: 20\n" 955 | ] 956 | }, 957 | { 958 | "data": { 959 | "application/vnd.jupyter.widget-view+json": { 960 | "model_id": "346fcab8245c4457b6fc23382377a3cd", 961 | "version_major": 2, 962 | "version_minor": 0 963 | }, 964 | "text/plain": [ 965 | "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))" 966 | ] 967 | }, 968 | "metadata": {}, 969 | "output_type": "display_data" 970 | }, 971 | { 972 | "name": "stdout", 973 | "output_type": "stream", 974 | "text": [ 975 | "\n" 976 | ] 977 | }, 978 | { 979 | "ename": "KeyboardInterrupt", 980 | "evalue": "", 981 | "output_type": "error", 982 | "traceback": [ 983 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", 984 | "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 985 | "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 21\u001b[0m \u001b[0mbll\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfunctional\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlog_softmax\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbl\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mdim\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 22\u001b[0m \u001b[0mloss\u001b[0m\u001b[1;33m+=\u001b[0m\u001b[0mloss_function\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfll\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mw\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mk\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m+\u001b[0m\u001b[0mloss_function\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbll\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mw\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mk\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 23\u001b[1;33m \u001b[0mloss\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 24\u001b[0m \u001b[0mopt\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 25\u001b[0m \u001b[0mopt\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 986 | "\u001b[1;32mc:\\users\\kangb\\appdata\\local\\conda\\conda\\envs\\py36\\lib\\site-packages\\torch\\tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[1;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[0;32m 183\u001b[0m \u001b[0mproducts\u001b[0m\u001b[1;33m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[1;33m.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 184\u001b[0m \"\"\"\n\u001b[1;32m--> 185\u001b[1;33m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 186\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 187\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 987 | "\u001b[1;32mc:\\users\\kangb\\appdata\\local\\conda\\conda\\envs\\py36\\lib\\site-packages\\torch\\autograd\\__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[1;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[0;32m 125\u001b[0m Variable._execution_engine.run_backward(\n\u001b[0;32m 126\u001b[0m \u001b[0mtensors\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 127\u001b[1;33m allow_unreachable=True) # allow_unreachable flag\n\u001b[0m\u001b[0;32m 128\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 129\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", 988 | "\u001b[1;31mKeyboardInterrupt\u001b[0m: " 989 | ] 990 | } 991 | ], 992 | "source": [ 993 | "opt = optim.Adam(model.parameters(),lr = config[\"lr\"])\n", 994 | "loss_function = torch.nn.NLLLoss()\n", 995 | "for epoch in range(config[\"epochs\"]):\n", 996 | " total_loss = 0\n", 997 | " print(f\"Epoch: {epoch+1}\")\n", 998 | " for batch in tqdm(data_loader):\n", 999 | " total_loss = 0\n", 1000 | " w , c , i = batch\n", 1001 | " w = w.to(device)\n", 1002 | " c = c.to(device)\n", 1003 | " f, b = model(w,c)\n", 1004 | " f, b = f[-1], b[-1]\n", 1005 | " k_max=torch.max(i)\n", 1006 | " loss = 0\n", 1007 | " for k in range(1,k_max):\n", 1008 | " fpass=f[:,k-1,:]\n", 1009 | " bpass=b[:,k-1,:]\n", 1010 | " fl = model.ln(fpass).squeeze()\n", 1011 | " bl = model.ln(bpass).squeeze()\n", 1012 | " fll = torch.nn.functional.log_softmax(fl,dim=1).squeeze()\n", 1013 | " bll = torch.nn.functional.log_softmax(bl,dim=1).squeeze()\n", 1014 | " loss+=loss_function(fll,w[:,k])+loss_function(bll,w[:,k])\n", 1015 | " loss.backward()\n", 1016 | " opt.step()\n", 1017 | " opt.zero_grad()\n", 1018 | " model.zero_grad()\n", 1019 | " total_loss += loss.detach().item()\n", 1020 | " model.save_model(model_save_path)\n", 1021 | " print('total_loss:',total_loss)" 1022 | ] 1023 | }, 1024 | { 1025 | "cell_type": "code", 1026 | "execution_count": null, 1027 | "id": "c9120a89", 1028 | "metadata": { 1029 | "collapsed": true 1030 | }, 1031 | "outputs": [], 1032 | "source": [] 1033 | } 1034 | ], 1035 | "metadata": { 1036 | "kernelspec": { 1037 | "display_name": "Python 3", 1038 | "language": "python", 1039 | "name": "python3" 1040 | }, 1041 | "language_info": { 1042 | "codemirror_mode": { 1043 | "name": "ipython", 1044 | "version": 3 1045 | }, 1046 | "file_extension": ".py", 1047 | "mimetype": "text/x-python", 1048 | "name": "python", 1049 | "nbconvert_exporter": "python", 1050 | "pygments_lexer": "ipython3", 1051 | "version": "3.6.10" 1052 | }, 1053 | "toc": { 1054 | "base_numbering": 1, 1055 | "nav_menu": {}, 1056 | "number_sections": true, 1057 | "sideBar": true, 1058 | "skip_h1_title": false, 1059 | "title_cell": "Table of Contents", 1060 | "title_sidebar": "Contents", 1061 | "toc_cell": false, 1062 | "toc_position": {}, 1063 | "toc_section_display": true, 1064 | "toc_window_display": true 1065 | } 1066 | }, 1067 | "nbformat": 4, 1068 | "nbformat_minor": 5 1069 | } 1070 | -------------------------------------------------------------------------------- /文本表示/初识预训练模型elmo/代码/README.md: -------------------------------------------------------------------------------- 1 | ## 代码及相关语料 2 | -------------------------------------------------------------------------------- /文本表示/初识预训练模型elmo/课件/README.md: -------------------------------------------------------------------------------- 1 | ## 暂时PDF格式,后续markdown格式 2 | -------------------------------------------------------------------------------- /文本表示/初识预训练模型elmo/课件/初识预训练模型:elmo.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datawhalechina/hands-dirty-nlp/1218362ac4956169b358a9b462fc890a4e130df8/文本表示/初识预训练模型elmo/课件/初识预训练模型:elmo.pdf --------------------------------------------------------------------------------