├── README.md ├── notebooks ├── Bokeh画图技能.ipynb ├── Coggle202405意图识别 │ ├── 01-对话意图识别.ipynb │ ├── 02-正则关键词.ipynb │ ├── 03-TFIDF文本分类.ipynb │ ├── 04-词向量训练与使用.ipynb │ ├── 05-FastText文本分类.ipynb │ ├── 06-LSTM文本分类.ipynb │ ├── 07-TextCNN文本分类.ipynb │ ├── 08-BERT文本分类.ipynb │ ├── 09-BERT高效微调.ipynb │ ├── 10-T5加载与使用.ipynb │ ├── 11-T5微调文本分类.ipynb │ ├── 12-Qwen大模型加载与使用.ipynb │ └── 13-Qwen大模型微调文本分类.ipynb ├── Pandas画图技能.ipynb ├── Prophet教程.ipynb ├── Seaborn画图案例.ipynb ├── Zilliz-Milvus-Cloud.ipynb ├── llm │ ├── ChatGPT-01-API基础.ipynb │ ├── ChatGPT-02-API进阶.ipynb │ ├── ChatGPT-03-Prompt特征工程.ipynb │ ├── ChatGPT-04-langchain.ipynb │ ├── ChatGPT-05-Embedding模型.ipynb │ ├── ChatGPT-07-Function Call.ipynb │ ├── Coggle202401-RAG打卡.ipynb │ ├── GLM-01-基础API.ipynb │ ├── RAG-BM25-Rerank.ipynb │ ├── glm-info-extraction-agent.ipynb │ └── xfyun_api.ipynb ├── m3u8在线视频下载.ipynb ├── nlp │ ├── KeyBERT教程.ipynb │ ├── NLP-Pytorch基础.ipynb │ ├── README.md │ ├── bert-choice-example.ipynb │ ├── bert-cls-example.ipynb │ ├── bert-mlm-example.ipynb │ ├── bert-ner-example.ipynb │ ├── bert-nsp-example.ipynb │ ├── bert-prompt-cls.ipynb │ ├── bert-qa-example.ipynb │ ├── transformer基础.ipynb │ ├── 文本分类-BILSTM模型.ipynb │ ├── 文本分类-FastText模型-中文分类.ipynb │ ├── 文本分类-FastText模型-中文进阶.ipynb │ ├── 文本分类-FastText模型-英文分类.ipynb │ ├── 文本分类-TextCNN模型.ipynb │ └── 无监督句子编码.ipynb ├── numexpr教程.ipynb ├── sklearn迭代训练.ipynb ├── thirty-days-of-ml-2403.ipynb └── 时间序列库.ipynb └── streamlit ├── chatgpt_demo.py ├── glm_chat2pandas.py └── glm_demo.py /README.md: -------------------------------------------------------------------------------- 1 |

数据科学教程案例

2 | 3 | ## 目录 4 | 5 |
6 | 基础教程 7 |
8 | 9 | 10 | 11 | 14 | 17 | 18 | 19 | 26 | 33 | 34 | 35 | 36 | 37 |
12 | 可视化 13 | 15 | 机器学习 16 |
20 | 25 | 27 | 32 |
38 | 39 | 40 |
41 | 进阶教程 42 |
43 | 44 | 45 | 46 | 49 | 52 | 55 | 56 | 57 | 76 | 93 | 98 | 99 | 100 | 101 | 102 |
47 | NLP基础 48 | 50 | 文本分类(意图识别) 51 | 53 | 关键词提取与文本聚类 54 |
58 | 75 | 77 | 92 | 94 | 97 |
103 | 104 | 105 | 106 | 107 | ### TODO 108 | 109 | - 【TODO】sklearn模型加速 110 | - 【TODO】XGBoost/LightGBM/CatBoost 111 | - 分类任务/回归任务/排序任务 112 | - 自定义评价函数和目标函数 113 | - 自定义数据集训练 114 | - 【TODO】时序数据划分方法 115 | - 【TODO】时序特征工程 116 | - 【TODO】Pytorch线性回归 117 | - 【TODO】Pytorch搭建CNN模型 118 | - 【TODO】Pytorch自定义数据集 119 | - 【TODO】Paddle线性回归 120 | - 【TODO】Paddle搭建CNN模型 121 | - 【TODO】Paddle自定义数据集 122 | - 【TODO】TextRank中文关键词识别 + 文本摘要 123 | - 【TODO】Rake中文关键词识别 + 文本摘要 124 | - 【TODO】图像检索:颜色直方图 125 | - 【TODO】图像检索:局部SIFT关键点 + 词袋编码/VLDA/FV 126 | - 【TODO】图像检索:卷积特征/Vit特征 127 | - 【TODO】图像细粒度检索:电商商品识别 128 | - 【TODO】图像自编码器 129 | - 【TODO】图像变分自编码器 130 | - 【TODO】图像MAE自监督训练 131 | 132 | ## 其他代码 133 | 134 | - [m3u8在线视频下载](https://github.com/coggle-club/notebooks/blob/main/notebooks/m3u8%E5%9C%A8%E7%BA%BF%E8%A7%86%E9%A2%91%E4%B8%8B%E8%BD%BD.ipynb) 135 | 136 | ## 环境说明 137 | 138 | - 代码使用Py3 Notebook编写,如无标注深度学习框架均为Pytorch。 139 | - 代码数据集部分需要额外下载,如需要请关注下面公众号询问。 140 | - 部分代码需要GPU,推荐11GB或以上配置 141 | 142 | ## 关于我们 143 | 144 | ![](https://coggle.club/assets/img/coggle_qrcode.jpg) 145 | 146 | - 竞赛日历:https://coggle.club/ 147 | - 知乎专栏:https://zhuanlan.zhihu.com/DataAI 148 | - Github主页:https://github.com/coggle-club 149 | 150 | -------------------------------------------------------------------------------- /notebooks/Coggle202405意图识别/02-正则关键词.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "id": "2b0453b9-e9da-4e1c-855c-cab94c9064ce", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stderr", 11 | "output_type": "stream", 12 | "text": [ 13 | "Building prefix dict from the default dictionary ...\n", 14 | "Loading model from cache /tmp/jieba.cache\n", 15 | "Loading model cost 0.651 seconds.\n", 16 | "Prefix dict has been built successfully.\n" 17 | ] 18 | } 19 | ], 20 | "source": [ 21 | "import pandas as pd\n", 22 | "import jieba\n", 23 | "import matplotlib.pyplot as plt\n", 24 | "\n", 25 | "# 读取数据集,这里是直接联网读取,也可以通过下载文件,再读取\n", 26 | "data_dir = 'https://mirror.coggle.club/dataset/coggle-competition/'\n", 27 | "train_data = pd.read_csv(data_dir + 'intent-classify/train.csv', sep='\\t', header=None)\n", 28 | "test_data = pd.read_csv(data_dir + 'intent-classify/test.csv', sep='\\t', header=None)\n", 29 | "\n", 30 | "train_text = '。'.join(list(train_data[0]))\n", 31 | "train_words = jieba.lcut(train_text)\n", 32 | "\n", 33 | "cn_stopwords = ' '.join(pd.read_csv('https://mirror.coggle.club/stopwords/baidu_stopwords.txt', header=None)[0])\n", 34 | "train_words = [x for x in train_words if x not in cn_stopwords]\n", 35 | "train_words = [x for x in train_words if len(x) > 1]\n", 36 | "train_words = [x for x in train_words if not x.isdigit()]" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 4, 42 | "id": "6557e4fe-5e10-4959-8ec7-d3fd155635e8", 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "from collections import Counter\n", 47 | "train_words_freq = Counter(train_words)\n", 48 | "train_words = [x for x in train_words if train_words_freq[x] >= 5]" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 5, 54 | "id": "1e78e280-ff02-4df4-9d20-092d2a8fa642", 55 | "metadata": { 56 | "scrolled": true 57 | }, 58 | "outputs": [], 59 | "source": [ 60 | "train_word_prior = {}\n", 61 | "for row in train_data.iloc[:].itertuples():\n", 62 | " text, label = row[1], row[2]\n", 63 | " words = jieba.lcut(text)\n", 64 | " words = [x for x in words if x in train_words]\n", 65 | "\n", 66 | " if len(words) == 0:\n", 67 | " continue\n", 68 | "\n", 69 | " for word in words:\n", 70 | " if word not in train_word_prior:\n", 71 | " train_word_prior[word] = {\"total\": 0}\n", 72 | "\n", 73 | " if label not in train_word_prior[word]:\n", 74 | " train_word_prior[word][label] = 0\n", 75 | "\n", 76 | " train_word_prior[word][label] += 1\n", 77 | " train_word_prior[word][\"total\"] += 1" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 6, 83 | "id": "0d317a33-42da-456c-8015-e0b57d35138d", 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "train_word_prior = pd.DataFrame(train_word_prior).T\n", 88 | "train_word_prior.fillna(0, inplace=True)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 7, 94 | "id": "f18d195b-8ac9-4b60-9fa2-14ce00777e2d", 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "for category in train_data[1].unique():\n", 99 | " train_word_prior[category] /= train_word_prior['total']" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 8, 105 | "id": "a511da43-9d5a-4b53-a538-339cdee6cee4", 106 | "metadata": {}, 107 | "outputs": [ 108 | { 109 | "name": "stderr", 110 | "output_type": "stream", 111 | "text": [ 112 | "/tmp/ipykernel_59295/1433704415.py:2: DeprecationWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.\n", 113 | " train_word_prior.groupby('category').apply(lambda x: list(x.index))\n" 114 | ] 115 | }, 116 | { 117 | "data": { 118 | "text/plain": [ 119 | "category\n", 120 | "Alarm-Update [早上, 我定, 下午, 参加, 公司, 闹钟, 活动, 提醒, 创建, 周末, 上午, 取...\n", 121 | "Audio-Play [故事, 小说, 广播剧, 英文版, 岳云鹏, 爆笑, 相声, 有声, 俄语, 第五章, 郭...\n", 122 | "Calendar-Query [昨天, 农历, 我查, 星期, 几号, 告诉, 几月, 查查, 礼拜, 几是, 春节, 母...\n", 123 | "FilmTele-Play [播放, 古装, 爱情, 电视剧, 一个, 推理, 一会, 地方, 导演, 赵丽颖, 麻烦,...\n", 124 | "HomeAppliance-Control [空调, 客厅, 风速, 打开, 烤箱, 儿童房, 调高, 洗衣机, 停止, 工作, 模式,...\n", 125 | "Music-Play [随便, 一首, 专辑, 单曲, 循环, 王菲, 钢琴曲, 随机, 治愈, 日语, 歌曲, ...\n", 126 | "Other [永远, 电话, 笑话, 之间, 老婆, 不好, 漫画, 有人]\n", 127 | "Radio-Listen [河南, 新闻广播, 新闻台, 交通, 广播电台, 经典音乐, 七点, 中央, 电台, 都市...\n", 128 | "TVProgram-Play [播出, 卫视, 广西, 法治, CCTV11, 剧场, 开播, 文化, 结束, 早间, 贵...\n", 129 | "Travel-Query [汽车票, 回家, 深圳, 武汉, 北京, 桂林, 飞机, 起飞, 快点, 三张, 成都, ...\n", 130 | "Video-Play [挑战, 游戏, 视频, 和平, 精英, 花絮, 转播, 比赛, 现场, 世界, 年谍, 第...\n", 131 | "Weather-Query [查询, 海南, 几级, 刮风, 几天, 山西, 明天, 衡水, 气温, 适合, 杭州, 香...\n", 132 | "dtype: object" 133 | ] 134 | }, 135 | "execution_count": 8, 136 | "metadata": {}, 137 | "output_type": "execute_result" 138 | } 139 | ], 140 | "source": [ 141 | "train_word_prior['category'] = train_word_prior.columns[1:][train_word_prior.values[:, 1:].argmax(1)]\n", 142 | "train_word_prior.groupby('category').apply(lambda x: list(x.index))" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "id": "2e1787ac-fdc5-461f-92bc-d6361cab525d", 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [] 152 | } 153 | ], 154 | "metadata": { 155 | "kernelspec": { 156 | "display_name": "py3.11", 157 | "language": "python", 158 | "name": "py3.11" 159 | }, 160 | "language_info": { 161 | "codemirror_mode": { 162 | "name": "ipython", 163 | "version": 3 164 | }, 165 | "file_extension": ".py", 166 | "mimetype": "text/x-python", 167 | "name": "python", 168 | "nbconvert_exporter": "python", 169 | "pygments_lexer": "ipython3", 170 | "version": "3.11.8" 171 | } 172 | }, 173 | "nbformat": 4, 174 | "nbformat_minor": 5 175 | } 176 | -------------------------------------------------------------------------------- /notebooks/Coggle202405意图识别/03-TFIDF文本分类.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "47ce31dc-0635-4049-821a-b3fd6102ae7b", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import pandas as pd\n", 11 | "import jieba\n", 12 | "import matplotlib.pyplot as plt\n", 13 | "\n", 14 | "# 读取数据集,这里是直接联网读取,也可以通过下载文件,再读取\n", 15 | "data_dir = 'https://mirror.coggle.club/dataset/coggle-competition/'\n", 16 | "train_data = pd.read_csv(data_dir + 'intent-classify/train.csv', sep='\\t', header=None)\n", 17 | "test_data = pd.read_csv(data_dir + 'intent-classify/test.csv', sep='\\t', header=None)\n", 18 | "\n", 19 | "cn_stopwords = pd.read_csv('https://mirror.coggle.club/stopwords/baidu_stopwords.txt', header=None)[0].values" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "id": "a08e8b45-9256-42d3-b89e-bb76dc94d9c9", 26 | "metadata": {}, 27 | "outputs": [ 28 | { 29 | "name": "stderr", 30 | "output_type": "stream", 31 | "text": [ 32 | "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/feature_extraction/text.py:525: UserWarning: The parameter 'token_pattern' will not be used since 'tokenizer' is not None'\n", 33 | " warnings.warn(\n", 34 | "Building prefix dict from the default dictionary ...\n", 35 | "Loading model from cache /tmp/jieba.cache\n", 36 | "Loading model cost 0.626 seconds.\n", 37 | "Prefix dict has been built successfully.\n", 38 | "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/feature_extraction/text.py:408: UserWarning: Your stop_words may be inconsistent with your preprocessing. Tokenizing the stop words generated tokens [\"'\", 'a', 'ain', 'aren', 'c', 'couldn', 'd', 'didn', 'doesn', 'don', 'hadn', 'hasn', 'haven', 'i', 'isn', 'll', 'm', 'mon', 's', 'shouldn', 't', 've', 'wasn', 'weren', 'won', 'wouldn', '下', '不', '为什', '什', '今', '使', '先', '却', '只', '唷', '啪', '喔', '天', '好', '後', '最', '漫', '然', '特', '特别', '见', '设', '说', '达', '面', '麽', '-'] not in stop_words.\n", 39 | " warnings.warn(\n" 40 | ] 41 | } 42 | ], 43 | "source": [ 44 | "from sklearn.feature_extraction.text import TfidfVectorizer\n", 45 | "\n", 46 | "tfidf = TfidfVectorizer(\n", 47 | " tokenizer=jieba.lcut,\n", 48 | " stop_words=list(cn_stopwords)\n", 49 | ")\n", 50 | "train_tfidf = tfidf.fit_transform(train_data[0])\n", 51 | "test_tfidf = tfidf.transform(test_data[0])" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 3, 57 | "id": "cfcc582a-80ae-458f-bedf-2233a15eeca0", 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "from sklearn.linear_model import LogisticRegression\n", 62 | "from sklearn.neighbors import KNeighborsClassifier\n", 63 | "from sklearn.svm import LinearSVC\n", 64 | "from sklearn.metrics import classification_report\n", 65 | "from sklearn.model_selection import cross_val_predict" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 4, 71 | "id": "e710df9c-7743-45d4-8b40-9b7bec898121", 72 | "metadata": {}, 73 | "outputs": [ 74 | { 75 | "name": "stdout", 76 | "output_type": "stream", 77 | "text": [ 78 | " precision recall f1-score support\n", 79 | "\n", 80 | " Alarm-Update 0.98 0.93 0.96 1264\n", 81 | " Audio-Play 0.74 0.50 0.60 226\n", 82 | " Calendar-Query 0.99 0.95 0.97 1214\n", 83 | " FilmTele-Play 0.70 0.93 0.80 1355\n", 84 | "HomeAppliance-Control 0.94 0.97 0.96 1215\n", 85 | " Music-Play 0.88 0.87 0.87 1304\n", 86 | " Other 0.39 0.07 0.11 214\n", 87 | " Radio-Listen 0.94 0.89 0.91 1285\n", 88 | " TVProgram-Play 0.72 0.45 0.55 240\n", 89 | " Travel-Query 0.92 0.96 0.94 1220\n", 90 | " Video-Play 0.90 0.87 0.89 1334\n", 91 | " Weather-Query 0.92 0.96 0.94 1229\n", 92 | "\n", 93 | " accuracy 0.89 12100\n", 94 | " macro avg 0.84 0.78 0.79 12100\n", 95 | " weighted avg 0.89 0.89 0.89 12100\n", 96 | "\n" 97 | ] 98 | } 99 | ], 100 | "source": [ 101 | "cv_pred = cross_val_predict(\n", 102 | " LogisticRegression(),\n", 103 | " train_tfidf, train_data[1]\n", 104 | ")\n", 105 | "print(classification_report(train_data[1], cv_pred))" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 5, 111 | "id": "7f89dd4e-1374-419c-9ac1-df2435547997", 112 | "metadata": {}, 113 | "outputs": [ 114 | { 115 | "name": "stderr", 116 | "output_type": "stream", 117 | "text": [ 118 | "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/svm/_classes.py:31: FutureWarning: The default value of `dual` will change from `True` to `'auto'` in 1.5. Set the value of `dual` explicitly to suppress the warning.\n", 119 | " warnings.warn(\n", 120 | "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/svm/_classes.py:31: FutureWarning: The default value of `dual` will change from `True` to `'auto'` in 1.5. Set the value of `dual` explicitly to suppress the warning.\n", 121 | " warnings.warn(\n", 122 | "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/svm/_classes.py:31: FutureWarning: The default value of `dual` will change from `True` to `'auto'` in 1.5. Set the value of `dual` explicitly to suppress the warning.\n", 123 | " warnings.warn(\n", 124 | "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/svm/_classes.py:31: FutureWarning: The default value of `dual` will change from `True` to `'auto'` in 1.5. Set the value of `dual` explicitly to suppress the warning.\n", 125 | " warnings.warn(\n", 126 | "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/svm/_classes.py:31: FutureWarning: The default value of `dual` will change from `True` to `'auto'` in 1.5. Set the value of `dual` explicitly to suppress the warning.\n", 127 | " warnings.warn(\n" 128 | ] 129 | }, 130 | { 131 | "name": "stdout", 132 | "output_type": "stream", 133 | "text": [ 134 | " precision recall f1-score support\n", 135 | "\n", 136 | " Alarm-Update 0.97 0.95 0.96 1264\n", 137 | " Audio-Play 0.64 0.71 0.67 226\n", 138 | " Calendar-Query 0.98 0.97 0.98 1214\n", 139 | " FilmTele-Play 0.81 0.89 0.85 1355\n", 140 | "HomeAppliance-Control 0.97 0.98 0.98 1215\n", 141 | " Music-Play 0.90 0.89 0.89 1304\n", 142 | " Other 0.31 0.25 0.27 214\n", 143 | " Radio-Listen 0.94 0.90 0.92 1285\n", 144 | " TVProgram-Play 0.66 0.62 0.64 240\n", 145 | " Travel-Query 0.95 0.98 0.97 1220\n", 146 | " Video-Play 0.92 0.88 0.90 1334\n", 147 | " Weather-Query 0.96 0.97 0.97 1229\n", 148 | "\n", 149 | " accuracy 0.91 12100\n", 150 | " macro avg 0.83 0.83 0.83 12100\n", 151 | " weighted avg 0.91 0.91 0.91 12100\n", 152 | "\n" 153 | ] 154 | } 155 | ], 156 | "source": [ 157 | "cv_pred = cross_val_predict(\n", 158 | " LinearSVC(),\n", 159 | " train_tfidf, train_data[1]\n", 160 | ")\n", 161 | "print(classification_report(train_data[1], cv_pred))" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 6, 167 | "id": "7cff5d84-826e-484b-a555-6da50bda667b", 168 | "metadata": {}, 169 | "outputs": [ 170 | { 171 | "name": "stdout", 172 | "output_type": "stream", 173 | "text": [ 174 | " precision recall f1-score support\n", 175 | "\n", 176 | " Alarm-Update 0.83 0.92 0.87 1264\n", 177 | " Audio-Play 0.55 0.63 0.59 226\n", 178 | " Calendar-Query 0.81 0.96 0.88 1214\n", 179 | " FilmTele-Play 0.80 0.79 0.80 1355\n", 180 | "HomeAppliance-Control 0.92 0.97 0.94 1215\n", 181 | " Music-Play 0.83 0.82 0.83 1304\n", 182 | " Other 0.20 0.25 0.22 214\n", 183 | " Radio-Listen 0.91 0.83 0.87 1285\n", 184 | " TVProgram-Play 0.55 0.39 0.46 240\n", 185 | " Travel-Query 0.94 0.90 0.92 1220\n", 186 | " Video-Play 0.87 0.73 0.79 1334\n", 187 | " Weather-Query 0.92 0.89 0.90 1229\n", 188 | "\n", 189 | " accuracy 0.84 12100\n", 190 | " macro avg 0.76 0.76 0.76 12100\n", 191 | " weighted avg 0.84 0.84 0.84 12100\n", 192 | "\n" 193 | ] 194 | } 195 | ], 196 | "source": [ 197 | "cv_pred = cross_val_predict(\n", 198 | " KNeighborsClassifier(),\n", 199 | " train_tfidf, train_data[1]\n", 200 | ")\n", 201 | "print(classification_report(train_data[1], cv_pred))" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 7, 207 | "id": "06fe4b6a-e4bb-4116-b6c9-1f22a37c368a", 208 | "metadata": {}, 209 | "outputs": [ 210 | { 211 | "name": "stderr", 212 | "output_type": "stream", 213 | "text": [ 214 | "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/svm/_classes.py:31: FutureWarning: The default value of `dual` will change from `True` to `'auto'` in 1.5. Set the value of `dual` explicitly to suppress the warning.\n", 215 | " warnings.warn(\n" 216 | ] 217 | } 218 | ], 219 | "source": [ 220 | "model = LinearSVC()\n", 221 | "model.fit(train_tfidf, train_data[1])\n", 222 | "pd.DataFrame({\n", 223 | " 'ID':range(1, len(test_data) + 1),\n", 224 | " \"Target\":model.predict(test_tfidf)\n", 225 | "}).to_csv('nlp_submit.csv', index=None)\n", 226 | "# 可以提交到" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": null, 232 | "id": "2e1787ac-fdc5-461f-92bc-d6361cab525d", 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [] 236 | } 237 | ], 238 | "metadata": { 239 | "kernelspec": { 240 | "display_name": "py3.11", 241 | "language": "python", 242 | "name": "py3.11" 243 | }, 244 | "language_info": { 245 | "codemirror_mode": { 246 | "name": "ipython", 247 | "version": 3 248 | }, 249 | "file_extension": ".py", 250 | "mimetype": "text/x-python", 251 | "name": "python", 252 | "nbconvert_exporter": "python", 253 | "pygments_lexer": "ipython3", 254 | "version": "3.11.8" 255 | } 256 | }, 257 | "nbformat": 4, 258 | "nbformat_minor": 5 259 | } 260 | -------------------------------------------------------------------------------- /notebooks/Coggle202405意图识别/04-词向量训练与使用.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 7, 6 | "id": "e12d4f35-6bf3-45ac-a68f-f5cf8f3f1694", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from sklearn.linear_model import LogisticRegression\n", 11 | "from sklearn.neighbors import KNeighborsClassifier\n", 12 | "from sklearn.svm import LinearSVC\n", 13 | "from sklearn.metrics import classification_report\n", 14 | "from sklearn.model_selection import cross_val_predict\n", 15 | "\n", 16 | "from gensim.test.utils import common_texts\n", 17 | "from gensim.models import Word2Vec\n", 18 | "model = Word2Vec(sentences=common_texts, vector_size=100, window=5, min_count=1, workers=4)" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 8, 24 | "id": "87638379-4e14-45a8-9b24-753bef0c99c8", 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "import pandas as pd\n", 29 | "import numpy as np\n", 30 | "import jieba\n", 31 | "\n", 32 | "data_dir = 'https://mirror.coggle.club/dataset/coggle-competition/'\n", 33 | "train_data = pd.read_csv(data_dir + 'intent-classify/train.csv', sep='\\t', header=None)\n", 34 | "test_data = pd.read_csv(data_dir + 'intent-classify/test.csv', sep='\\t', header=None)\n", 35 | "train_data[0] = train_data[0].apply(jieba.lcut)\n", 36 | "test_data[0] = test_data[0].apply(jieba.lcut)" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 9, 42 | "id": "7fe6746d-c804-4627-bdbf-ef228b7d0b70", 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "model = Word2Vec(\n", 47 | " sentences=list(train_data[0].values[:]) + list(test_data[0].values[:]), \n", 48 | "vector_size=30, window=5, min_count=1, workers=4)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 10, 54 | "id": "942d78a3-15e2-4857-a13e-90945f5d04a4", 55 | "metadata": { 56 | "scrolled": true 57 | }, 58 | "outputs": [ 59 | { 60 | "data": { 61 | "text/plain": [ 62 | "[('热水器', 0.98893141746521),\n", 63 | " ('关掉', 0.9859920740127563),\n", 64 | " ('加湿器', 0.9854778051376343),\n", 65 | " ('这时候', 0.9815304279327393),\n", 66 | " ('电脑', 0.9793143272399902),\n", 67 | " (',', 0.9790096879005432),\n", 68 | " ('蒸', 0.975031852722168),\n", 69 | " ('洗衣机', 0.9734700918197632),\n", 70 | " ('准备', 0.9727935791015625),\n", 71 | " ('把', 0.9721974730491638)]" 72 | ] 73 | }, 74 | "execution_count": 10, 75 | "metadata": {}, 76 | "output_type": "execute_result" 77 | } 78 | ], 79 | "source": [ 80 | "model.wv.most_similar('打开')" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 11, 86 | "id": "f1761282-5026-407d-80ee-5ae6944d2d90", 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "train_w2v = train_data[0].apply(lambda x: model.wv[x].mean(0))\n", 91 | "test_w2v = test_data[0].apply(lambda x: model.wv[x].mean(0))\n", 92 | "\n", 93 | "train_w2v = np.vstack(train_w2v)\n", 94 | "test_w2v = np.vstack(test_w2v)" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 12, 100 | "id": "53c2c463-6995-49d6-a6d3-2bb770f33b0c", 101 | "metadata": {}, 102 | "outputs": [ 103 | { 104 | "name": "stderr", 105 | "output_type": "stream", 106 | "text": [ 107 | "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/svm/_classes.py:31: FutureWarning: The default value of `dual` will change from `True` to `'auto'` in 1.5. Set the value of `dual` explicitly to suppress the warning.\n", 108 | " warnings.warn(\n", 109 | "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/svm/_base.py:1237: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n", 110 | " warnings.warn(\n", 111 | "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/svm/_classes.py:31: FutureWarning: The default value of `dual` will change from `True` to `'auto'` in 1.5. Set the value of `dual` explicitly to suppress the warning.\n", 112 | " warnings.warn(\n", 113 | "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/svm/_base.py:1237: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n", 114 | " warnings.warn(\n", 115 | "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/svm/_classes.py:31: FutureWarning: The default value of `dual` will change from `True` to `'auto'` in 1.5. Set the value of `dual` explicitly to suppress the warning.\n", 116 | " warnings.warn(\n", 117 | "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/svm/_base.py:1237: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n", 118 | " warnings.warn(\n", 119 | "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/svm/_classes.py:31: FutureWarning: The default value of `dual` will change from `True` to `'auto'` in 1.5. Set the value of `dual` explicitly to suppress the warning.\n", 120 | " warnings.warn(\n", 121 | "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/svm/_base.py:1237: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n", 122 | " warnings.warn(\n", 123 | "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/svm/_classes.py:31: FutureWarning: The default value of `dual` will change from `True` to `'auto'` in 1.5. Set the value of `dual` explicitly to suppress the warning.\n", 124 | " warnings.warn(\n", 125 | "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/svm/_base.py:1237: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n", 126 | " warnings.warn(\n" 127 | ] 128 | }, 129 | { 130 | "name": "stdout", 131 | "output_type": "stream", 132 | "text": [ 133 | " precision recall f1-score support\n", 134 | "\n", 135 | " Alarm-Update 0.91 0.93 0.92 1264\n", 136 | " Audio-Play 0.00 0.00 0.00 226\n", 137 | " Calendar-Query 0.93 0.97 0.95 1214\n", 138 | " FilmTele-Play 0.56 0.63 0.59 1355\n", 139 | "HomeAppliance-Control 0.84 0.92 0.88 1215\n", 140 | " Music-Play 0.73 0.78 0.75 1304\n", 141 | " Other 0.10 0.01 0.02 214\n", 142 | " Radio-Listen 0.85 0.84 0.85 1285\n", 143 | " TVProgram-Play 0.69 0.05 0.09 240\n", 144 | " Travel-Query 0.84 0.92 0.88 1220\n", 145 | " Video-Play 0.70 0.73 0.71 1334\n", 146 | " Weather-Query 0.82 0.83 0.83 1229\n", 147 | "\n", 148 | " accuracy 0.79 12100\n", 149 | " macro avg 0.66 0.63 0.62 12100\n", 150 | " weighted avg 0.76 0.79 0.77 12100\n", 151 | "\n" 152 | ] 153 | } 154 | ], 155 | "source": [ 156 | "cv_pred = cross_val_predict(\n", 157 | " LinearSVC(),\n", 158 | " train_w2v, train_data[1]\n", 159 | ")\n", 160 | "print(classification_report(train_data[1], cv_pred))" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "id": "2e1787ac-fdc5-461f-92bc-d6361cab525d", 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [] 170 | } 171 | ], 172 | "metadata": { 173 | "kernelspec": { 174 | "display_name": "py3.11", 175 | "language": "python", 176 | "name": "py3.11" 177 | }, 178 | "language_info": { 179 | "codemirror_mode": { 180 | "name": "ipython", 181 | "version": 3 182 | }, 183 | "file_extension": ".py", 184 | "mimetype": "text/x-python", 185 | "name": "python", 186 | "nbconvert_exporter": "python", 187 | "pygments_lexer": "ipython3", 188 | "version": "3.11.8" 189 | } 190 | }, 191 | "nbformat": 4, 192 | "nbformat_minor": 5 193 | } 194 | -------------------------------------------------------------------------------- /notebooks/Coggle202405意图识别/06-LSTM文本分类.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "08749765-5bf2-4305-a461-1d375fee41bb", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "import torch\n", 13 | "import jieba\n", 14 | "import pandas as pd\n", 15 | "\n", 16 | "data_dir = 'https://mirror.coggle.club/dataset/coggle-competition/'\n", 17 | "train_data = pd.read_csv(data_dir + 'intent-classify/train.csv', sep='\\t', header=None)\n", 18 | "test_data = pd.read_csv(data_dir + 'intent-classify/test.csv', sep='\\t', header=None)" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "id": "81a0c43c-3e04-42eb-9c1c-a284e0f21178", 25 | "metadata": { 26 | "tags": [] 27 | }, 28 | "outputs": [ 29 | { 30 | "data": { 31 | "text/html": [ 32 | "
\n", 33 | "\n", 46 | "\n", 47 | " \n", 48 | " \n", 49 | " \n", 50 | " \n", 51 | " \n", 52 | " \n", 53 | " \n", 54 | " \n", 55 | " \n", 56 | " \n", 57 | " \n", 58 | " \n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | "
01
0还有双鸭山到淮阴的汽车票吗13号的Travel-Query
1从这里怎么回家Travel-Query
2随便播放一首专辑阁楼里的佛里的歌Music-Play
3给看一下墓王之王嘛FilmTele-Play
4我想看挑战两把s686打突变团竞的游戏视频Video-Play
\n", 82 | "
" 83 | ], 84 | "text/plain": [ 85 | " 0 1\n", 86 | "0 还有双鸭山到淮阴的汽车票吗13号的 Travel-Query\n", 87 | "1 从这里怎么回家 Travel-Query\n", 88 | "2 随便播放一首专辑阁楼里的佛里的歌 Music-Play\n", 89 | "3 给看一下墓王之王嘛 FilmTele-Play\n", 90 | "4 我想看挑战两把s686打突变团竞的游戏视频 Video-Play" 91 | ] 92 | }, 93 | "execution_count": 2, 94 | "metadata": {}, 95 | "output_type": "execute_result" 96 | } 97 | ], 98 | "source": [ 99 | "train_data.head()" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 3, 105 | "id": "d8492bd1-ad05-4e3b-a778-dcecb2573196", 106 | "metadata": { 107 | "tags": [] 108 | }, 109 | "outputs": [], 110 | "source": [ 111 | "train_data = train_data.sample(frac=1.0)" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 4, 117 | "id": "6f67d9cb-04e9-4ab7-8642-acb0c4f58493", 118 | "metadata": { 119 | "tags": [] 120 | }, 121 | "outputs": [], 122 | "source": [ 123 | "train_data[1], lbl = pd.factorize(train_data[1])" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 5, 129 | "id": "73d8e7e8-d520-43c9-bffa-a2d89c79a0e7", 130 | "metadata": { 131 | "tags": [] 132 | }, 133 | "outputs": [], 134 | "source": [ 135 | "def coustom_data_iter(texts, labels):\n", 136 | " for x, y in zip(texts, labels):\n", 137 | " yield x, y" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 6, 143 | "id": "b40c62f9-04f7-4c58-8a28-3b2799e31fae", 144 | "metadata": { 145 | "tags": [] 146 | }, 147 | "outputs": [], 148 | "source": [ 149 | "train_iter = coustom_data_iter(train_data[0].values[:], train_data[1].values[:])" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 7, 155 | "id": "fdd54feb-ac16-4b37-ad93-c571e6277e63", 156 | "metadata": { 157 | "tags": [] 158 | }, 159 | "outputs": [ 160 | { 161 | "name": "stderr", 162 | "output_type": "stream", 163 | "text": [ 164 | "Building prefix dict from the default dictionary ...\n", 165 | "Loading model from cache /tmp/jieba.cache\n", 166 | "Loading model cost 0.602 seconds.\n", 167 | "Prefix dict has been built successfully.\n" 168 | ] 169 | } 170 | ], 171 | "source": [ 172 | "from torchtext.data.utils import get_tokenizer\n", 173 | "from torchtext.vocab import build_vocab_from_iterator\n", 174 | "\n", 175 | "tokenizer = jieba.lcut\n", 176 | "\n", 177 | "\n", 178 | "def yield_tokens(data_iter):\n", 179 | " for text, _ in data_iter:\n", 180 | " yield tokenizer(text)\n", 181 | "\n", 182 | "\n", 183 | "vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=[\"\"])\n", 184 | "vocab.set_default_index(vocab[\"\"])" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 8, 190 | "id": "4b036478-3f3f-48bf-b562-40bc295f6060", 191 | "metadata": { 192 | "scrolled": true, 193 | "tags": [] 194 | }, 195 | "outputs": [ 196 | { 197 | "data": { 198 | "text/plain": [ 199 | "['', '的', '我', '一下', '播放', '是', '吗', '给', '帮', '一个']" 200 | ] 201 | }, 202 | "execution_count": 8, 203 | "metadata": {}, 204 | "output_type": "execute_result" 205 | } 206 | ], 207 | "source": [ 208 | "vocab.get_itos()[:10]" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": 9, 214 | "id": "b6829460-2e9f-4630-bff9-4513e7c989e5", 215 | "metadata": { 216 | "tags": [] 217 | }, 218 | "outputs": [ 219 | { 220 | "data": { 221 | "text/plain": [ 222 | "[2, 3, 41]" 223 | ] 224 | }, 225 | "execution_count": 9, 226 | "metadata": {}, 227 | "output_type": "execute_result" 228 | } 229 | ], 230 | "source": [ 231 | "vocab(['我', '一下', '今天'])" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": 10, 237 | "id": "b1c03326-b060-46bd-a35d-71e3debb3075", 238 | "metadata": { 239 | "tags": [] 240 | }, 241 | "outputs": [], 242 | "source": [ 243 | "# from gensim.models import KeyedVectors\n", 244 | "\n", 245 | "# # 需要自行下载,然后修改路径后运行\n", 246 | "# wv_from_text = KeyedVectors.load_word2vec_format('/home/lyz/work/dataset/词向量/tencent-ailab-embedding-zh-d100-v0.2.0-s/tencent-ailab-embedding-zh-d100-v0.2.0-s.txt', binary=False)\n", 247 | "\n", 248 | "# pretrained_w2v = []\n", 249 | "# for w in vocab.get_itos():\n", 250 | "# if w in wv_from_text:\n", 251 | "# pretrained_w2v.append(wv_from_text[w])\n", 252 | "# else:\n", 253 | "# pretrained_w2v.append(np.random.rand(100))\n", 254 | " \n", 255 | "# pretrained_w2v = np.vstack(pretrained_w2v)" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 11, 261 | "id": "39b1f63f-36f8-45cb-bacd-69278e2dd790", 262 | "metadata": { 263 | "tags": [] 264 | }, 265 | "outputs": [], 266 | "source": [ 267 | "def text_pipeline(x): return vocab(tokenizer(x))" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 12, 273 | "id": "399e7cc9-bd92-4860-8776-3d6ae991dbf4", 274 | "metadata": { 275 | "tags": [] 276 | }, 277 | "outputs": [], 278 | "source": [ 279 | "processed_text = torch.tensor(text_pipeline('今天我们在这里'), dtype=torch.int64)" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": 13, 285 | "id": "0d3373a6-4ffc-42ac-b0d7-c43dd3bc5e22", 286 | "metadata": { 287 | "tags": [] 288 | }, 289 | "outputs": [], 290 | "source": [ 291 | "from torch.nn.utils.rnn import pad_sequence" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": 14, 297 | "id": "cce7e02f-ee05-48a0-9aa1-2bcfde8eef30", 298 | "metadata": { 299 | "tags": [] 300 | }, 301 | "outputs": [], 302 | "source": [ 303 | "from torch.utils.data import DataLoader\n", 304 | "import torch.nn.functional as F\n", 305 | "\n", 306 | "# device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 307 | "device = 'cpu'\n", 308 | "\n", 309 | "# 90%的文本长度都在20个单词以内\n", 310 | "def collate_batch(batch, max_len=20):\n", 311 | " label_list, text_list = [], []\n", 312 | " for (_text, _label) in batch:\n", 313 | " label_list.append(_label)\n", 314 | " processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)\n", 315 | " processed_text = F.pad(processed_text, pad=[0, max_len,], mode='constant', value=0)\n", 316 | " if len(processed_text) > max_len:\n", 317 | " processed_text = processed_text[:max_len]\n", 318 | "\n", 319 | " text_list.append(processed_text)\n", 320 | " label_list = torch.tensor(label_list, dtype=torch.int64)\n", 321 | " text_list = pad_sequence(text_list).T\n", 322 | " return label_list.to(device), text_list.to(device)" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": 15, 328 | "id": "c1f3ca37-4e2b-4ca7-a642-216a3b237a7c", 329 | "metadata": { 330 | "tags": [] 331 | }, 332 | "outputs": [], 333 | "source": [ 334 | "class BILSTM(torch.nn.Module):\n", 335 | " def __init__(self, vocab_size, embedding_dim, hidden_dim, label_size):\n", 336 | " super(BILSTM, self).__init__()\n", 337 | " self.hidden_dim = hidden_dim\n", 338 | " self.embeddings = torch.nn.Embedding(vocab_size, embedding_dim)\n", 339 | " # self.embeddings.weight.data.copy_(torch.from_numpy(pretrained_w2v))\n", 340 | " \n", 341 | " self.lstm = torch.nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, bidirectional=True)\n", 342 | " self.hidden2label = torch.nn.Linear(hidden_dim*2, label_size)\n", 343 | " \n", 344 | "\n", 345 | " def forward(self, sentence):\n", 346 | " # torch.Size([16, 20])\n", 347 | " # print(sentence.shape)\n", 348 | " sentence = torch.transpose(sentence, 1, 0)\n", 349 | " \n", 350 | " # torch.Size([20, 16])\n", 351 | " # print(sentence.shape)\n", 352 | " \n", 353 | " # torch.Size([20, 16, 100])\n", 354 | " x = self.embeddings(sentence)\n", 355 | " # print(x.shape)\n", 356 | " \n", 357 | " # 5 seqence length\n", 358 | " # 3 batch size\n", 359 | " # 10 input size\n", 360 | " # input = torch.randn(5, 3, 10)\n", 361 | "\n", 362 | " lstm_out, self.hidden = self.lstm(x)\n", 363 | " # torch.Size([20, 16, 128])\n", 364 | " # print(lstm_out.shape)\n", 365 | " y = self.hidden2label(lstm_out[-1,:,:])\n", 366 | " return y" 367 | ] 368 | }, 369 | { 370 | "cell_type": "code", 371 | "execution_count": 16, 372 | "id": "d09e44a3-c003-4909-838f-50e533bff31d", 373 | "metadata": { 374 | "tags": [] 375 | }, 376 | "outputs": [ 377 | { 378 | "data": { 379 | "text/plain": [ 380 | "torch.Size([1, 3, 2])" 381 | ] 382 | }, 383 | "execution_count": 16, 384 | "metadata": {}, 385 | "output_type": "execute_result" 386 | } 387 | ], 388 | "source": [ 389 | "torch.transpose(torch.rand((2,3,1)), 0, 2).shape" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": 17, 395 | "id": "bf4c839f-24e5-49f8-b7fa-1f6e5eb17722", 396 | "metadata": { 397 | "tags": [] 398 | }, 399 | "outputs": [], 400 | "source": [ 401 | "import time\n", 402 | "\n", 403 | "def train(dataloader):\n", 404 | " model.train()\n", 405 | " total_acc, total_count = 0, 0\n", 406 | "\n", 407 | " for idx, (label, text) in enumerate(dataloader):\n", 408 | " optimizer.zero_grad()\n", 409 | " predicted_label = model(text)\n", 410 | " loss = criterion(predicted_label, label)\n", 411 | " loss.backward()\n", 412 | " torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)\n", 413 | " optimizer.step()\n", 414 | " total_acc += (predicted_label.argmax(1) == label).sum().item()\n", 415 | " total_count += label.size(0)\n", 416 | "\n", 417 | "def evaluate(dataloader):\n", 418 | " model.eval()\n", 419 | " total_acc, total_count = 0, 0\n", 420 | "\n", 421 | " with torch.no_grad():\n", 422 | " for idx, (label, text) in enumerate(dataloader):\n", 423 | " predicted_label = model(text)\n", 424 | " loss = criterion(predicted_label, label)\n", 425 | " total_acc += (predicted_label.argmax(1) == label).sum().item()\n", 426 | " total_count += label.size(0)\n", 427 | " return total_acc/total_count" 428 | ] 429 | }, 430 | { 431 | "cell_type": "code", 432 | "execution_count": 18, 433 | "id": "322569ce-b658-478c-a881-86eaaff6af90", 434 | "metadata": { 435 | "tags": [] 436 | }, 437 | "outputs": [], 438 | "source": [ 439 | "num_class = len(lbl)\n", 440 | "vocab_size = len(vocab)\n", 441 | "emsize = 100\n", 442 | "model = BILSTM(vocab_size, emsize, 64, num_class).to(device)" 443 | ] 444 | }, 445 | { 446 | "cell_type": "code", 447 | "execution_count": 19, 448 | "id": "c66d6645-5ef9-4ecd-af59-9bf9e572e3ee", 449 | "metadata": { 450 | "tags": [] 451 | }, 452 | "outputs": [], 453 | "source": [ 454 | "from torch.utils.data.dataset import random_split\n", 455 | "from torchtext.data.functional import to_map_style_dataset\n", 456 | "# Hyperparameters\n", 457 | "EPOCHS = 40 # epoch\n", 458 | "LR = 2 # learning rate\n", 459 | "BATCH_SIZE = 16 # batch size for training\n", 460 | "\n", 461 | "criterion = torch.nn.CrossEntropyLoss()\n", 462 | "optimizer = torch.optim.SGD(model.parameters(), lr=LR)\n", 463 | "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 5.0, gamma=0.75)\n", 464 | "total_accu = None\n", 465 | "\n", 466 | "train_iter = coustom_data_iter(train_data[0].values[:], train_data[1].values[:])\n", 467 | "train_dataset = to_map_style_dataset(train_iter)\n", 468 | "\n", 469 | "num_train = int(len(train_dataset) * 0.75)\n", 470 | "split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_dataset) - num_train])\n", 471 | "\n", 472 | "train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,\n", 473 | " shuffle=True, collate_fn=collate_batch)\n", 474 | "valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,\n", 475 | " shuffle=True, collate_fn=collate_batch)" 476 | ] 477 | }, 478 | { 479 | "cell_type": "code", 480 | "execution_count": 20, 481 | "id": "9f60cb58-a65c-420b-a487-5bfb9a14f7db", 482 | "metadata": { 483 | "tags": [] 484 | }, 485 | "outputs": [ 486 | { 487 | "data": { 488 | "text/plain": [ 489 | "torch.Size([16, 12])" 490 | ] 491 | }, 492 | "execution_count": 20, 493 | "metadata": {}, 494 | "output_type": "execute_result" 495 | } 496 | ], 497 | "source": [ 498 | "for (label, text) in train_dataloader:\n", 499 | " break\n", 500 | " \n", 501 | "model(text).shape\n", 502 | "# 12, 16, 100\n", 503 | "# sequence length * batch size * embedding dim" 504 | ] 505 | }, 506 | { 507 | "cell_type": "code", 508 | "execution_count": null, 509 | "id": "9af00bc2-9949-4f85-a785-47a227a21569", 510 | "metadata": { 511 | "tags": [] 512 | }, 513 | "outputs": [], 514 | "source": [ 515 | "for epoch in range(1, EPOCHS + 1):\n", 516 | " epoch_start_time = time.time()\n", 517 | " train(train_dataloader)\n", 518 | " accu_val = evaluate(valid_dataloader)\n", 519 | " if total_accu is not None and total_accu > accu_val:\n", 520 | " scheduler.step()\n", 521 | " else:\n", 522 | " total_accu = accu_val\n", 523 | " \n", 524 | " print('| end of epoch {:3d} | time: {:5.2f}s | '\n", 525 | " 'valid accuracy {:8.3f} '.format(epoch,\n", 526 | " time.time() - epoch_start_time,\n", 527 | " accu_val))" 528 | ] 529 | }, 530 | { 531 | "cell_type": "code", 532 | "execution_count": null, 533 | "id": "50ef9d28-fd43-41a7-a7b0-35171878b5d9", 534 | "metadata": {}, 535 | "outputs": [], 536 | "source": [] 537 | }, 538 | { 539 | "cell_type": "code", 540 | "execution_count": null, 541 | "id": "8980e9ae-fb3f-4a16-8792-565cbd1e4f77", 542 | "metadata": { 543 | "tags": [] 544 | }, 545 | "outputs": [], 546 | "source": [ 547 | "test_iter = coustom_data_iter(test_data[0].values[:], [0] * len(test_data))\n", 548 | "test_dataset = to_map_style_dataset(test_iter)\n", 549 | "test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE,\n", 550 | " shuffle=False, collate_fn=collate_batch)" 551 | ] 552 | }, 553 | { 554 | "cell_type": "code", 555 | "execution_count": null, 556 | "id": "33b9132d-cbdf-4c91-bea8-6cac6b737471", 557 | "metadata": { 558 | "tags": [] 559 | }, 560 | "outputs": [], 561 | "source": [ 562 | "def predict(dataloader):\n", 563 | " model.eval()\n", 564 | "\n", 565 | " test_pred = []\n", 566 | " with torch.no_grad():\n", 567 | " for idx, (label, text) in enumerate(dataloader):\n", 568 | " predicted_label = model(text).argmax(1)\n", 569 | " test_pred += list(predicted_label.cpu().numpy())\n", 570 | " return test_pred" 571 | ] 572 | }, 573 | { 574 | "cell_type": "code", 575 | "execution_count": null, 576 | "id": "e7d1273c-9a20-43bf-83aa-a3c5b56881bf", 577 | "metadata": { 578 | "tags": [] 579 | }, 580 | "outputs": [], 581 | "source": [ 582 | "test_pred = predict(test_dataloader)\n", 583 | "test_pred = [lbl[x] for x in test_pred]" 584 | ] 585 | }, 586 | { 587 | "cell_type": "code", 588 | "execution_count": null, 589 | "id": "44907273-d712-47bb-985b-bde269132bea", 590 | "metadata": { 591 | "tags": [] 592 | }, 593 | "outputs": [], 594 | "source": [ 595 | "pd.DataFrame({\n", 596 | " 'ID': range(1, len(test_pred) + 1),\n", 597 | " 'Target': test_pred,\n", 598 | "}).to_csv('nlp_submit.csv', index=None)\n", 599 | "\n", 600 | "# 提交一下吧~" 601 | ] 602 | }, 603 | { 604 | "cell_type": "code", 605 | "execution_count": null, 606 | "id": "15046e8c-1e7b-4b99-83d4-0e4211217f29", 607 | "metadata": {}, 608 | "outputs": [], 609 | "source": [] 610 | } 611 | ], 612 | "metadata": { 613 | "kernelspec": { 614 | "display_name": "py3.11", 615 | "language": "python", 616 | "name": "py3.11" 617 | }, 618 | "language_info": { 619 | "codemirror_mode": { 620 | "name": "ipython", 621 | "version": 3 622 | }, 623 | "file_extension": ".py", 624 | "mimetype": "text/x-python", 625 | "name": "python", 626 | "nbconvert_exporter": "python", 627 | "pygments_lexer": "ipython3", 628 | "version": "3.11.8" 629 | }, 630 | "widgets": { 631 | "application/vnd.jupyter.widget-state+json": { 632 | "state": {}, 633 | "version_major": 2, 634 | "version_minor": 0 635 | } 636 | } 637 | }, 638 | "nbformat": 4, 639 | "nbformat_minor": 5 640 | } 641 | -------------------------------------------------------------------------------- /notebooks/Coggle202405意图识别/07-TextCNN文本分类.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "c856aa61-938b-422e-bd10-a8f86e5b9dc4", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "import torch\n", 13 | "import numpy as np\n", 14 | "import jieba\n", 15 | "import pandas as pd\n", 16 | "\n", 17 | "data_dir = 'https://mirror.coggle.club/dataset/coggle-competition/'\n", 18 | "train_data = pd.read_csv(data_dir + 'intent-classify/train.csv', sep='\\t', header=None)\n", 19 | "test_data = pd.read_csv(data_dir + 'intent-classify/test.csv', sep='\\t', header=None)" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "id": "9a0d4a0d-03fc-4a54-92a9-cd87f94b1170", 26 | "metadata": {}, 27 | "outputs": [ 28 | { 29 | "data": { 30 | "text/plain": [ 31 | "'2.2.2+cu121'" 32 | ] 33 | }, 34 | "execution_count": 2, 35 | "metadata": {}, 36 | "output_type": "execute_result" 37 | } 38 | ], 39 | "source": [ 40 | "torch.__version__" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 3, 46 | "id": "6e599690-b330-437e-896f-c54f8de97601", 47 | "metadata": { 48 | "tags": [] 49 | }, 50 | "outputs": [ 51 | { 52 | "data": { 53 | "text/html": [ 54 | "
\n", 55 | "\n", 68 | "\n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | "
01
0还有双鸭山到淮阴的汽车票吗13号的Travel-Query
1从这里怎么回家Travel-Query
2随便播放一首专辑阁楼里的佛里的歌Music-Play
3给看一下墓王之王嘛FilmTele-Play
4我想看挑战两把s686打突变团竞的游戏视频Video-Play
\n", 104 | "
" 105 | ], 106 | "text/plain": [ 107 | " 0 1\n", 108 | "0 还有双鸭山到淮阴的汽车票吗13号的 Travel-Query\n", 109 | "1 从这里怎么回家 Travel-Query\n", 110 | "2 随便播放一首专辑阁楼里的佛里的歌 Music-Play\n", 111 | "3 给看一下墓王之王嘛 FilmTele-Play\n", 112 | "4 我想看挑战两把s686打突变团竞的游戏视频 Video-Play" 113 | ] 114 | }, 115 | "execution_count": 3, 116 | "metadata": {}, 117 | "output_type": "execute_result" 118 | } 119 | ], 120 | "source": [ 121 | "train_data.head()" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 4, 127 | "id": "a8b7c1b6-4730-4bc9-8637-51c671107e6d", 128 | "metadata": { 129 | "tags": [] 130 | }, 131 | "outputs": [], 132 | "source": [ 133 | "train_data = train_data.sample(frac=1.0)" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 5, 139 | "id": "c6232746-8cf6-4766-b6e9-b8763c938f90", 140 | "metadata": { 141 | "tags": [] 142 | }, 143 | "outputs": [], 144 | "source": [ 145 | "train_data[1], lbl = pd.factorize(train_data[1])" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 6, 151 | "id": "a9d63c0c-0e7d-4f57-9a6f-6fd1c23c5642", 152 | "metadata": { 153 | "tags": [] 154 | }, 155 | "outputs": [], 156 | "source": [ 157 | "def coustom_data_iter(texts, labels):\n", 158 | " for x, y in zip(texts, labels):\n", 159 | " yield x, y" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 7, 165 | "id": "5cdf3cc7-4b70-4b21-923e-65e7c2af5dac", 166 | "metadata": { 167 | "tags": [] 168 | }, 169 | "outputs": [], 170 | "source": [ 171 | "train_iter = coustom_data_iter(train_data[0].values[:], train_data[1].values[:])" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": 8, 177 | "id": "4ab5ebd1-f28d-4c46-a28e-57d415454c24", 178 | "metadata": { 179 | "tags": [] 180 | }, 181 | "outputs": [ 182 | { 183 | "name": "stderr", 184 | "output_type": "stream", 185 | "text": [ 186 | "Building prefix dict from the default dictionary ...\n", 187 | "Loading model from cache /tmp/jieba.cache\n", 188 | "Loading model cost 0.812 seconds.\n", 189 | "Prefix dict has been built successfully.\n" 190 | ] 191 | } 192 | ], 193 | "source": [ 194 | "from torchtext.data.utils import get_tokenizer\n", 195 | "from torchtext.vocab import build_vocab_from_iterator\n", 196 | "\n", 197 | "tokenizer = jieba.lcut\n", 198 | "\n", 199 | "def yield_tokens(data_iter):\n", 200 | " for text, _ in data_iter:\n", 201 | " yield tokenizer(text)\n", 202 | "\n", 203 | "vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=[\"\"])\n", 204 | "vocab.set_default_index(vocab[\"\"])" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": 9, 210 | "id": "36956a32-1053-4639-a543-cd6c09fbdd2c", 211 | "metadata": { 212 | "scrolled": true, 213 | "tags": [] 214 | }, 215 | "outputs": [ 216 | { 217 | "data": { 218 | "text/plain": [ 219 | "['', '的', '我', '一下', '播放', '是', '吗', '给', '帮', '一个']" 220 | ] 221 | }, 222 | "execution_count": 9, 223 | "metadata": {}, 224 | "output_type": "execute_result" 225 | } 226 | ], 227 | "source": [ 228 | "vocab.get_itos()[:10]" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": 10, 234 | "id": "48580411-289b-4ef8-a4e5-a35040fe2014", 235 | "metadata": { 236 | "tags": [] 237 | }, 238 | "outputs": [ 239 | { 240 | "data": { 241 | "text/plain": [ 242 | "[2, 3, 41]" 243 | ] 244 | }, 245 | "execution_count": 10, 246 | "metadata": {}, 247 | "output_type": "execute_result" 248 | } 249 | ], 250 | "source": [ 251 | "vocab(['我', '一下', '今天'])" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": 11, 257 | "id": "b80b35e3-0f63-40a1-9813-78a5113be18e", 258 | "metadata": { 259 | "tags": [] 260 | }, 261 | "outputs": [], 262 | "source": [ 263 | "def text_pipeline(x): return vocab(tokenizer(x))" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": 12, 269 | "id": "7683e758-2d0f-44ec-b561-212c1fe91829", 270 | "metadata": { 271 | "tags": [] 272 | }, 273 | "outputs": [], 274 | "source": [ 275 | "processed_text = torch.tensor(text_pipeline('今天我们在这里'), dtype=torch.int64)" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": 13, 281 | "id": "779340b0-5f93-4054-9752-481392e375a6", 282 | "metadata": { 283 | "tags": [] 284 | }, 285 | "outputs": [], 286 | "source": [ 287 | "from torch.nn.utils.rnn import pad_sequence" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": 14, 293 | "id": "8c49f4b3-bf96-430a-8856-d61e0867cff9", 294 | "metadata": { 295 | "tags": [] 296 | }, 297 | "outputs": [], 298 | "source": [ 299 | "from torch.utils.data import DataLoader\n", 300 | "import torch.nn.functional as F\n", 301 | "\n", 302 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 303 | "\n", 304 | "\n", 305 | "def collate_batch(batch, max_len=40):\n", 306 | " label_list, text_list = [], []\n", 307 | " for (_text, _label) in batch:\n", 308 | " label_list.append(_label)\n", 309 | " processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)\n", 310 | " processed_text = F.pad(processed_text, pad=[0, max_len,], mode='constant', value=0)\n", 311 | " if len(processed_text) > max_len:\n", 312 | " processed_text = processed_text[:max_len]\n", 313 | "\n", 314 | " text_list.append(processed_text)\n", 315 | " label_list = torch.tensor(label_list, dtype=torch.int64)\n", 316 | " text_list = pad_sequence(text_list).T\n", 317 | " return label_list.to(device), text_list.to(device)\n", 318 | "\n", 319 | "from torchtext.data.functional import to_map_style_dataset\n", 320 | "train_iter = to_map_style_dataset(train_iter)\n", 321 | "dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": 15, 327 | "id": "76c856c5-cb3a-4c39-b063-9f7b7067caaa", 328 | "metadata": { 329 | "tags": [] 330 | }, 331 | "outputs": [], 332 | "source": [ 333 | "from torch import nn\n", 334 | "#%% Text CNN model\n", 335 | "class textCNN(nn.Module):\n", 336 | " \n", 337 | " def __init__(self, vocab_size, emb_dim, kernel_wins, num_class=12):\n", 338 | " super(textCNN, self).__init__()\n", 339 | " #load pretrained embedding in embedding layer.\n", 340 | " self.embed = nn.Embedding(vocab_size, emb_dim)\n", 341 | " # self.embed.weight.data.copy_(torch.from_numpy(pretrained_w2v))\n", 342 | " \n", 343 | " self.convs = nn.ModuleList([nn.Conv2d(1, emb_dim, (w, emb_dim)) for w in kernel_wins])\n", 344 | " #Dropout layer\n", 345 | " self.dropout = nn.Dropout(0.6)\n", 346 | " \n", 347 | " #FC layer\n", 348 | " self.fc = nn.Linear(len(kernel_wins)*emb_dim, num_class)\n", 349 | " \n", 350 | " def forward(self, x):\n", 351 | " # torch.Size([16, 40])\n", 352 | " # print(x.shape)\n", 353 | " emb_x = self.embed(x)\n", 354 | " \n", 355 | " # torch.Size([16, 40, 100])\n", 356 | " # print(emb_x.shape)\n", 357 | " \n", 358 | " # batch size * channel * height * width\n", 359 | " # batch size * channel * word count * embedding\n", 360 | " # torch.Size([16, 1, 40, 100])\n", 361 | " emb_x = emb_x.unsqueeze(1)\n", 362 | " # print(emb_x.shape)\n", 363 | "\n", 364 | " con_x = [conv(emb_x) for conv in self.convs]\n", 365 | "\n", 366 | " pool_x = [F.max_pool1d(x.squeeze(-1), x.size()[2]) for x in con_x]\n", 367 | " fc_x = torch.cat(pool_x, dim=1)\n", 368 | " # torch.Size([16, 300, 1])\n", 369 | " # print(fc_x.shape)\n", 370 | " fc_x = fc_x.squeeze(-1)\n", 371 | "\n", 372 | " fc_x = self.dropout(fc_x)\n", 373 | " logit = self.fc(fc_x)\n", 374 | " return logit" 375 | ] 376 | }, 377 | { 378 | "cell_type": "code", 379 | "execution_count": 16, 380 | "id": "f7cb48f6-d567-4b57-b74d-2830cc2b902e", 381 | "metadata": { 382 | "tags": [] 383 | }, 384 | "outputs": [], 385 | "source": [ 386 | "import time\n", 387 | "\n", 388 | "def train(dataloader):\n", 389 | " model.train()\n", 390 | " total_acc, total_count = 0, 0\n", 391 | "\n", 392 | " for idx, (label, text) in enumerate(dataloader):\n", 393 | " optimizer.zero_grad()\n", 394 | " predicted_label = model(text)\n", 395 | " loss = criterion(predicted_label, label)\n", 396 | " loss.backward()\n", 397 | " torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)\n", 398 | " optimizer.step()\n", 399 | " total_acc += (predicted_label.argmax(1) == label).sum().item()\n", 400 | " total_count += label.size(0)\n", 401 | "\n", 402 | "def evaluate(dataloader):\n", 403 | " model.eval()\n", 404 | " total_acc, total_count = 0, 0\n", 405 | "\n", 406 | " with torch.no_grad():\n", 407 | " for idx, (label, text) in enumerate(dataloader):\n", 408 | " predicted_label = model(text)\n", 409 | " loss = criterion(predicted_label, label)\n", 410 | " total_acc += (predicted_label.argmax(1) == label).sum().item()\n", 411 | " total_count += label.size(0)\n", 412 | " return total_acc/total_count" 413 | ] 414 | }, 415 | { 416 | "cell_type": "code", 417 | "execution_count": 17, 418 | "id": "66f6580a-d4e2-4a6f-bc97-de95c5a139c5", 419 | "metadata": { 420 | "tags": [] 421 | }, 422 | "outputs": [], 423 | "source": [ 424 | "num_class = len(lbl)\n", 425 | "vocab_size = len(vocab)\n", 426 | "emsize = 100\n", 427 | "model = textCNN(vocab_size, emsize, [3, 4 , 5], num_class).to(device)" 428 | ] 429 | }, 430 | { 431 | "cell_type": "code", 432 | "execution_count": 18, 433 | "id": "d3fca9b3-9a4f-4057-a8d4-18b5ee1dd1c7", 434 | "metadata": { 435 | "tags": [] 436 | }, 437 | "outputs": [], 438 | "source": [ 439 | "from torch.utils.data.dataset import random_split\n", 440 | "from torchtext.data.functional import to_map_style_dataset\n", 441 | "# Hyperparameters\n", 442 | "EPOCHS = 20 # epoch\n", 443 | "LR = 0.001 # learning rate\n", 444 | "BATCH_SIZE = 16 # batch size for training\n", 445 | "\n", 446 | "criterion = torch.nn.CrossEntropyLoss()\n", 447 | "optimizer = torch.optim.Adam(model.parameters(), lr=LR)\n", 448 | "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 5.0, gamma=0.75)\n", 449 | "total_accu = None\n", 450 | "\n", 451 | "train_iter = coustom_data_iter(train_data[0].values[:], train_data[1].values[:])\n", 452 | "train_dataset = to_map_style_dataset(train_iter)\n", 453 | "\n", 454 | "num_train = int(len(train_dataset) * 0.75)\n", 455 | "split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_dataset) - num_train])\n", 456 | "\n", 457 | "train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,\n", 458 | " shuffle=True, collate_fn=collate_batch)\n", 459 | "valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,\n", 460 | " shuffle=True, collate_fn=collate_batch)" 461 | ] 462 | }, 463 | { 464 | "cell_type": "code", 465 | "execution_count": null, 466 | "id": "5fee2746-9c22-43ea-a0b1-3a98f71ecc81", 467 | "metadata": { 468 | "tags": [] 469 | }, 470 | "outputs": [], 471 | "source": [ 472 | "for epoch in range(1, EPOCHS + 1):\n", 473 | " epoch_start_time = time.time()\n", 474 | " train(train_dataloader)\n", 475 | " accu_val = evaluate(valid_dataloader)\n", 476 | " if total_accu is not None and total_accu > accu_val:\n", 477 | " scheduler.step()\n", 478 | " else:\n", 479 | " total_accu = accu_val\n", 480 | " \n", 481 | " print('| end of epoch {:3d} | time: {:5.2f}s | '\n", 482 | " 'valid accuracy {:8.3f} '.format(epoch,\n", 483 | " time.time() - epoch_start_time,\n", 484 | " accu_val))" 485 | ] 486 | }, 487 | { 488 | "cell_type": "code", 489 | "execution_count": null, 490 | "id": "413f0821-4d58-4b9f-8b22-aee6ef3ceb41", 491 | "metadata": { 492 | "tags": [] 493 | }, 494 | "outputs": [], 495 | "source": [ 496 | "test_iter = coustom_data_iter(test_data[0].values[:], [0] * len(test_data))\n", 497 | "test_dataset = to_map_style_dataset(test_iter)\n", 498 | "test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE,\n", 499 | " shuffle=False, collate_fn=collate_batch)" 500 | ] 501 | }, 502 | { 503 | "cell_type": "code", 504 | "execution_count": null, 505 | "id": "167373d5-902d-4081-84ba-8287b00548aa", 506 | "metadata": { 507 | "tags": [] 508 | }, 509 | "outputs": [], 510 | "source": [ 511 | "def predict(dataloader):\n", 512 | " model.eval()\n", 513 | "\n", 514 | " test_pred = []\n", 515 | " with torch.no_grad():\n", 516 | " for idx, (label, text) in enumerate(dataloader):\n", 517 | " predicted_label = model(text).argmax(1)\n", 518 | " test_pred += list(predicted_label.cpu().numpy())\n", 519 | " return test_pred" 520 | ] 521 | }, 522 | { 523 | "cell_type": "code", 524 | "execution_count": null, 525 | "id": "d74bf13e-566f-4daa-b4b4-fab2b563a8e5", 526 | "metadata": { 527 | "tags": [] 528 | }, 529 | "outputs": [], 530 | "source": [ 531 | "test_pred = predict(test_dataloader)\n", 532 | "test_pred = [lbl[x] for x in test_pred]" 533 | ] 534 | }, 535 | { 536 | "cell_type": "code", 537 | "execution_count": null, 538 | "id": "3499669d-3348-4a5b-9087-62b524cac89b", 539 | "metadata": { 540 | "tags": [] 541 | }, 542 | "outputs": [], 543 | "source": [ 544 | "pd.DataFrame({\n", 545 | " 'ID': range(1, len(test_pred) + 1),\n", 546 | " 'Target': test_pred,\n", 547 | "}).to_csv('nlp_submit.csv', index=None)\n", 548 | "\n", 549 | "# 提交一下吧~" 550 | ] 551 | }, 552 | { 553 | "cell_type": "code", 554 | "execution_count": null, 555 | "id": "deef3ccf-c1de-4b5e-a10b-4a50c1db67b0", 556 | "metadata": {}, 557 | "outputs": [], 558 | "source": [] 559 | } 560 | ], 561 | "metadata": { 562 | "kernelspec": { 563 | "display_name": "py3.11", 564 | "language": "python", 565 | "name": "py3.11" 566 | }, 567 | "language_info": { 568 | "codemirror_mode": { 569 | "name": "ipython", 570 | "version": 3 571 | }, 572 | "file_extension": ".py", 573 | "mimetype": "text/x-python", 574 | "name": "python", 575 | "nbconvert_exporter": "python", 576 | "pygments_lexer": "ipython3", 577 | "version": "3.11.8" 578 | }, 579 | "widgets": { 580 | "application/vnd.jupyter.widget-state+json": { 581 | "state": {}, 582 | "version_major": 2, 583 | "version_minor": 0 584 | } 585 | } 586 | }, 587 | "nbformat": 4, 588 | "nbformat_minor": 5 589 | } 590 | -------------------------------------------------------------------------------- /notebooks/Coggle202405意图识别/10-T5加载与使用.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 68, 6 | "id": "efdbb171-ebf1-4465-871a-6add225b8313", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stdout", 11 | "output_type": "stream", 12 | "text": [ 13 | "['Travel-Query']\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "import torch\n", 19 | "from transformers import T5Tokenizer, T5Config, T5ForConditionalGeneration\n", 20 | "\n", 21 | "# load tokenizer and model \n", 22 | "pretrained_model = \"/home/lyz/hf-models/IDEA-CCNL/Randeng-T5-784M-MultiTask-Chinese/\"\n", 23 | "\n", 24 | "special_tokens = [\"\".format(i) for i in range(100)]\n", 25 | "tokenizer = T5Tokenizer.from_pretrained(\n", 26 | " pretrained_model,\n", 27 | " do_lower_case=True,\n", 28 | " max_length=512,\n", 29 | " truncation=True,\n", 30 | " additional_special_tokens=special_tokens,\n", 31 | ")\n", 32 | "config = T5Config.from_pretrained(pretrained_model)\n", 33 | "model = T5ForConditionalGeneration.from_pretrained(pretrained_model, config=config)\n", 34 | "model.resize_token_embeddings(len(tokenizer))\n", 35 | "model.eval()\n", 36 | "\n", 37 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 38 | "\n", 39 | "# device = 'cpu'\n", 40 | "model.to(device)\n", 41 | "\n", 42 | "# tokenize\n", 43 | "text = \"意图识别任务:还有双鸭山到淮阴的汽车票吗13号的 这篇文章的类别是什么?Travel-Query/Music-Play/FilmTele-Play/Video-Play/Radio-Listen/HomeAppliance-Control/Weather-Query/Alarm-Update/Calendar-Query/TVProgram-Play/Audio-Play/Other\"\n", 44 | "encode_dict = tokenizer(text, max_length=512, padding='max_length',truncation=True)\n", 45 | "\n", 46 | "inputs = {\n", 47 | " \"input_ids\": torch.tensor([encode_dict['input_ids']]).long().to(device),\n", 48 | " \"attention_mask\": torch.tensor([encode_dict['attention_mask']]).long().to(device),\n", 49 | " }\n", 50 | "\n", 51 | "# generate answer\n", 52 | "logits = model.generate(\n", 53 | " input_ids = inputs['input_ids'],\n", 54 | " max_length=100, \n", 55 | " do_sample= True\n", 56 | " # early_stopping=True,\n", 57 | " )\n", 58 | "\n", 59 | "logits=logits[:,1:]\n", 60 | "predict_label = [tokenizer.decode(i,skip_special_tokens=True) for i in logits]\n", 61 | "print(predict_label)" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 69, 67 | "id": "f248c8e5-67a5-4b13-81f9-41e8521ad5b3", 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "# 读取数据集,这里是直接联网读取,也可以通过下载文件,再读取\n", 72 | "import pandas as pd\n", 73 | "import matplotlib.pyplot as plt\n", 74 | "\n", 75 | "data_dir = 'https://mirror.coggle.club/dataset/coggle-competition/'\n", 76 | "train_data = pd.read_csv(data_dir + 'intent-classify/train.csv', sep='\\t', header=None)\n", 77 | "test_data = pd.read_csv(data_dir + 'intent-classify/test.csv', sep='\\t', header=None)" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 70, 83 | "id": "a6d1823c-e857-48d4-b82c-44ea52206640", 84 | "metadata": {}, 85 | "outputs": [ 86 | { 87 | "data": { 88 | "text/html": [ 89 | "
\n", 90 | "\n", 103 | "\n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | "
01
0还有双鸭山到淮阴的汽车票吗13号的Travel-Query
1从这里怎么回家Travel-Query
2随便播放一首专辑阁楼里的佛里的歌Music-Play
3给看一下墓王之王嘛FilmTele-Play
4我想看挑战两把s686打突变团竞的游戏视频Video-Play
.........
12095一千六百五十三加三千一百六十五点六五等于几Calendar-Query
12096稍小点客厅空调风速HomeAppliance-Control
12097黎耀祥陈豪邓萃雯畲诗曼陈法拉敖嘉年杨怡马浚伟等到场出席Radio-Listen
12098百事盖世群星星光演唱会有谁Video-Play
12099下周一视频会议的闹钟帮我开开Alarm-Update
\n", 169 | "

12100 rows × 2 columns

\n", 170 | "
" 171 | ], 172 | "text/plain": [ 173 | " 0 1\n", 174 | "0 还有双鸭山到淮阴的汽车票吗13号的 Travel-Query\n", 175 | "1 从这里怎么回家 Travel-Query\n", 176 | "2 随便播放一首专辑阁楼里的佛里的歌 Music-Play\n", 177 | "3 给看一下墓王之王嘛 FilmTele-Play\n", 178 | "4 我想看挑战两把s686打突变团竞的游戏视频 Video-Play\n", 179 | "... ... ...\n", 180 | "12095 一千六百五十三加三千一百六十五点六五等于几 Calendar-Query\n", 181 | "12096 稍小点客厅空调风速 HomeAppliance-Control\n", 182 | "12097 黎耀祥陈豪邓萃雯畲诗曼陈法拉敖嘉年杨怡马浚伟等到场出席 Radio-Listen\n", 183 | "12098 百事盖世群星星光演唱会有谁 Video-Play\n", 184 | "12099 下周一视频会议的闹钟帮我开开 Alarm-Update\n", 185 | "\n", 186 | "[12100 rows x 2 columns]" 187 | ] 188 | }, 189 | "execution_count": 70, 190 | "metadata": {}, 191 | "output_type": "execute_result" 192 | } 193 | ], 194 | "source": [ 195 | "train_data" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 71, 201 | "id": "733edea2-647a-48e2-95e9-8e79d88d96fa", 202 | "metadata": {}, 203 | "outputs": [ 204 | { 205 | "data": { 206 | "text/plain": [ 207 | "'Travel-Query/Music-Play/FilmTele-Play/Video-Play/Radio-Listen/HomeAppliance-Control/Weather-Query/Alarm-Update/Calendar-Query/TVProgram-Play/Audio-Play/Other'" 208 | ] 209 | }, 210 | "execution_count": 71, 211 | "metadata": {}, 212 | "output_type": "execute_result" 213 | } 214 | ], 215 | "source": [ 216 | "'/'.join(train_data[1].unique())" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": 72, 222 | "id": "0c90b916-a8b5-4768-a636-0f1155e72837", 223 | "metadata": {}, 224 | "outputs": [ 225 | { 226 | "name": "stdout", 227 | "output_type": "stream", 228 | "text": [ 229 | "CPU times: user 181 ms, sys: 3.85 ms, total: 185 ms\n", 230 | "Wall time: 183 ms\n" 231 | ] 232 | }, 233 | { 234 | "data": { 235 | "text/plain": [ 236 | "['Music-Play']" 237 | ] 238 | }, 239 | "execution_count": 72, 240 | "metadata": {}, 241 | "output_type": "execute_result" 242 | } 243 | ], 244 | "source": [ 245 | "%%time\n", 246 | "\n", 247 | "text = \"意图识别任务:【播放周杰伦的歌曲】 这篇文章的类别是什么?Travel-Query/Music-Play/FilmTele-Play/Video-Play/Radio-Listen/HomeAppliance-Control/Weather-Query/Alarm-Update/Calendar-Query/TVProgram-Play/Audio-Play/Other\"\n", 248 | "encode_dict = tokenizer(text)\n", 249 | "\n", 250 | "inputs = {\n", 251 | " \"input_ids\": torch.tensor([encode_dict['input_ids']]).long().to(device),\n", 252 | " \"attention_mask\": torch.tensor([encode_dict['attention_mask']]).long().to(device),\n", 253 | "}\n", 254 | "\n", 255 | "logits = model.generate(\n", 256 | " input_ids = inputs['input_ids'],\n", 257 | " # attention_mask = inputs['attention_mask'],\n", 258 | " max_length=20, \n", 259 | " do_sample= False\n", 260 | ")\n", 261 | "\n", 262 | "logits=logits[:,1:]\n", 263 | "predict_label = [tokenizer.decode(i,skip_special_tokens=True) for i in logits]\n", 264 | "predict_label" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 73, 270 | "id": "6dfe163f-55dd-4cf2-9afb-3c7563d03f79", 271 | "metadata": {}, 272 | "outputs": [], 273 | "source": [ 274 | "from tqdm import tqdm_notebook" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": 74, 280 | "id": "90984788-5411-4cb1-9369-4e7af42b7f65", 281 | "metadata": {}, 282 | "outputs": [ 283 | { 284 | "name": "stderr", 285 | "output_type": "stream", 286 | "text": [ 287 | "/tmp/ipykernel_26227/1619661787.py:2: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0\n", 288 | "Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`\n", 289 | " for train_text in tqdm_notebook(test_data[0].values):\n" 290 | ] 291 | }, 292 | { 293 | "data": { 294 | "application/vnd.jupyter.widget-view+json": { 295 | "model_id": "64bc6bd790ce4630b53687a053fc8c1d", 296 | "version_major": 2, 297 | "version_minor": 0 298 | }, 299 | "text/plain": [ 300 | " 0%| | 0/3000 [00:00\n", 102 | "\n", 115 | "\n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | "
01
0还有双鸭山到淮阴的汽车票吗13号的Travel-Query
1从这里怎么回家Travel-Query
2随便播放一首专辑阁楼里的佛里的歌Music-Play
3给看一下墓王之王嘛FilmTele-Play
4我想看挑战两把s686打突变团竞的游戏视频Video-Play
.........
12095一千六百五十三加三千一百六十五点六五等于几Calendar-Query
12096稍小点客厅空调风速HomeAppliance-Control
12097黎耀祥陈豪邓萃雯畲诗曼陈法拉敖嘉年杨怡马浚伟等到场出席Radio-Listen
12098百事盖世群星星光演唱会有谁Video-Play
12099下周一视频会议的闹钟帮我开开Alarm-Update
\n", 181 | "

12100 rows × 2 columns

\n", 182 | "" 183 | ], 184 | "text/plain": [ 185 | " 0 1\n", 186 | "0 还有双鸭山到淮阴的汽车票吗13号的 Travel-Query\n", 187 | "1 从这里怎么回家 Travel-Query\n", 188 | "2 随便播放一首专辑阁楼里的佛里的歌 Music-Play\n", 189 | "3 给看一下墓王之王嘛 FilmTele-Play\n", 190 | "4 我想看挑战两把s686打突变团竞的游戏视频 Video-Play\n", 191 | "... ... ...\n", 192 | "12095 一千六百五十三加三千一百六十五点六五等于几 Calendar-Query\n", 193 | "12096 稍小点客厅空调风速 HomeAppliance-Control\n", 194 | "12097 黎耀祥陈豪邓萃雯畲诗曼陈法拉敖嘉年杨怡马浚伟等到场出席 Radio-Listen\n", 195 | "12098 百事盖世群星星光演唱会有谁 Video-Play\n", 196 | "12099 下周一视频会议的闹钟帮我开开 Alarm-Update\n", 197 | "\n", 198 | "[12100 rows x 2 columns]" 199 | ] 200 | }, 201 | "execution_count": 3, 202 | "metadata": {}, 203 | "output_type": "execute_result" 204 | } 205 | ], 206 | "source": [ 207 | "train_data" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 4, 213 | "id": "733edea2-647a-48e2-95e9-8e79d88d96fa", 214 | "metadata": {}, 215 | "outputs": [ 216 | { 217 | "data": { 218 | "text/plain": [ 219 | "'Travel-Query Music-Play FilmTele-Play Video-Play Radio-Listen HomeAppliance-Control Weather-Query Alarm-Update Calendar-Query TVProgram-Play Audio-Play Other'" 220 | ] 221 | }, 222 | "execution_count": 4, 223 | "metadata": {}, 224 | "output_type": "execute_result" 225 | } 226 | ], 227 | "source": [ 228 | "' '.join(train_data[1].unique())" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": 5, 234 | "id": "0c90b916-a8b5-4768-a636-0f1155e72837", 235 | "metadata": {}, 236 | "outputs": [ 237 | { 238 | "name": "stdout", 239 | "output_type": "stream", 240 | "text": [ 241 | "CPU times: user 359 ms, sys: 0 ns, total: 359 ms\n", 242 | "Wall time: 365 ms\n" 243 | ] 244 | }, 245 | { 246 | "data": { 247 | "text/plain": [ 248 | "'Music-Play'" 249 | ] 250 | }, 251 | "execution_count": 5, 252 | "metadata": {}, 253 | "output_type": "execute_result" 254 | } 255 | ], 256 | "source": [ 257 | "%%time\n", 258 | "\n", 259 | "prompt = '''识别句子的意图:播放周杰伦的歌曲\n", 260 | "待选类别:Travel-Query Music-Play FilmTele-Play Video-Play Radio-Listen HomeAppliance-Control Weather-Query Alarm-Update Calendar-Query TVProgram-Play Audio-Play Other\n", 261 | "只需要输出类别(从待选类别中选一个),不要其他输出。\n", 262 | "'''\n", 263 | "messages = [\n", 264 | " {\"role\": \"user\", \"content\": prompt}\n", 265 | "]\n", 266 | "text = tokenizer.apply_chat_template(\n", 267 | " messages,\n", 268 | " tokenize=False,\n", 269 | " add_generation_prompt=True\n", 270 | ")\n", 271 | "model_inputs = tokenizer([text], return_tensors=\"pt\").to(device)\n", 272 | "\n", 273 | "generated_ids = model.generate(\n", 274 | " model_inputs.input_ids,\n", 275 | " max_new_tokens=512,\n", 276 | " \n", 277 | ")\n", 278 | "generated_ids = [\n", 279 | " output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)\n", 280 | "]\n", 281 | "\n", 282 | "response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]\n", 283 | "response" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": 6, 289 | "id": "6dfe163f-55dd-4cf2-9afb-3c7563d03f79", 290 | "metadata": {}, 291 | "outputs": [], 292 | "source": [ 293 | "from tqdm import tqdm_notebook" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": 7, 299 | "id": "90984788-5411-4cb1-9369-4e7af42b7f65", 300 | "metadata": {}, 301 | "outputs": [ 302 | { 303 | "name": "stderr", 304 | "output_type": "stream", 305 | "text": [ 306 | "/tmp/ipykernel_36021/2977000300.py:2: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0\n", 307 | "Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`\n", 308 | " for train_text in tqdm_notebook(test_data[0].values):\n" 309 | ] 310 | }, 311 | { 312 | "data": { 313 | "application/vnd.jupyter.widget-view+json": { 314 | "model_id": "3fd75ff95caa4c378a14b06ef2cf0a1e", 315 | "version_major": 2, 316 | "version_minor": 0 317 | }, 318 | "text/plain": [ 319 | " 0%| | 0/3000 [00:00

注册链接:https://xinghuo.xfyun.cn/sparkapi?ch=nt_api_m1ophux

" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "0fdf7354-b1e1-45f1-b482-882183589fdb", 14 | "metadata": {}, 15 | "source": [ 16 | "## 文本对话\n", 17 | "\n", 18 | "- 控制台:https://console.xfyun.cn/services/bm35\n", 19 | "- API文档:https://www.xfyun.cn/doc/spark/Web.html" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 1, 25 | "id": "359d9854-e899-40fb-b7d0-17283ca863ab", 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "# coding: utf-8\n", 30 | "import _thread as thread\n", 31 | "import os\n", 32 | "import time\n", 33 | "import base64\n", 34 | "\n", 35 | "import base64\n", 36 | "import datetime\n", 37 | "import hashlib\n", 38 | "import hmac\n", 39 | "import json\n", 40 | "from urllib.parse import urlparse\n", 41 | "import ssl\n", 42 | "from datetime import datetime\n", 43 | "from time import mktime\n", 44 | "from urllib.parse import urlencode\n", 45 | "from wsgiref.handlers import format_date_time\n", 46 | "\n", 47 | "import websocket\n", 48 | "import openpyxl\n", 49 | "from concurrent.futures import ThreadPoolExecutor, as_completed\n", 50 | "import os\n", 51 | "\n", 52 | "\n", 53 | "class Ws_Param(object):\n", 54 | " # 初始化\n", 55 | " def __init__(self, APPID, APIKey, APISecret, gpt_url):\n", 56 | " self.APPID = APPID\n", 57 | " self.APIKey = APIKey\n", 58 | " self.APISecret = APISecret\n", 59 | " self.host = urlparse(gpt_url).netloc\n", 60 | " self.path = urlparse(gpt_url).path\n", 61 | " self.gpt_url = gpt_url\n", 62 | "\n", 63 | " # 生成url\n", 64 | " def create_url(self):\n", 65 | " # 生成RFC1123格式的时间戳\n", 66 | " now = datetime.now()\n", 67 | " date = format_date_time(mktime(now.timetuple()))\n", 68 | "\n", 69 | " # 拼接字符串\n", 70 | " signature_origin = \"host: \" + self.host + \"\\n\"\n", 71 | " signature_origin += \"date: \" + date + \"\\n\"\n", 72 | " signature_origin += \"GET \" + self.path + \" HTTP/1.1\"\n", 73 | "\n", 74 | " # 进行hmac-sha256进行加密\n", 75 | " signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),\n", 76 | " digestmod=hashlib.sha256).digest()\n", 77 | "\n", 78 | " signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')\n", 79 | "\n", 80 | " authorization_origin = f'api_key=\"{self.APIKey}\", algorithm=\"hmac-sha256\", headers=\"host date request-line\", signature=\"{signature_sha_base64}\"'\n", 81 | "\n", 82 | " authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')\n", 83 | "\n", 84 | " # 将请求的鉴权参数组合为字典\n", 85 | " v = {\n", 86 | " \"authorization\": authorization,\n", 87 | " \"date\": date,\n", 88 | " \"host\": self.host\n", 89 | " }\n", 90 | " # 拼接鉴权参数,生成url\n", 91 | " url = self.gpt_url + '?' + urlencode(v)\n", 92 | " # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致\n", 93 | " return url\n", 94 | "\n", 95 | "\n", 96 | "\n", 97 | "# 收到websocket错误的处理\n", 98 | "def on_error(ws, error):\n", 99 | " print(\"### error:\", error)\n", 100 | "\n", 101 | "\n", 102 | "# 收到websocket关闭的处理\n", 103 | "def on_close(ws):\n", 104 | " print(\"### closed ###\")\n", 105 | "\n", 106 | "\n", 107 | "# 收到websocket连接建立的处理\n", 108 | "def on_open(ws):\n", 109 | " thread.start_new_thread(run, (ws,))\n", 110 | "\n", 111 | "\n", 112 | "def run(ws, *args):\n", 113 | " data = json.dumps(gen_params(appid=ws.appid, query=ws.query, domain=ws.domain))\n", 114 | " ws.send(data)\n", 115 | "\n", 116 | "\n", 117 | "# 收到websocket消息的处理\n", 118 | "def on_message(ws, message):\n", 119 | " # print(message)\n", 120 | " data = json.loads(message)\n", 121 | " code = data['header']['code']\n", 122 | " if code != 0:\n", 123 | " print(f'请求错误: {code}, {data}')\n", 124 | " ws.close()\n", 125 | " else:\n", 126 | " choices = data[\"payload\"][\"choices\"]\n", 127 | " status = choices[\"status\"]\n", 128 | " content = choices[\"text\"][0][\"content\"]\n", 129 | " print(content,end='')\n", 130 | " if status == 2:\n", 131 | " print(\"#### 关闭会话\")\n", 132 | " ws.close()\n", 133 | "\n", 134 | "\n", 135 | "def gen_params(appid, query, domain):\n", 136 | " \"\"\"\n", 137 | " 通过appid和用户的提问来生成请参数\n", 138 | " \"\"\"\n", 139 | "\n", 140 | " data = {\n", 141 | " \"header\": {\n", 142 | " \"app_id\": appid,\n", 143 | " \"uid\": \"1234\",\n", 144 | " # \"patch_id\": [] #接入微调模型,对应服务发布后的resourceid\n", 145 | " },\n", 146 | " \"parameter\": {\n", 147 | " \"chat\": {\n", 148 | " \"domain\": domain,\n", 149 | " \"temperature\": 0.5,\n", 150 | " \"max_tokens\": 4096,\n", 151 | " \"auditing\": \"default\",\n", 152 | " }\n", 153 | " },\n", 154 | " \"payload\": {\n", 155 | " \"message\": {\n", 156 | " \"text\": [{\"role\": \"user\", \"content\": query}]\n", 157 | " }\n", 158 | " }\n", 159 | " }\n", 160 | " return data\n", 161 | "\n", 162 | "\n", 163 | "def main(appid, api_secret, api_key, gpt_url, domain, query):\n", 164 | " wsParam = Ws_Param(appid, api_key, api_secret, gpt_url)\n", 165 | " websocket.enableTrace(False)\n", 166 | " wsUrl = wsParam.create_url()\n", 167 | "\n", 168 | " ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)\n", 169 | " ws.appid = appid\n", 170 | " ws.query = query\n", 171 | " ws.domain = domain\n", 172 | " ws.run_forever(sslopt={\"cert_reqs\": ssl.CERT_NONE})" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 2, 178 | "id": "fbb25706-7ffc-467a-a006-4a92b2df7244", 179 | "metadata": {}, 180 | "outputs": [ 181 | { 182 | "name": "stdout", 183 | "output_type": "stream", 184 | "text": [ 185 | "很高兴听到你今天很开心!如果你有任何问题或者想聊天,请随时告诉我!#### 关闭会话\n" 186 | ] 187 | } 188 | ], 189 | "source": [ 190 | "main(\n", 191 | " appid=\"XXX\",\n", 192 | " api_secret=\"XXX\",\n", 193 | " api_key=\"XXX\",\n", 194 | " #appid、api_secret、api_key三个服务认证信息请前往开放平台控制台查看(https://console.xfyun.cn/services/bm35)\n", 195 | " gpt_url=\"wss://spark-api.xf-yun.com/v3.5/chat\",\n", 196 | " # Spark_url = \"ws://spark-api.xf-yun.com/v3.1/chat\" # v3.0环境的地址\n", 197 | " # Spark_url = \"ws://spark-api.xf-yun.com/v2.1/chat\" # v2.0环境的地址\n", 198 | " # Spark_url = \"ws://spark-api.xf-yun.com/v1.1/chat\" # v1.5环境的地址\n", 199 | " domain=\"generalv3.5\",\n", 200 | " # domain = \"generalv3\" # v3.0版本\n", 201 | " # domain = \"generalv2\" # v2.0版本\n", 202 | " # domain = \"general\" # v2.0版本\n", 203 | " query=\"我今天很开心。\"\n", 204 | ")" 205 | ] 206 | }, 207 | { 208 | "cell_type": "markdown", 209 | "id": "cebd079b-0df6-4a2d-95ef-31a2e791ffa6", 210 | "metadata": {}, 211 | "source": [ 212 | "## 图片生成\n", 213 | "\n", 214 | "- 控制台:https://console.xfyun.cn/services/tti\n", 215 | "- API文档:https://www.xfyun.cn/doc/spark/ImageGeneration.html" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 3, 221 | "id": "8046f01b-8e71-4d8d-919f-ee37f2006747", 222 | "metadata": {}, 223 | "outputs": [], 224 | "source": [ 225 | "# encoding: UTF-8\n", 226 | "import time\n", 227 | "\n", 228 | "import requests\n", 229 | "from datetime import datetime\n", 230 | "from wsgiref.handlers import format_date_time\n", 231 | "from time import mktime\n", 232 | "import hashlib\n", 233 | "import base64\n", 234 | "import hmac\n", 235 | "from urllib.parse import urlencode\n", 236 | "import json\n", 237 | "from PIL import Image\n", 238 | "from io import BytesIO\n", 239 | "\n", 240 | "class AssembleHeaderException(Exception):\n", 241 | " def __init__(self, msg):\n", 242 | " self.message = msg\n", 243 | "\n", 244 | "\n", 245 | "class Url:\n", 246 | " def __init__(this, host, path, schema):\n", 247 | " this.host = host\n", 248 | " this.path = path\n", 249 | " this.schema = schema\n", 250 | " pass\n", 251 | "\n", 252 | "\n", 253 | "# calculate sha256 and encode to base64\n", 254 | "def sha256base64(data):\n", 255 | " sha256 = hashlib.sha256()\n", 256 | " sha256.update(data)\n", 257 | " digest = base64.b64encode(sha256.digest()).decode(encoding='utf-8')\n", 258 | " return digest\n", 259 | "\n", 260 | "\n", 261 | "def parse_url(requset_url):\n", 262 | " stidx = requset_url.index(\"://\")\n", 263 | " host = requset_url[stidx + 3:]\n", 264 | " schema = requset_url[:stidx + 3]\n", 265 | " edidx = host.index(\"/\")\n", 266 | " if edidx <= 0:\n", 267 | " raise AssembleHeaderException(\"invalid request url:\" + requset_url)\n", 268 | " path = host[edidx:]\n", 269 | " host = host[:edidx]\n", 270 | " u = Url(host, path, schema)\n", 271 | " return u\n", 272 | "\n", 273 | "\n", 274 | "# 生成鉴权url\n", 275 | "def assemble_ws_auth_url(requset_url, method=\"GET\", api_key=\"\", api_secret=\"\"):\n", 276 | " u = parse_url(requset_url)\n", 277 | " host = u.host\n", 278 | " path = u.path\n", 279 | " now = datetime.now()\n", 280 | " date = format_date_time(mktime(now.timetuple()))\n", 281 | " # print(date)\n", 282 | " # date = \"Thu, 12 Dec 2019 01:57:27 GMT\"\n", 283 | " signature_origin = \"host: {}\\ndate: {}\\n{} {} HTTP/1.1\".format(host, date, method, path)\n", 284 | " # print(signature_origin)\n", 285 | " signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'),\n", 286 | " digestmod=hashlib.sha256).digest()\n", 287 | " signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')\n", 288 | " authorization_origin = \"api_key=\\\"%s\\\", algorithm=\\\"%s\\\", headers=\\\"%s\\\", signature=\\\"%s\\\"\" % (\n", 289 | " api_key, \"hmac-sha256\", \"host date request-line\", signature_sha)\n", 290 | " authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')\n", 291 | " # print(authorization_origin)\n", 292 | " values = {\n", 293 | " \"host\": host,\n", 294 | " \"date\": date,\n", 295 | " \"authorization\": authorization\n", 296 | " }\n", 297 | "\n", 298 | " return requset_url + \"?\" + urlencode(values)\n", 299 | "\n", 300 | "# 生成请求body体\n", 301 | "def getBody(appid,text):\n", 302 | " body= {\n", 303 | " \"header\": {\n", 304 | " \"app_id\": appid,\n", 305 | " \"uid\":\"123456789\"\n", 306 | " },\n", 307 | " \"parameter\": {\n", 308 | " \"chat\": {\n", 309 | " \"domain\": \"general\",\n", 310 | " \"temperature\":0.5,\n", 311 | " \"max_tokens\":4096\n", 312 | " }\n", 313 | " },\n", 314 | " \"payload\": {\n", 315 | " \"message\":{\n", 316 | " \"text\":[\n", 317 | " {\n", 318 | " \"role\":\"user\",\n", 319 | " \"content\":text\n", 320 | " }\n", 321 | " ]\n", 322 | " }\n", 323 | " }\n", 324 | " }\n", 325 | " return body\n", 326 | "\n", 327 | "# 发起请求并返回结果\n", 328 | "def main(text,appid,apikey,apisecret):\n", 329 | " host = 'http://spark-api.cn-huabei-1.xf-yun.com/v2.1/tti'\n", 330 | " url = assemble_ws_auth_url(host,method='POST',api_key=apikey,api_secret=apisecret)\n", 331 | " content = getBody(appid,text)\n", 332 | " print(time.time())\n", 333 | " response = requests.post(url,json=content,headers={'content-type': \"application/json\"}).text\n", 334 | " print(time.time())\n", 335 | " return response\n", 336 | "\n", 337 | "#将base64 的图片数据存在本地\n", 338 | "def base64_to_image(base64_data, save_path):\n", 339 | " # 解码base64数据\n", 340 | " img_data = base64.b64decode(base64_data)\n", 341 | "\n", 342 | " # 将解码后的数据转换为图片\n", 343 | " img = Image.open(BytesIO(img_data))\n", 344 | "\n", 345 | " # 保存图片到本地\n", 346 | " img.save(save_path)\n", 347 | "\n", 348 | "\n", 349 | "# 解析并保存到指定位置\n", 350 | "def parser_Message(message):\n", 351 | " data = json.loads(message)\n", 352 | " # print(\"data\" + str(message))\n", 353 | " code = data['header']['code']\n", 354 | " if code != 0:\n", 355 | " print(f'请求错误: {code}, {data}')\n", 356 | " else:\n", 357 | " text = data[\"payload\"][\"choices\"][\"text\"]\n", 358 | " imageContent = text[0]\n", 359 | " # if('image' == imageContent[\"content_type\"]):\n", 360 | " imageBase = imageContent[\"content\"]\n", 361 | " imageName = data['header']['sid']\n", 362 | " savePath = f\"{imageName}.jpg\"\n", 363 | " base64_to_image(imageBase,savePath)\n", 364 | " print(\"图片保存路径:\" + savePath)" 365 | ] 366 | }, 367 | { 368 | "cell_type": "code", 369 | "execution_count": 4, 370 | "id": "83758e97-c649-44bd-ab0a-9e087b9e576d", 371 | "metadata": {}, 372 | "outputs": [ 373 | { 374 | "name": "stdout", 375 | "output_type": "stream", 376 | "text": [ 377 | "1710491962.4746332\n", 378 | "1710491966.1379104\n", 379 | "图片保存路径:cht000b3678@dx18e41440cdeb8f3550.jpg\n" 380 | ] 381 | } 382 | ], 383 | "source": [ 384 | "#运行前请配置以下鉴权三要素,获取途径:https://console.xfyun.cn/services/tti\n", 385 | "APPID ='XXX'\n", 386 | "APISecret = 'XXX'\n", 387 | "APIKEY = 'XXX'\n", 388 | "desc = '''生成一张图:一个男生在打篮球'''\n", 389 | "res = main(desc,appid=APPID,apikey=APIKEY,apisecret=APISecret)\n", 390 | "parser_Message(res)" 391 | ] 392 | }, 393 | { 394 | "cell_type": "markdown", 395 | "id": "2ac6be3e-14d6-4393-ba45-30ab268f26d1", 396 | "metadata": {}, 397 | "source": [ 398 | "## 图片理解\n", 399 | "\n", 400 | "- 控制台:https://console.xfyun.cn/services/image\n", 401 | "- API文档:https://www.xfyun.cn/doc/spark/ImageGeneration.html" 402 | ] 403 | }, 404 | { 405 | "cell_type": "markdown", 406 | "id": "ceff20c8-5465-46cd-89e2-26acc181b5a9", 407 | "metadata": {}, 408 | "source": [ 409 | "## 语音合成\n", 410 | "\n", 411 | "- 控制台:https://console.xfyun.cn/services/medd90fec\n", 412 | "- API文档:https://www.xfyun.cn/doc/spark/%E8%B6%85%E6%8B%9F%E2%BC%88%E5%90%88%E6%88%90%E5%8D%8F%E8%AE%AE.html" 413 | ] 414 | }, 415 | { 416 | "cell_type": "code", 417 | "execution_count": 17, 418 | "id": "f979cc5d-d30a-407d-9f49-e5bc7703b3e9", 419 | "metadata": {}, 420 | "outputs": [], 421 | "source": [ 422 | "# coding: utf-8\n", 423 | "import _thread as thread\n", 424 | "import os\n", 425 | "import time\n", 426 | "import base64\n", 427 | "\n", 428 | "import base64\n", 429 | "import datetime\n", 430 | "import hashlib\n", 431 | "import hmac\n", 432 | "import json\n", 433 | "from urllib.parse import urlparse\n", 434 | "import ssl\n", 435 | "from datetime import datetime\n", 436 | "from time import mktime\n", 437 | "from urllib.parse import urlencode\n", 438 | "from wsgiref.handlers import format_date_time\n", 439 | "import websocket\n", 440 | "import openpyxl\n", 441 | "from concurrent.futures import ThreadPoolExecutor, as_completed\n", 442 | "import os\n", 443 | "\n", 444 | "\n", 445 | "class Ws_Param(object):\n", 446 | " # 初始化\n", 447 | " def __init__(self, APPID, APIKey, APISecret, gpt_url):\n", 448 | " self.APPID = APPID\n", 449 | " self.APIKey = APIKey\n", 450 | " self.APISecret = APISecret\n", 451 | " self.host = urlparse(gpt_url).netloc\n", 452 | " self.path = urlparse(gpt_url).path\n", 453 | " self.gpt_url = gpt_url\n", 454 | "\n", 455 | " # 生成url\n", 456 | " def create_url(self):\n", 457 | " # 生成RFC1123格式的时间戳\n", 458 | " now = datetime.now()\n", 459 | " date = format_date_time(mktime(now.timetuple()))\n", 460 | "\n", 461 | " # 拼接字符串\n", 462 | " signature_origin = \"host: \" + self.host + \"\\n\"\n", 463 | " signature_origin += \"date: \" + date + \"\\n\"\n", 464 | " signature_origin += \"GET \" + self.path + \" HTTP/1.1\"\n", 465 | "\n", 466 | " # 进行hmac-sha256进行加密\n", 467 | " signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),\n", 468 | " digestmod=hashlib.sha256).digest()\n", 469 | "\n", 470 | " signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')\n", 471 | "\n", 472 | " authorization_origin = f'api_key=\"{self.APIKey}\", algorithm=\"hmac-sha256\", headers=\"host date request-line\", signature=\"{signature_sha_base64}\"'\n", 473 | "\n", 474 | " authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')\n", 475 | "\n", 476 | " # 将请求的鉴权参数组合为字典\n", 477 | " v = {\n", 478 | " \"authorization\": authorization,\n", 479 | " \"date\": date,\n", 480 | " \"host\": self.host\n", 481 | " }\n", 482 | " # 拼接鉴权参数,生成url\n", 483 | " url = self.gpt_url + '?' + urlencode(v)\n", 484 | " # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致\n", 485 | " return url\n", 486 | "\n", 487 | "\n", 488 | "# 收到websocket错误的处理\n", 489 | "def on_error(ws, error):\n", 490 | " print(\"### error:\", error)\n", 491 | "\n", 492 | "\n", 493 | "# 收到websocket关闭的处理\n", 494 | "def on_close(ws):\n", 495 | " print(\"### closed ###\")\n", 496 | "\n", 497 | "\n", 498 | "# 收到websocket连接建立的处理\n", 499 | "def on_open(ws):\n", 500 | " thread.start_new_thread(run, (ws,))\n", 501 | "\n", 502 | "\n", 503 | "# 收到websocket消息的处理\n", 504 | "def on_message(ws, message):\n", 505 | " message = json.loads(message)\n", 506 | " code = message['header']['code']\n", 507 | " if code != 0:\n", 508 | " print(\"### 请求出错: \", message)\n", 509 | " else:\n", 510 | " payload = message.get(\"payload\")\n", 511 | " status = message['header']['status']\n", 512 | " if status == 2:\n", 513 | " print(\"### 合成完毕\")\n", 514 | " ws.close()\n", 515 | " if payload and payload != \"null\":\n", 516 | " audio = payload.get(\"audio\")\n", 517 | " if audio:\n", 518 | " audio = audio[\"audio\"]\n", 519 | " with open(fr'./{ws.vcn}.mp3', 'ab') as f:\n", 520 | " f.write(base64.b64decode(audio))\n", 521 | "\n", 522 | "\n", 523 | "def run(ws, *args):\n", 524 | " body = {\n", 525 | " \"header\": {\n", 526 | " \"app_id\": ws.appid,\n", 527 | " \"status\": 0\n", 528 | " },\n", 529 | " \"parameter\": {\n", 530 | " \"oral\": {\n", 531 | " \"spark_assist\": 1,\n", 532 | " \"oral_level\": \"mid\"\n", 533 | " },\n", 534 | " \"tts\": {\n", 535 | " \"vcn\": ws.vcn,\n", 536 | " \"speed\": 50,\n", 537 | " \"volume\": 50,\n", 538 | " \"pitch\": 50,\n", 539 | " \"bgs\": 0,\n", 540 | " \"reg\": 0,\n", 541 | " \"rdn\": 0,\n", 542 | " \"rhy\": 0,\n", 543 | " \"scn\": 0,\n", 544 | " \"version\": 0,\n", 545 | " \"L5SilLen\": 0,\n", 546 | " \"ParagraphSilLen\": 0,\n", 547 | " \"audio\": {\n", 548 | " \"encoding\": \"lame\",\n", 549 | " \"sample_rate\": 16000,\n", 550 | " \"channels\": 1,\n", 551 | " \"bit_depth\": 16,\n", 552 | " \"frame_size\": 0\n", 553 | " },\n", 554 | " \"pybuf\": {\n", 555 | " \"encoding\": \"utf8\",\n", 556 | " \"compress\": \"raw\",\n", 557 | " \"format\": \"plain\"\n", 558 | " }\n", 559 | " }\n", 560 | " },\n", 561 | " \"payload\": {\n", 562 | " \"text\": {\n", 563 | " \"encoding\": \"utf8\",\n", 564 | " \"compress\": \"raw\",\n", 565 | " \"format\": \"json\",\n", 566 | " \"status\": 0,\n", 567 | " \"seq\": 0,\n", 568 | " \"text\": str(base64.b64encode(ws.text.encode('utf-8')), \"UTF8\")\n", 569 | " }\n", 570 | " }\n", 571 | " }\n", 572 | "\n", 573 | " ws.send(json.dumps(body))\n", 574 | "\n", 575 | "\n", 576 | "def main(appid, api_secret, api_key, url, text, vcn):\n", 577 | " wsParam = Ws_Param(appid, api_key, api_secret, url)\n", 578 | " wsUrl = wsParam.create_url()\n", 579 | " ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)\n", 580 | " websocket.enableTrace(False)\n", 581 | " ws.appid = appid\n", 582 | " ws.text = text\n", 583 | " ws.vcn = vcn\n", 584 | " ws.run_forever(sslopt={\"cert_reqs\": ssl.CERT_NONE})" 585 | ] 586 | }, 587 | { 588 | "cell_type": "code", 589 | "execution_count": 20, 590 | "id": "f4818028-563f-412d-ac42-66d9b7dfc4fe", 591 | "metadata": {}, 592 | "outputs": [ 593 | { 594 | "name": "stdout", 595 | "output_type": "stream", 596 | "text": [ 597 | "### 合成完毕\n" 598 | ] 599 | } 600 | ], 601 | "source": [ 602 | "main(\n", 603 | " appid=\"XXX\",\n", 604 | " api_secret=\"XXX\",\n", 605 | " api_key=\"XXX\",\n", 606 | " url=\"wss://cbm01.cn-huabei-1.xf-yun.com/v1/private/medd90fec\",\n", 607 | " # 待合成文本\n", 608 | " text=\"今天天气很不错,适合打羽毛球。\",\n", 609 | " # 发音人参数\n", 610 | " vcn = \"x4_lingfeizhe_oral\"\n", 611 | ")" 612 | ] 613 | }, 614 | { 615 | "cell_type": "code", 616 | "execution_count": null, 617 | "id": "23828570-1fcd-403b-a729-16d2b8c19ab1", 618 | "metadata": {}, 619 | "outputs": [], 620 | "source": [] 621 | } 622 | ], 623 | "metadata": { 624 | "kernelspec": { 625 | "display_name": "Python 3 (ipykernel)", 626 | "language": "python", 627 | "name": "python3.10" 628 | }, 629 | "language_info": { 630 | "codemirror_mode": { 631 | "name": "ipython", 632 | "version": 3 633 | }, 634 | "file_extension": ".py", 635 | "mimetype": "text/x-python", 636 | "name": "python", 637 | "nbconvert_exporter": "python", 638 | "pygments_lexer": "ipython3", 639 | "version": "3.9.10" 640 | } 641 | }, 642 | "nbformat": 4, 643 | "nbformat_minor": 5 644 | } 645 | -------------------------------------------------------------------------------- /notebooks/nlp/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /notebooks/nlp/bert-mlm-example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "41f42f46-af1e-40ec-a736-e0a6ff814399", 6 | "metadata": {}, 7 | "source": [ 8 | "# 引入依赖" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "8032c1ee-6786-4679-9660-31e3a080a39f", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import pandas as pd\n", 19 | "\n", 20 | "from transformers import AutoTokenizer\n", 21 | "\n", 22 | "import torch\n", 23 | "from torch.utils.data import DataLoader, Dataset\n", 24 | "\n", 25 | "from transformers import AutoConfig, AutoModelForMaskedLM\n", 26 | "from transformers import Trainer, TrainingArguments\n", 27 | "from transformers import DataCollatorForLanguageModeling" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "id": "a7d72759-934e-4157-b0eb-ebc72ec06cc6", 33 | "metadata": {}, 34 | "source": [ 35 | "# 模拟训练部分数据" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 2, 41 | "id": "44ef2308-315d-479c-a5dc-c95d1a7868e0", 42 | "metadata": {}, 43 | "outputs": [ 44 | { 45 | "data": { 46 | "text/html": [ 47 | "
\n", 48 | "\n", 61 | "\n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | "
text
92021华为matepadpro保护套12.6带笔槽10.8寸matepad11平板电脑壳V...
11食掌门 绞肉机800瓦大功率小型肉馅机电动灌肠机
5WiFi收钱提示音响二维码语音播报器收款付款到账蓝牙小音箱色 炫酷黑 【蓝牙版】超长待机45天
1同福茂 适用华为畅享20Pro摄像头玻璃镜片畅享Z前后置大小照相机头 【】畅享20Z后主摄像头
4峰爵苹果13promax手机壳平安喜乐新款iPhone13虎年创意女款13pro全包本命年小...
8华为同款佳mp3mp4音乐播放器有屏学生随身听苹果风P5外放OTG可爱迷你 红 【带外放】可...
17源由 游戏键盘鼠标吃鸡oppoA59/A77/A83/A73/79/R15梦境手机壳保护皮套...
7QYHJD 适用荣耀9i背夹充电宝华为8/9/10青春版背夹电池nova青春版电源薄 荣耀1...
12ABS A4机柜文件夹盒PS柜配电箱威图柜WJ-1PS资料袋AE箱WJ-2文件夹 橙色带胶2...
19抖音快手热门歌曲音乐播放器迷你运动随身听mp4学生MP3 小怪兽 外响版+带卡含200首抖音...
2小米(MI)通用mp3MP4播放器迷你qq音乐学生运动跑步随身听电子书录音笔外放苹果 蓝色(...
6常用安全警示全套标示牌安全标识牌车间施工生产警告标志牌提示贴标语严禁烟火禁止吸烟有电危险标牌...
15机旺马华为畅享20plus手机壳磨砂超薄硅胶全包防摔欧美潮牌个性创意卡通图案男士款女网红软壳...
14【官方原装】荧阙 适用2021款iPad保护套10.2英英寸平板9代8电脑Air4/3带笔槽...
20胜枫 华为p50pro手机壳车载磁吸款p50镜头全包防摔保护套典藏版磨砂防指纹硬壳纯色极简保...
10韵果【小学初中高中生学习专用】蓝牙mp3英语听力随身听学生版小型便携式mp4mp5音乐播放器...
0苹果13/12/11/promax钢化膜13pro/12pro防窥iphone/x/xr/6...
16OPPOReno5手机壳男金属镜头全包reno5k磨砂OPPOreno5手机套5G誉科创SN...
18ITDK充电器插头超级快充数据充电线 适用于 【Max40W超级快充】充电器+1.5米5A数...
13鹿纯红米k40手机壳k40pro新款液态硅胶k40pro+全包防摔潮男女款软壳k40游戏增强...
3官方旗舰华强北智能watch6太空人s6代apple手表se顶配男士iwatch运动型iph...
\n", 155 | "
" 156 | ], 157 | "text/plain": [ 158 | " text\n", 159 | "9 2021华为matepadpro保护套12.6带笔槽10.8寸matepad11平板电脑壳V...\n", 160 | "11 食掌门 绞肉机800瓦大功率小型肉馅机电动灌肠机\n", 161 | "5 WiFi收钱提示音响二维码语音播报器收款付款到账蓝牙小音箱色 炫酷黑 【蓝牙版】超长待机45天\n", 162 | "1 同福茂 适用华为畅享20Pro摄像头玻璃镜片畅享Z前后置大小照相机头 【】畅享20Z后主摄像头\n", 163 | "4 峰爵苹果13promax手机壳平安喜乐新款iPhone13虎年创意女款13pro全包本命年小...\n", 164 | "8 华为同款佳mp3mp4音乐播放器有屏学生随身听苹果风P5外放OTG可爱迷你 红 【带外放】可...\n", 165 | "17 源由 游戏键盘鼠标吃鸡oppoA59/A77/A83/A73/79/R15梦境手机壳保护皮套...\n", 166 | "7 QYHJD 适用荣耀9i背夹充电宝华为8/9/10青春版背夹电池nova青春版电源薄 荣耀1...\n", 167 | "12 ABS A4机柜文件夹盒PS柜配电箱威图柜WJ-1PS资料袋AE箱WJ-2文件夹 橙色带胶2...\n", 168 | "19 抖音快手热门歌曲音乐播放器迷你运动随身听mp4学生MP3 小怪兽 外响版+带卡含200首抖音...\n", 169 | "2 小米(MI)通用mp3MP4播放器迷你qq音乐学生运动跑步随身听电子书录音笔外放苹果 蓝色(...\n", 170 | "6 常用安全警示全套标示牌安全标识牌车间施工生产警告标志牌提示贴标语严禁烟火禁止吸烟有电危险标牌...\n", 171 | "15 机旺马华为畅享20plus手机壳磨砂超薄硅胶全包防摔欧美潮牌个性创意卡通图案男士款女网红软壳...\n", 172 | "14 【官方原装】荧阙 适用2021款iPad保护套10.2英英寸平板9代8电脑Air4/3带笔槽...\n", 173 | "20 胜枫 华为p50pro手机壳车载磁吸款p50镜头全包防摔保护套典藏版磨砂防指纹硬壳纯色极简保...\n", 174 | "10 韵果【小学初中高中生学习专用】蓝牙mp3英语听力随身听学生版小型便携式mp4mp5音乐播放器...\n", 175 | "0 苹果13/12/11/promax钢化膜13pro/12pro防窥iphone/x/xr/6...\n", 176 | "16 OPPOReno5手机壳男金属镜头全包reno5k磨砂OPPOreno5手机套5G誉科创SN...\n", 177 | "18 ITDK充电器插头超级快充数据充电线 适用于 【Max40W超级快充】充电器+1.5米5A数...\n", 178 | "13 鹿纯红米k40手机壳k40pro新款液态硅胶k40pro+全包防摔潮男女款软壳k40游戏增强...\n", 179 | "3 官方旗舰华强北智能watch6太空人s6代apple手表se顶配男士iwatch运动型iph..." 180 | ] 181 | }, 182 | "execution_count": 2, 183 | "metadata": {}, 184 | "output_type": "execute_result" 185 | } 186 | ], 187 | "source": [ 188 | "data_list = [['苹果13/12/11/promax钢化膜13pro/12pro防窥iphone/x/xr/6/7皇迎 【黑边钻石防爆-高清透明膜】2片. iPhone X'],\n", 189 | " ['同福茂 适用华为畅享20Pro摄像头玻璃镜片畅享Z前后置大小照相机头 【】畅享20Z后主摄像头'],\n", 190 | " ['小米(MI)通用mp3MP4播放器迷你qq音乐学生运动跑步随身听电子书录音笔外放苹果 蓝色(带外放+可插内存卡) 8GB 王者套餐'],\n", 191 | " ['官方旗舰华强北智能watch6太空人s6代apple手表se顶配男士iwatch运动型iphone手 44MM/蓝高配版【太空人/长续航/双按键/高清 套餐二“送回环表带“'],\n", 192 | " ['峰爵苹果13promax手机壳平安喜乐新款iPhone13虎年创意女款13pro全包本命年小羊皮防摔 苹果13-红色-日富一日老虎'],\n", 193 | " ['WiFi收钱提示音响二维码语音播报器收款付款到账蓝牙小音箱色 炫酷黑 【蓝牙版】超长待机45天'],\n", 194 | " ['常用安全警示全套标示牌安全标识牌车间施工生产警告标志牌提示贴标语严禁烟火禁止吸烟有电危险标牌定制 禁止攀爬 20x30cm'],\n", 195 | " ['QYHJD 适用荣耀9i背夹充电宝华为8/9/10青春版背夹电池nova青春版电源薄 荣耀10青春版宝石蓝两万毫安 送耳机+数据线'],\n", 196 | " ['华为同款佳mp3mp4音乐播放器有屏学生随身听苹果风P5外放OTG可爱迷你 红 【带外放】可插卡 64GB HiFi耳机套餐【鎹彩膜贴纸】'],\n", 197 | " ['2021华为matepadpro保护套12.6带笔槽10.8寸matepad11平板电脑壳V7硅胶p 【炭灰色+钢化膜-笔槽三折款】 华为平板 C5(10.4英寸)'],\n", 198 | " ['韵果【小学初中高中生学习专用】蓝牙mp3英语听力随身听学生版小型便携式mp4mp5音乐播放器听读 2英寸 触屏【+品质耳机+蓝牙功能】 16G蓝牙版(蓝牙功能+帮下载资料+发声词典)'],\n", 199 | " ['食掌门 绞肉机800瓦大功率小型肉馅机电动灌肠机'],\n", 200 | " ['ABS A4机柜文件夹盒PS柜配电箱威图柜WJ-1PS资料袋AE箱WJ-2文件夹 橙色带胶200个起'],\n", 201 | " ['鹿纯红米k40手机壳k40pro新款液态硅胶k40pro+全包防摔潮男女款软壳k40游戏增强版直边简 红米K40游戏增强版【海军蓝】*+钢化膜'],\n", 202 | " ['【官方原装】荧阙 适用2021款iPad保护套10.2英英寸平板9代8电脑Air4/3带笔槽ipad 【笔槽书本款】 梦想成真 付费软件APP iPad mini6(8.3英寸)'],\n", 203 | " ['机旺马华为畅享20plus手机壳磨砂超薄硅胶全包防摔欧美潮牌个性创意卡通图案男士款女网红软壳保护套 巧克联名 畅享20plus【加网红挂绳】+全屏膜'],\n", 204 | " ['OPPOReno5手机壳男金属镜头全包reno5k磨砂OPPOreno5手机套5G誉科创SN84 黑红色【单壳】刀锋 oppo Reno5'],\n", 205 | " ['源由 游戏键盘鼠标吃鸡oppoA59/A77/A83/A73/79/R15梦境手机壳保护皮套 升级版(黑色)键盘皮套+鼠标'],\n", 206 | " ['ITDK充电器插头超级快充数据充电线 适用于 【Max40W超级快充】充电器+1.5米5A数据线 华为Mate20 Pro/Mate20X 5g4g'],\n", 207 | " ['抖音快手热门歌曲音乐播放器迷你运动随身听mp4学生MP3 小怪兽 外响版+带卡含200首抖音+四件套'],\n", 208 | " ['胜枫 华为p50pro手机壳车载磁吸款p50镜头全包防摔保护套典藏版磨砂防指纹硬壳纯色极简保护套 P50【岩砂黑+车载支架】磁吸套装']]\n", 209 | "\n", 210 | " \n", 211 | "data_df = pd.DataFrame(data_list, columns=['text'])\n", 212 | "data_df = data_df.sample(frac=1)\n", 213 | "data_df" 214 | ] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "id": "ee2629e9-5683-408a-bc47-4c910b7e7611", 219 | "metadata": {}, 220 | "source": [ 221 | "# 定义预训练模型&创建dataloader" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 3, 227 | "id": "cc804bdf-11c5-4d62-b4f8-b43871c59115", 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [ 231 | "'''\n", 232 | " 定义BERT模型\n", 233 | "'''\n", 234 | "# model_checkpoint = 'chinese-roberta-wwm-ext'\n", 235 | "# tokenizer_checkpoint = 'chinese-roberta-wwm-ext'\n", 236 | "'''\n", 237 | " 本地引入模型\n", 238 | "'''\n", 239 | "model_checkpoint = 'D:/env/bert_model/hfl/chinese-roberta-wwm-ext'\n", 240 | "tokenizer_checkpoint = 'D:/env/bert_model/hfl/chinese-roberta-wwm-ext'\n", 241 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": 4, 247 | "id": "3ccbef85-aae9-44d6-be5e-707db6162007", 248 | "metadata": { 249 | "tags": [] 250 | }, 251 | "outputs": [], 252 | "source": [ 253 | "# 简单自定义dataset\n", 254 | "class PreTrainDataset(Dataset):\n", 255 | " def __init__(self, data_list, tokenizer, max_seq_len):\n", 256 | " super(PreTrainDataset, self).__init__()\n", 257 | " self.data_list = data_list\n", 258 | " self.len = len(data_list)\n", 259 | " self.tokenizer = tokenizer\n", 260 | " self.max_seq_len = max_seq_len\n", 261 | "\n", 262 | " def __getitem__(self, index):\n", 263 | " example = self.data_list[index]\n", 264 | " data = self.tokenizer.encode_plus(example, return_token_type_ids=True, return_attention_mask=True, padding = 'max_length', max_length=self.max_seq_len)\n", 265 | " return {'input_ids': torch.tensor(data['input_ids'][:self.max_seq_len], dtype=torch.long), \n", 266 | " 'token_type_ids': torch.tensor(data['token_type_ids'][:self.max_seq_len], dtype=torch.long), \n", 267 | " 'attention_mask': torch.tensor(data['attention_mask'][:self.max_seq_len], dtype=torch.long)}\n", 268 | " \n", 269 | " def __len__(self):\n", 270 | " return self.len\n", 271 | "\n", 272 | "# 生成dataset\n", 273 | "tokenizer = AutoTokenizer.from_pretrained(tokenizer_checkpoint)\n", 274 | "test_size = int(len(data_df)*0.8)\n", 275 | "train_dataset = PreTrainDataset(data_df[:test_size]['text'].tolist(), tokenizer, 128)\n", 276 | "test_dataset = PreTrainDataset(data_df[test_size:]['text'].tolist(), tokenizer, 128)\n", 277 | "# 生成dataloader\n", 278 | "train_dataloader = DataLoader(train_dataset, batch_size=8, num_workers=0)\n", 279 | "test_dataloader = DataLoader(test_dataset, batch_size=8, num_workers=0)\n" 280 | ] 281 | }, 282 | { 283 | "cell_type": "markdown", 284 | "id": "a7335e8a-9103-43b1-a04a-fe3e2d45cd18", 285 | "metadata": {}, 286 | "source": [ 287 | "# 加载定义模型&定义trainer" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": 5, 293 | "id": "f20be05f-f05e-44e7-8a78-aa0c6ee1377d", 294 | "metadata": {}, 295 | "outputs": [], 296 | "source": [ 297 | "\n", 298 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 299 | "# 加载模型\n", 300 | "config = AutoConfig.from_pretrained(model_checkpoint)\n", 301 | "model = AutoModelForMaskedLM.from_config(config)\n", 302 | "\n", 303 | "model.to(device)\n", 304 | "# 定义trainer\n", 305 | "training_args = TrainingArguments(\n", 306 | " output_dir='pretrain_bert',\n", 307 | " overwrite_output_dir=True,\n", 308 | " do_train=True, \n", 309 | " do_eval=True,\n", 310 | " per_device_train_batch_size=32,\n", 311 | " per_device_eval_batch_size=100,\n", 312 | " evaluation_strategy='steps',\n", 313 | " logging_steps=10000,\n", 314 | " eval_steps = None,\n", 315 | " prediction_loss_only=True,\n", 316 | " learning_rate = 1e-5,\n", 317 | " weight_decay=0.01,\n", 318 | " adam_epsilon = 1e-8,\n", 319 | " max_grad_norm = 1.0,\n", 320 | " num_train_epochs = 10,\n", 321 | " save_steps = -1,\n", 322 | " push_to_hub=False\n", 323 | ")\n", 324 | "\n", 325 | "data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)\n", 326 | "trainer = Trainer(\n", 327 | " model=model,\n", 328 | " args=training_args,\n", 329 | " train_dataset=train_dataset,\n", 330 | " eval_dataset=test_dataset,\n", 331 | " data_collator=data_collator,\n", 332 | ")" 333 | ] 334 | }, 335 | { 336 | "cell_type": "markdown", 337 | "id": "6a856b4d-c988-45e4-bd6d-90452bb997fa", 338 | "metadata": {}, 339 | "source": [ 340 | "# 训练" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": null, 346 | "id": "ea3ab395-06fa-4a35-bd16-b579f4cc5681", 347 | "metadata": {}, 348 | "outputs": [], 349 | "source": [ 350 | "trainer.train()\n", 351 | "trainer.save_model()" 352 | ] 353 | } 354 | ], 355 | "metadata": { 356 | "kernelspec": { 357 | "display_name": "Python 3", 358 | "language": "python", 359 | "name": "python3" 360 | }, 361 | "language_info": { 362 | "codemirror_mode": { 363 | "name": "ipython", 364 | "version": 3 365 | }, 366 | "file_extension": ".py", 367 | "mimetype": "text/x-python", 368 | "name": "python", 369 | "nbconvert_exporter": "python", 370 | "pygments_lexer": "ipython3", 371 | "version": "3.8.8" 372 | } 373 | }, 374 | "nbformat": 4, 375 | "nbformat_minor": 5 376 | } 377 | -------------------------------------------------------------------------------- /notebooks/nlp/transformer基础.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 7, 6 | "metadata": { 7 | "execution": { 8 | "iopub.execute_input": "2021-07-10T16:28:14.869549Z", 9 | "iopub.status.busy": "2021-07-10T16:28:14.868977Z", 10 | "iopub.status.idle": "2021-07-10T16:28:24.380797Z", 11 | "shell.execute_reply": "2021-07-10T16:28:24.379799Z", 12 | "shell.execute_reply.started": "2021-07-10T16:28:14.869500Z" 13 | }, 14 | "tags": [] 15 | }, 16 | "outputs": [ 17 | { 18 | "name": "stderr", 19 | "output_type": "stream", 20 | "text": [ 21 | "Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n", 22 | "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 23 | "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", 24 | "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-chinese and are newly initialized: ['classifier.weight', 'classifier.bias']\n", 25 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" 26 | ] 27 | } 28 | ], 29 | "source": [ 30 | "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n", 31 | "\n", 32 | "# AutoTokenizer:输入文本的进行分词,并进行转换编码的操作\n", 33 | "# AutoModelForSequenceClassification:加载置顶的模型,并将输出[CLS]加上全连接进行文本分类\n", 34 | "\n", 35 | "model_name = \"bert-base-chinese\" # 可以自己修改\n", 36 | "\n", 37 | "pt_model = AutoModelForSequenceClassification.from_pretrained(model_name)\n", 38 | "tokenizer = AutoTokenizer.from_pretrained(model_name)" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 9, 44 | "metadata": { 45 | "execution": { 46 | "iopub.execute_input": "2021-07-10T16:28:43.942177Z", 47 | "iopub.status.busy": "2021-07-10T16:28:43.941642Z", 48 | "iopub.status.idle": "2021-07-10T16:28:43.947178Z", 49 | "shell.execute_reply": "2021-07-10T16:28:43.946568Z", 50 | "shell.execute_reply.started": "2021-07-10T16:28:43.942131Z" 51 | }, 52 | "tags": [] 53 | }, 54 | "outputs": [ 55 | { 56 | "data": { 57 | "text/plain": [ 58 | "{'input_ids': [101, 791, 1921, 1921, 3698, 2523, 1962, 8024, 852, 3209, 1921, 833, 678, 7433, 511, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}" 59 | ] 60 | }, 61 | "execution_count": 9, 62 | "metadata": {}, 63 | "output_type": "execute_result" 64 | } 65 | ], 66 | "source": [ 67 | "inputs = tokenizer(\"me\")\n", 68 | "inputs\n", 69 | "# input_ids:\n", 70 | "# token_type_ids:\n", 71 | "# attention_mask:" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 12, 77 | "metadata": { 78 | "execution": { 79 | "iopub.execute_input": "2021-07-10T16:29:19.811054Z", 80 | "iopub.status.busy": "2021-07-10T16:29:19.810515Z", 81 | "iopub.status.idle": "2021-07-10T16:29:19.817458Z", 82 | "shell.execute_reply": "2021-07-10T16:29:19.816364Z", 83 | "shell.execute_reply.started": "2021-07-10T16:29:19.811009Z" 84 | }, 85 | "tags": [] 86 | }, 87 | "outputs": [], 88 | "source": [ 89 | "pt_batch = tokenizer(\n", 90 | " [\"hello\", \"now\"],\n", 91 | " padding=True,\n", 92 | " truncation=True,\n", 93 | " max_length=512,\n", 94 | " return_tensors=\"pt\"\n", 95 | ")" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 13, 101 | "metadata": { 102 | "execution": { 103 | "iopub.execute_input": "2021-07-10T16:29:20.209047Z", 104 | "iopub.status.busy": "2021-07-10T16:29:20.208536Z", 105 | "iopub.status.idle": "2021-07-10T16:29:20.215541Z", 106 | "shell.execute_reply": "2021-07-10T16:29:20.214458Z", 107 | "shell.execute_reply.started": "2021-07-10T16:29:20.209001Z" 108 | }, 109 | "tags": [] 110 | }, 111 | "outputs": [ 112 | { 113 | "name": "stdout", 114 | "output_type": "stream", 115 | "text": [ 116 | "input_ids: [[101, 791, 1921, 678, 7433, 102, 0, 0, 0, 0], [101, 3209, 1921, 678, 7433, 1400, 1921, 678, 7433, 102]]\n", 117 | "token_type_ids: [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]\n", 118 | "attention_mask: [[1, 1, 1, 1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]\n" 119 | ] 120 | } 121 | ], 122 | "source": [ 123 | "for key, value in pt_batch.items():\n", 124 | " print(f\"{key}: {value.numpy().tolist()}\")" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 15, 130 | "metadata": { 131 | "execution": { 132 | "iopub.execute_input": "2021-07-10T16:42:33.618263Z", 133 | "iopub.status.busy": "2021-07-10T16:42:33.617692Z", 134 | "iopub.status.idle": "2021-07-10T16:42:38.171868Z", 135 | "shell.execute_reply": "2021-07-10T16:42:38.170897Z", 136 | "shell.execute_reply.started": "2021-07-10T16:42:33.618215Z" 137 | }, 138 | "tags": [] 139 | }, 140 | "outputs": [], 141 | "source": [ 142 | "from transformers import BertTokenizer\n", 143 | "tokenizer = BertTokenizer.from_pretrained(\"bert-base-cased\")\n", 144 | "\n", 145 | "sequence = \"A Titan RTX has 24GB of VRAM\"\n", 146 | "tokenized_sequence = tokenizer.tokenize(sequence)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 16, 152 | "metadata": { 153 | "execution": { 154 | "iopub.execute_input": "2021-07-10T16:42:38.172981Z", 155 | "iopub.status.busy": "2021-07-10T16:42:38.172786Z", 156 | "iopub.status.idle": "2021-07-10T16:42:38.176551Z", 157 | "shell.execute_reply": "2021-07-10T16:42:38.176026Z", 158 | "shell.execute_reply.started": "2021-07-10T16:42:38.172964Z" 159 | } 160 | }, 161 | "outputs": [ 162 | { 163 | "data": { 164 | "text/plain": [ 165 | "['A',\n", 166 | " 'Titan',\n", 167 | " 'R',\n", 168 | " '##T',\n", 169 | " '##X',\n", 170 | " 'has',\n", 171 | " '24',\n", 172 | " '##GB',\n", 173 | " 'of',\n", 174 | " 'V',\n", 175 | " '##RA',\n", 176 | " '##M']" 177 | ] 178 | }, 179 | "execution_count": 16, 180 | "metadata": {}, 181 | "output_type": "execute_result" 182 | } 183 | ], 184 | "source": [ 185 | "tokenized_sequence" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 17, 191 | "metadata": { 192 | "execution": { 193 | "iopub.execute_input": "2021-07-10T16:42:43.226055Z", 194 | "iopub.status.busy": "2021-07-10T16:42:43.225524Z", 195 | "iopub.status.idle": "2021-07-10T16:42:43.231297Z", 196 | "shell.execute_reply": "2021-07-10T16:42:43.230292Z", 197 | "shell.execute_reply.started": "2021-07-10T16:42:43.226009Z" 198 | } 199 | }, 200 | "outputs": [], 201 | "source": [ 202 | "inputs = tokenizer(sequence)" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 18, 208 | "metadata": { 209 | "execution": { 210 | "iopub.execute_input": "2021-07-10T16:42:50.235952Z", 211 | "iopub.status.busy": "2021-07-10T16:42:50.235433Z", 212 | "iopub.status.idle": "2021-07-10T16:42:50.241011Z", 213 | "shell.execute_reply": "2021-07-10T16:42:50.240507Z", 214 | "shell.execute_reply.started": "2021-07-10T16:42:50.235905Z" 215 | } 216 | }, 217 | "outputs": [ 218 | { 219 | "data": { 220 | "text/plain": [ 221 | "{'input_ids': [101, 138, 18696, 155, 1942, 3190, 1144, 1572, 13745, 1104, 159, 9664, 2107, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}" 222 | ] 223 | }, 224 | "execution_count": 18, 225 | "metadata": {}, 226 | "output_type": "execute_result" 227 | } 228 | ], 229 | "source": [ 230 | "inputs" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": 19, 236 | "metadata": { 237 | "execution": { 238 | "iopub.execute_input": "2021-07-10T16:43:10.712399Z", 239 | "iopub.status.busy": "2021-07-10T16:43:10.711834Z", 240 | "iopub.status.idle": "2021-07-10T16:43:16.185772Z", 241 | "shell.execute_reply": "2021-07-10T16:43:16.185210Z", 242 | "shell.execute_reply.started": "2021-07-10T16:43:10.712352Z" 243 | } 244 | }, 245 | "outputs": [], 246 | "source": [ 247 | "from transformers import BertTokenizer\n", 248 | "tokenizer = BertTokenizer.from_pretrained(\"bert-base-cased\")\n", 249 | "\n", 250 | "sequence_a = \"This is a short sequence.\"\n", 251 | "sequence_b = \"This is a rather long sequence. It is at least longer than the sequence A.\"\n", 252 | "\n", 253 | "encoded_sequence_a = tokenizer(sequence_a)[\"input_ids\"]\n", 254 | "encoded_sequence_b = tokenizer(sequence_b)[\"input_ids\"]" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": 20, 260 | "metadata": { 261 | "execution": { 262 | "iopub.execute_input": "2021-07-10T16:43:16.186812Z", 263 | "iopub.status.busy": "2021-07-10T16:43:16.186621Z", 264 | "iopub.status.idle": "2021-07-10T16:43:16.189833Z", 265 | "shell.execute_reply": "2021-07-10T16:43:16.189424Z", 266 | "shell.execute_reply.started": "2021-07-10T16:43:16.186795Z" 267 | } 268 | }, 269 | "outputs": [ 270 | { 271 | "data": { 272 | "text/plain": [ 273 | "[101, 1188, 1110, 170, 1603, 4954, 119, 102]" 274 | ] 275 | }, 276 | "execution_count": 20, 277 | "metadata": {}, 278 | "output_type": "execute_result" 279 | } 280 | ], 281 | "source": [ 282 | "encoded_sequence_a" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": 4, 288 | "metadata": { 289 | "execution": { 290 | "iopub.execute_input": "2021-07-11T18:31:50.882727Z", 291 | "iopub.status.busy": "2021-07-11T18:31:50.882281Z", 292 | "iopub.status.idle": "2021-07-11T18:31:55.088013Z", 293 | "shell.execute_reply": "2021-07-11T18:31:55.086956Z", 294 | "shell.execute_reply.started": "2021-07-11T18:31:50.882693Z" 295 | }, 296 | "tags": [] 297 | }, 298 | "outputs": [], 299 | "source": [ 300 | "from transformers import BertTokenizer\n", 301 | "tokenizer = BertTokenizer.from_pretrained(\"bert-base-cased\")\n", 302 | "sequence_a = \"HuggingFace is based in NYC. So my\"\n", 303 | "sequence_b = \"Where is HuggingFace based?\"\n", 304 | "\n", 305 | "encoded_dict = tokenizer(sequence_a)\n", 306 | "decoded = tokenizer.decode(encoded_dict[\"input_ids\"])" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": 5, 312 | "metadata": { 313 | "execution": { 314 | "iopub.execute_input": "2021-07-11T18:31:58.636084Z", 315 | "iopub.status.busy": "2021-07-11T18:31:58.635532Z", 316 | "iopub.status.idle": "2021-07-11T18:31:58.641397Z", 317 | "shell.execute_reply": "2021-07-11T18:31:58.640814Z", 318 | "shell.execute_reply.started": "2021-07-11T18:31:58.636038Z" 319 | }, 320 | "tags": [] 321 | }, 322 | "outputs": [ 323 | { 324 | "data": { 325 | "text/plain": [ 326 | "'[CLS] HuggingFace is based in NYC. So my [SEP]'" 327 | ] 328 | }, 329 | "execution_count": 5, 330 | "metadata": {}, 331 | "output_type": "execute_result" 332 | } 333 | ], 334 | "source": [ 335 | "decoded" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": 6, 341 | "metadata": { 342 | "execution": { 343 | "iopub.execute_input": "2021-07-11T18:31:59.258210Z", 344 | "iopub.status.busy": "2021-07-11T18:31:59.257692Z", 345 | "iopub.status.idle": "2021-07-11T18:31:59.411923Z", 346 | "shell.execute_reply": "2021-07-11T18:31:59.410914Z", 347 | "shell.execute_reply.started": "2021-07-11T18:31:59.258165Z" 348 | }, 349 | "tags": [] 350 | }, 351 | "outputs": [ 352 | { 353 | "data": { 354 | "text/plain": [ 355 | "{'input_ids': [101, 20164, 10932, 2271, 7954, 1110, 1359, 1107, 17520, 119, 1573, 1139, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}" 356 | ] 357 | }, 358 | "execution_count": 6, 359 | "metadata": {}, 360 | "output_type": "execute_result" 361 | } 362 | ], 363 | "source": [ 364 | "encoded_dict" 365 | ] 366 | }, 367 | { 368 | "cell_type": "code", 369 | "execution_count": null, 370 | "metadata": {}, 371 | "outputs": [], 372 | "source": [] 373 | } 374 | ], 375 | "metadata": { 376 | "kernelspec": { 377 | "display_name": "Python 3.6 (ipykernel)", 378 | "language": "python", 379 | "name": "python3" 380 | }, 381 | "language_info": { 382 | "codemirror_mode": { 383 | "name": "ipython", 384 | "version": 3 385 | }, 386 | "file_extension": ".py", 387 | "mimetype": "text/x-python", 388 | "name": "python", 389 | "nbconvert_exporter": "python", 390 | "pygments_lexer": "ipython3", 391 | "version": "3.6.9" 392 | }, 393 | "widgets": { 394 | "application/vnd.jupyter.widget-state+json": { 395 | "state": {}, 396 | "version_major": 2, 397 | "version_minor": 0 398 | } 399 | } 400 | }, 401 | "nbformat": 4, 402 | "nbformat_minor": 4 403 | } 404 | -------------------------------------------------------------------------------- /streamlit/chatgpt_demo.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import os 3 | 4 | import time 5 | import jwt 6 | import requests 7 | 8 | def ask_gpt(key, model, max_tokens, temperature, content): 9 | url = "https://openai.api2d.net/v1/chat/completions" 10 | # 可以从 https://api2d.com 购买api,然后国内直接链接 11 | 12 | headers = { 13 | 'Content-Type': 'application/json', 14 | 'Authorization': 'Bearer ' + key 15 | } 16 | 17 | data = { 18 | "model": model, 19 | "max_tokens": max_tokens, 20 | "temperature": temperature, 21 | "messages": content 22 | } 23 | 24 | response = requests.post(url, headers=headers, json=data) 25 | return response.json() 26 | 27 | st.set_page_config(page_title="聊天机器人") 28 | 29 | with st.sidebar: 30 | st.title('通过API与大模型的对话') 31 | # st.write('支持的大模型包括的ChatGLM3和4') 32 | if 'API_TOKEN' in st.session_state and len(st.session_state['API_TOKEN']) > 1: 33 | st.success('API Token已经配置', icon='✅') 34 | key = st.session_state['API_TOKEN'] 35 | else: 36 | key = "" 37 | 38 | key = st.text_input('输入Token:', type='password', value=key) 39 | 40 | st.session_state['API_TOKEN'] = key 41 | 42 | model = st.selectbox("选择模型", ["gpt-3.5-turbo-0613", "gpt-4-0613"]) 43 | max_tokens = st.slider("max_tokens", 0, 2000, value=512) 44 | temperature = st.slider("temperature", 0.0, 2.0, value=0.8) 45 | 46 | # 初始化的对话 47 | if "messages" not in st.session_state.keys(): 48 | st.session_state.messages = [{"role": "assistant", "content": "你好我是ChatGPT,有什么可以帮助你的?"}] 49 | 50 | for message in st.session_state.messages: 51 | with st.chat_message(message["role"]): 52 | st.write(message["content"]) 53 | 54 | def clear_chat_history(): 55 | st.session_state.messages = [{"role": "assistant", "content": "你好我是ChatGPT,有什么可以帮助你的?"}] 56 | 57 | st.sidebar.button('清空聊天记录', on_click=clear_chat_history) 58 | 59 | 60 | if len(key) > 1: 61 | if prompt := st.chat_input(): 62 | st.session_state.messages.append({"role": "user", "content": prompt}) 63 | with st.chat_message("user"): 64 | st.markdown(prompt) 65 | 66 | with st.chat_message("assistant"): 67 | with st.spinner("请求中..."): 68 | full_response = ask_gpt(key, model, max_tokens, temperature, st.session_state.messages)['choices'][0]['message']['content'] 69 | st.markdown(full_response) 70 | 71 | message = {"role": "assistant", "content": full_response} 72 | st.session_state.messages.append(message) -------------------------------------------------------------------------------- /streamlit/glm_chat2pandas.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pickle import NONE 3 | import time 4 | from xml.etree.ElementInclude import include 5 | import jwt 6 | 7 | import streamlit as st 8 | import pandas as pd 9 | import requests 10 | 11 | import re 12 | def extract_code_blocks(markdown_content): 13 | code_blocks = re.findall(r'```(.*?)```', markdown_content, re.DOTALL) 14 | return code_blocks 15 | 16 | # 实际KEY,过期时间 17 | def generate_token(apikey: str, exp_seconds: int): 18 | try: 19 | id, secret = apikey.split(".") 20 | except Exception as e: 21 | raise Exception("invalid apikey", e) 22 | 23 | payload = { 24 | "api_key": id, 25 | "exp": int(round(time.time() * 1000)) + exp_seconds * 1000, 26 | "timestamp": int(round(time.time() * 1000)), 27 | } 28 | return jwt.encode( 29 | payload, 30 | secret, 31 | algorithm="HS256", 32 | headers={"alg": "HS256", "sign_type": "SIGN"}, 33 | ) 34 | 35 | def ask_glm(key, model, max_tokens, temperature, content): 36 | url = "https://open.bigmodel.cn/api/paas/v4/chat/completions" 37 | headers = { 38 | 'Content-Type': 'application/json', 39 | 'Authorization': generate_token(key, 1000) 40 | } 41 | 42 | data = { 43 | "model": model, 44 | "max_tokens": max_tokens, 45 | "temperature": temperature, 46 | "messages": content 47 | } 48 | 49 | response = requests.post(url, headers=headers, json=data) 50 | return response.json() 51 | 52 | st.set_page_config(page_title="聊天机器人") 53 | 54 | with st.sidebar: 55 | st.title('通过大模型进行数据分析') 56 | if 'API_TOKEN' in st.session_state and len(st.session_state['API_TOKEN']) > 1: 57 | st.success('API Token已经配置', icon='✅') 58 | key = st.session_state['API_TOKEN'] 59 | else: 60 | key = "7bf001734ef2fd7f7a55bf51dadd7cbb.BMAsoKRDFTmTEPwj" 61 | 62 | key = st.text_input('输入Token:', type='password', value=key) 63 | 64 | st.session_state['API_TOKEN'] = key 65 | 66 | model = st.selectbox("选择模型", ["glm-3-turbo", "glm-4"]) 67 | max_tokens = st.slider("max_tokens", 0, 2000, value=512) 68 | temperature = st.slider("temperature", 0.0, 2.0, value=0.8) 69 | 70 | if "messages" not in st.session_state.keys(): 71 | st.session_state.messages = [{"role": "assistant", "content": "你好我是ChatGLM,我可以自动的帮你进行数据分析和建模。"}] 72 | 73 | for message in st.session_state.messages: 74 | with st.chat_message(message["role"]): 75 | st.write(message["content"]) 76 | 77 | if "dataframe" not in st.session_state.keys(): 78 | uploaded_file = st.file_uploader("上传你需要分析的文件") 79 | if uploaded_file is not None: 80 | dataframe = pd.read_csv(uploaded_file) 81 | st.write(dataframe.head(10)) 82 | st.session_state["dataframe"] = dataframe 83 | else: 84 | dataframe = None 85 | else: 86 | dataframe = st.session_state["dataframe"] 87 | 88 | def clear_chat_history(): 89 | st.session_state.messages = [{"role": "assistant", "content": "你好我是ChatGLM,有什么可以帮助你的?"}] 90 | 91 | st.sidebar.button('清空聊天记录', on_click=clear_chat_history) 92 | 93 | print(len(st.session_state.messages)) 94 | 95 | if len(key) > 1 and len(st.session_state.messages) == 1 and dataframe is not None: 96 | data_info = f'''在python代码总,表格变量命名为df,数据集维度为:{dataframe.shape},列名为:{dataframe.columns} 97 | 98 | 字段类型为: 99 | {dataframe.dtypes.to_markdown()} 100 | ''' 101 | with st.chat_message("user"): 102 | st.markdown(data_info) 103 | st.session_state.messages.append({"role": "user", "content": data_info}) 104 | if len(key) > 1 and dataframe is not None: 105 | if prompt := st.chat_input(): 106 | st.session_state.messages.append({"role": "user", "content": prompt}) 107 | with st.chat_message("user"): 108 | st.markdown(prompt) 109 | 110 | with st.chat_message("assistant"): 111 | with st.spinner("请求中..."): 112 | full_response = ask_glm(key, model, max_tokens, temperature, st.session_state.messages)['choices'][0]['message']['content'] 113 | full_code = extract_code_blocks(full_response) 114 | full_code = [x.replace("python\n", "") for x in full_code if 'python' in x] 115 | for x in full_code: 116 | st.code(x) 117 | st.session_state.messages = st.session_state.messages[:-1] -------------------------------------------------------------------------------- /streamlit/glm_demo.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import os 3 | 4 | import time 5 | import jwt 6 | import requests 7 | 8 | # 实际KEY,过期时间 9 | def generate_token(apikey: str, exp_seconds: int): 10 | try: 11 | id, secret = apikey.split(".") 12 | except Exception as e: 13 | raise Exception("invalid apikey", e) 14 | 15 | payload = { 16 | "api_key": id, 17 | "exp": int(round(time.time() * 1000)) + exp_seconds * 1000, 18 | "timestamp": int(round(time.time() * 1000)), 19 | } 20 | return jwt.encode( 21 | payload, 22 | secret, 23 | algorithm="HS256", 24 | headers={"alg": "HS256", "sign_type": "SIGN"}, 25 | ) 26 | 27 | def ask_glm(key, model, max_tokens, temperature, content): 28 | url = "https://open.bigmodel.cn/api/paas/v4/chat/completions" 29 | headers = { 30 | 'Content-Type': 'application/json', 31 | 'Authorization': generate_token(key, 1000) 32 | } 33 | 34 | data = { 35 | "model": model, 36 | "max_tokens": max_tokens, 37 | "temperature": temperature, 38 | "messages": content 39 | } 40 | 41 | response = requests.post(url, headers=headers, json=data) 42 | return response.json() 43 | 44 | st.set_page_config(page_title="聊天机器人") 45 | 46 | with st.sidebar: 47 | st.title('通过API与大模型的对话') 48 | # st.write('支持的大模型包括的ChatGLM3和4') 49 | if 'API_TOKEN' in st.session_state and len(st.session_state['API_TOKEN']) > 1: 50 | st.success('API Token已经配置', icon='✅') 51 | key = st.session_state['API_TOKEN'] 52 | else: 53 | key = "" 54 | 55 | key = st.text_input('输入Token:', type='password', value=key) 56 | 57 | st.session_state['API_TOKEN'] = key 58 | 59 | model = st.selectbox("选择模型", ["glm-3-turbo", "glm-4"]) 60 | max_tokens = st.slider("max_tokens", 0, 2000, value=512) 61 | temperature = st.slider("temperature", 0.0, 2.0, value=0.8) 62 | 63 | # 初始化的对话 64 | if "messages" not in st.session_state.keys(): 65 | st.session_state.messages = [{"role": "assistant", "content": "你好我是ChatGLM,有什么可以帮助你的?"}] 66 | 67 | for message in st.session_state.messages: 68 | with st.chat_message(message["role"]): 69 | st.write(message["content"]) 70 | 71 | def clear_chat_history(): 72 | st.session_state.messages = [{"role": "assistant", "content": "你好我是ChatGLM,有什么可以帮助你的?"}] 73 | 74 | st.sidebar.button('清空聊天记录', on_click=clear_chat_history) 75 | 76 | 77 | if len(key) > 1: 78 | if prompt := st.chat_input(): 79 | st.session_state.messages.append({"role": "user", "content": prompt}) 80 | with st.chat_message("user"): 81 | st.markdown(prompt) 82 | 83 | with st.chat_message("assistant"): 84 | with st.spinner("请求中..."): 85 | full_response = ask_glm(key, model, max_tokens, temperature, st.session_state.messages)['choices'][0]['message']['content'] 86 | st.markdown(full_response) 87 | 88 | message = {"role": "assistant", "content": full_response} 89 | st.session_state.messages.append(message) --------------------------------------------------------------------------------