├── README.md └── 国科大-深度学习作业 ├── LSTM自动写诗 ├── classmodel.py ├── function.py ├── lstm自动写诗.pptx └── main.py ├── VIT实现CAFIR10分类 ├── VIT实现CIFAR10分类.docx ├── VIT实现CIFAR10分类.ipynb ├── VIT实现CIFAR10分类.pptx └── 助教代码魔改.ipynb ├── 手写数字识别 ├── 手写数字识别.docx ├── 手写数字识别.ipynb └── 手写数字识别.pptx └── 机器翻译 ├── bloom5-1.4b-半精度lora微调-机器翻译.ipynb ├── bloom5-2.5b-半精度lora微调-机器翻译.ipynb ├── bloom5-6.4b-4bit Qlora微调-机器翻译.ipynb ├── 大模型微调数据集 ├── test.csv └── train.csv ├── 手搓transformer数据集 ├── src.txt └── tgt.txt └── 机器翻译.ipynb /README.md: -------------------------------------------------------------------------------- 1 | ## 国科大深度学习作业 2024 春 2 | 作者较菜,仅用作交流学习。PPT和word报告写的水,代码写了比较详细的注释。 3 | ## 提示 4 | #### LSTM多训练一些时间,作者github上传的结果训练时间较短,效果一般。 5 | #### VIT作者电脑性能有限,大型网络跑不了,我自己搭的网络只是个完全自己手写VIT的demo,而且是简化版,大家还是参考我发的助教魔改版,这个简单些。 6 | #### 机器翻译我提供了两个版本,直接训练的版本,目前我bleu4得分17,不过这个数据集是我精心设计过的。三个微调模型的bleu均超过14,最高19,如果需要模型直接进行评估,可以加我微信hmf3053529702。 7 | -------------------------------------------------------------------------------- /国科大-深度学习作业/LSTM自动写诗/classmodel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class myDataset(Dataset): # 创建数据类 8 | def __init__(self, data, wvec, word2index): # 初始化 9 | """ 10 | :param data: 原始数据 11 | :param wvec: 词向量 12 | :param word2index: 单词对应的index 13 | """ 14 | self.data = data 15 | self.wvec = wvec 16 | self.word2index = word2index 17 | 18 | def __getitem__(self, index): # 获取一条数据并处理 19 | poem_index = [self.word2index[i] for i in self.data[index]] # 取出古诗并获取它每个字符对应的index 20 | x = poem_index[:-1] # 获取输入的index 21 | y = poem_index[1:] # 获取输出的index 22 | x_emb = self.wvec[x] # 获取输入的词向量 23 | return x_emb, np.array(y).astype(np.int64) 24 | 25 | def __len__(self): # 返回数据集大小 26 | return len(self.data) 27 | 28 | 29 | class PoemLstm(nn.Module): # 构建模型 30 | def __init__(self, embedding_dim, hidden_dim, output_dim, num_layers): 31 | # input_size 输入维度,hidden_dim 隐藏层维度,batch_first 是否将batch_size的放在第一个维度 bidirectional 是否使用双向lstm,num_layers 层数 32 | 33 | super(PoemLstm, self).__init__() 34 | self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True, 35 | bidirectional=False) 36 | 37 | self.dropout = nn.Dropout(0.2) # 随机失活 38 | 39 | self.flatten = nn.Flatten(0, 1) # 第0和第1维度进行展平 40 | 41 | # output_dim 输出维度 42 | self.linear = nn.Linear(hidden_dim, output_dim) 43 | 44 | self.loss = nn.CrossEntropyLoss() # 损失函数 45 | 46 | self.hidden_num = hidden_dim 47 | self.num_layers = num_layers 48 | 49 | 50 | def forward(self, x, h0=None, c0=None): # 前向传播 51 | x = x.to('cuda') 52 | if h0 is None or c0 is None: 53 | # h0与c0 的形状是(num_layers, batch_size, hidden_dim) 输入x的第0维就是batch_size 54 | h0 = torch.zeros((self.num_layers, x.shape[0], self.hidden_num), dtype=torch.float32) 55 | c0 = torch.zeros((self.num_layers, x.shape[0], self.hidden_num), dtype=torch.float32) 56 | h0 = h0.to('cuda') 57 | c0 = c0.to('cuda') 58 | x, (h0, c0) = self.lstm(x, (h0, c0)) # 当不提供h0,c0时,默认是0 59 | x = self.dropout(x) 60 | x = self.flatten(x) 61 | x = self.linear(x) 62 | return x, h0, c0 63 | -------------------------------------------------------------------------------- /国科大-深度学习作业/LSTM自动写诗/function.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import classmodel 4 | import numpy as np 5 | import torch 6 | from gensim.models import Word2Vec 7 | from torch.utils.data import DataLoader 8 | import matplotlib.pyplot as plt 9 | 10 | def process_1(path, path2): # 将古诗文件中的内容提取出来,并保存到poem列表中 11 | # path:古诗文件路径 这个路径下包含json文件 12 | # path2:保存古诗的路径 13 | 14 | poem_5 = [] # 保存5言古诗的列表 15 | poem_7 = [] # 保存7言古诗的列表 16 | jsons = os.listdir(path) # 读取古诗路径 17 | for js in jsons: # 读取json文件 18 | with open(os.path.join(path, js), 'r', encoding='utf-8') as f: # 读取古诗文件 19 | data = json.load(f) 20 | 21 | for i in data: # 将古诗内容保存到poem列表中,这里只处理五言与七言古诗 22 | da = ' '.join(''.join(i['paragraphs'])) # 读取古诗,将每首故事组成一个字符串 23 | if len(da) == 47: 24 | if da[10] != ',' or da[22] != '。' or da[34] != ',' or da[-1] != '。': 25 | continue 26 | else: 27 | poem_5.append(da) 28 | if len(da) == 63: 29 | if da[14] != ',' or da[30] != '。' or da[46] != ',' or da[-1] != '。': 30 | continue 31 | poem_7.append(da) 32 | with open(f'{path2}/poem_5.txt', 'w', encoding='utf-8') as f: # 将古诗保存为txt文件 33 | for i in poem_5: 34 | f.write(i + '\n') 35 | with open(f'{path2}/poem_7.txt', 'w', encoding='utf-8') as f: 36 | for i in poem_7: 37 | f.write(i + '\n') 38 | 39 | 40 | def process_2(path): # 使用word2vec训练词向量 41 | # path 古诗路径 42 | with open(path, 'r', encoding='utf-8') as f: 43 | data = f.read().split('\n') # 读取数据并每行切分 44 | 45 | # data 古诗数据 vector_size 词向量维度 min_count 忽略出现次数少于1的单词 workers 线程数 46 | model = Word2Vec(data, vector_size=100, min_count=1, workers=6, epochs=20) 47 | model.save(f'{path[:-4]}_W.bin') 48 | # 词向量 model.syn1neg 49 | # key_to_index model.wv.key_to_index 50 | # index_to_key model.wv.index_to_key 51 | 52 | 53 | def gen_poetry(wordsize, index_2_word, type1, wvec, model): # 生成古诗 54 | """ 55 | 古诗生成 56 | :param wordsize: 词表大小 57 | :param index_2_word: 用于将索引转换为词语 58 | :param type1: 需要生成的字符数 59 | :param wvec:词嵌入向量 60 | :param model: 模型 61 | :return: 62 | """ 63 | result = "" # 保存生成的故事 64 | wordindex = np.random.randint(0, wordsize, 1)[0] # 随机生成一个索引 65 | result += index_2_word[wordindex] # 查询该索引对应的字符用于古诗的第一个字 66 | h0, c0 = None, None # LSTM的两个中间矩阵 67 | model.eval() # 模型变为评估模式 68 | for i in range(type1): # 循环生成字符 69 | wordemd = torch.tensor(wvec[wordindex]).reshape(1, 1, -1) # 提取第一个字符的词嵌入向量 70 | pre, h0, c0 = model(wordemd, h0, c0) # 传入模型预测 71 | wordindex = int(torch.argmax(pre)) # 将模型预测结果转换为索引 72 | pre = index_2_word[wordindex] # 获取结果 73 | result += pre 74 | print(''.join(result.split(' '))) 75 | return result # 返回生成的古诗 76 | 77 | 78 | def train(path1, path2, peo, batchsize): 79 | """ 80 | 模型训练 81 | :param path1: 原json数据文件 82 | :param path2: 处理后的文件夹 83 | :param peo: 5 代表五言古诗,7 代表七言古诗 84 | :param batchsize: 批量 85 | :return: 86 | """ 87 | process_1(path1, path2) 88 | process_2(os.path.join(path2, f'poem_{peo}.txt')) 89 | data = Word2Vec.load(os.path.join(path2, f'poem_{peo}_W.bin')) 90 | wvec = data.syn1neg # 词嵌入矩阵 91 | word2index = data.wv.key_to_index # 获取词语到索引对应关系 92 | with open(os.path.join(path2, f'poem_{peo}.txt'), 'r', encoding='utf-8') as f: # 读取古诗文件 93 | data2 = f.read().split('\n') # 训练数据 94 | data2 = data2[:len(data2) - len(data2) % 100] 95 | 96 | dataset = classmodel.myDataset(data2, wvec, word2index) # 创建古诗数据集 97 | loader = DataLoader(dataset, batch_size=batchsize, shuffle=True) 98 | out_put, emd_num = wvec.shape 99 | 100 | hidden_num = 600 # 隐藏层神经元个数 101 | num_layer = 2 # 层数 102 | lr = 3e-4 103 | model = classmodel.PoemLstm(emd_num, hidden_num, out_put, num_layer) # 创建模型 104 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 优化器 105 | model.to('cuda') 106 | model.train() # 训练模式 107 | epoch = 24 108 | loss_all = [] 109 | for j in range(epoch): # 训练 110 | for i, (inputs, labels) in enumerate(loader): 111 | inputs = inputs.to('cuda') 112 | labels = labels.to('cuda') 113 | optimizer.zero_grad() 114 | pre, h0, c0 = model(inputs) 115 | loss = model.loss(pre, labels.reshape(-1)) 116 | loss.backward() 117 | optimizer.step() 118 | if i % 60 == 0: 119 | loss_all.append(float(loss)) 120 | print(j, i, loss, end='\n') 121 | gen_poetry(len(word2index), data.wv.index_to_key, (peo + 1) * 8 - 2, wvec, model) 122 | model.train() 123 | # if (j + 1) % 10 == 0: 124 | # lr = lr * 0.8 125 | # optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 优化器 126 | # if (j + 1) % 100 == 0: 127 | # torch.save(model.state_dict(), f'./data/poem__{peo}.pth') 128 | with open('loss.txt','w',encoding='utf-8') as f: 129 | f.write(str(loss_all)) 130 | return model 131 | -------------------------------------------------------------------------------- /国科大-深度学习作业/LSTM自动写诗/lstm自动写诗.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dxcf123/UCAS_DeepLearning_homework/c9a64b80b3377fbc590f8f64f495789eec4cb9ad/国科大-深度学习作业/LSTM自动写诗/lstm自动写诗.pptx -------------------------------------------------------------------------------- /国科大-深度学习作业/LSTM自动写诗/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import function 4 | print(torch.cuda.is_available()) 5 | path1 = r'tangshi' # 古诗所在的文件夹 6 | path2 = r'data' # 存放处理后的数据 7 | 8 | # 返回模型,可以进行后处理 9 | # 参数:训练集文件夹路径,处理后数据存放路径,5言/7言古诗,每个批次的大小 10 | model = function.train(path1, path2, 7, 64) 11 | # torch.save(model.state_dict(), './data/poem_7_model.pth') 12 | -------------------------------------------------------------------------------- /国科大-深度学习作业/VIT实现CAFIR10分类/VIT实现CIFAR10分类.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dxcf123/UCAS_DeepLearning_homework/c9a64b80b3377fbc590f8f64f495789eec4cb9ad/国科大-深度学习作业/VIT实现CAFIR10分类/VIT实现CIFAR10分类.docx -------------------------------------------------------------------------------- /国科大-深度学习作业/VIT实现CAFIR10分类/VIT实现CIFAR10分类.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "859542e6", 6 | "metadata": {}, 7 | "source": [ 8 | "## 1 导入相关包" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "a096699a", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import pickle\n", 19 | "import random\n", 20 | "import numpy as np\n", 21 | "import torch\n", 22 | "import torch.nn as nn\n", 23 | "from torch.utils.data import Dataset\n", 24 | "import torch.nn.functional as F\n", 25 | "from torch.utils.data import DataLoader\n", 26 | "import torchvision\n", 27 | "from torchvision import transforms" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "id": "6baf97a8", 33 | "metadata": {}, 34 | "source": [ 35 | "## 2 获取数据集" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 2, 41 | "id": "207309e8", 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "def get_dataset(path, batch_size=64, transform=None):\n", 46 | " \"\"\"\n", 47 | " 加载MNIST数据集并将其转换为DataLoader对象。\n", 48 | " :param path: 数据集路径\n", 49 | " :param batch_size: 批处理大小\n", 50 | " :param transform: 数据预处理\n", 51 | " :return: 训练集与测试集的DataLoader对象\n", 52 | " \"\"\"\n", 53 | " if transform is None:\n", 54 | " trans_train = torchvision.transforms.Compose(\n", 55 | " [transforms.ToTensor(),\n", 56 | " transforms.RandomResizedCrop(56), # 将图像的大小随机裁剪并放大到56\n", 57 | " transforms.RandomHorizontalFlip(), # 以给定的概率随机水平旋转给定的PIL的图像,默认为0.5;\n", 58 | " transforms.Normalize(mean=[0.485, 0.456, 0.406],\n", 59 | " std=[0.229, 0.224, 0.225])])\n", 60 | " trans_valid = torchvision.transforms.Compose(\n", 61 | " [transforms.ToTensor(),\n", 62 | " transforms.Resize(64), # 将图像放大到64\n", 63 | " transforms.CenterCrop(56),#依据给定的size从中心裁剪\n", 64 | "\n", 65 | " # 将PIL Image或者ndarray 转换为tensor,并且归一化至[0-1]#归一化至[0-1]是直接除以255\n", 66 | " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])\n", 67 | "\n", 68 | " train = CIFAR_Dataset(path, train=True, transform=trans_train)\n", 69 | " tset = CIFAR_Dataset(path, train=False, transform=trans_valid)\n", 70 | "\n", 71 | " # 创建dataloader对象\n", 72 | " train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)\n", 73 | " test_loader = DataLoader(tset, batch_size=batch_size, shuffle=False)\n", 74 | "\n", 75 | " return train_loader, test_loader" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "id": "b11bcca7", 81 | "metadata": {}, 82 | "source": [ 83 | "## 3 数据集类" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 3, 89 | "id": "c4928231", 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "class CIFAR_Dataset(Dataset):\n", 94 | " def __init__(self, data_dir, train, transform): # 数据集的位置,训练集还是测试集,以及数据预处理的变换\n", 95 | " super(CIFAR_Dataset, self).__init__()\n", 96 | " self.data_dir = data_dir\n", 97 | " self.train = train\n", 98 | " self.transform = transform\n", 99 | " self.data = []\n", 100 | " self.targets = []\n", 101 | "\n", 102 | " # 判断是否为训练集\n", 103 | " if self.train:\n", 104 | " for i in range(5): # CIFAR-10训练数据集有5个文件,所以要循环5次读取\n", 105 | " with open(data_dir + '/cifar-10-batches-py/data_batch_' + str(i + 1), 'rb') as f: # 二进制格式读取文件\n", 106 | " entry = pickle.load(f, encoding='latin1') # 对文件进行反序列化成python对象\n", 107 | " self.data.append(entry['data']) # 读取文件中data部分的数据并将其添加到self.data中\n", 108 | " self.targets.extend(entry['labels']) # 读取文件中labels部分的数据并将其添加到self.targets中\n", 109 | " else: # 操作与上述相同,只是读取的是测试集\n", 110 | " with open(data_dir + '/cifar-10-batches-py/test_batch', 'rb') as f:\n", 111 | " entry = pickle.load(f, encoding='latin1')\n", 112 | " self.data.append(entry['data'])\n", 113 | " self.targets.extend(entry['labels'])\n", 114 | " # 上面的操作是将数据添加到列表中,就会对整体数据添加一个纬度,\n", 115 | " # 比如训练集本身是n*3*32*32,现在变成了 5*(n/5)*3*32*32,所以需要reshape一下,\n", 116 | " # -1将5与n/5这两个纬度合并,变成n\n", 117 | " self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)\n", 118 | " # 对纬度进行转置,这个跟图片数组相关\n", 119 | " self.data = self.data.transpose((0, 2, 3, 1))\n", 120 | "\n", 121 | " # 获取数据集长度\n", 122 | " def __len__(self):\n", 123 | " return len(self.data)\n", 124 | "\n", 125 | " # 让对象能像数组一样根据下标访问\n", 126 | " def __getitem__(self, idx):\n", 127 | " # 这里是自己构建one-hot数组,可以利用torch.nn.functional 中的ont_hot函数进行变换\n", 128 | " label = torch.zeros(10)\n", 129 | " label[self.targets[idx]] = 1.\n", 130 | "\n", 131 | " # 判断是否有预处理函数,如果有则对数据进行预处理\n", 132 | " if self.transform:\n", 133 | " image = self.transform(self.data[idx])\n", 134 | " if self.train and idx > 0 and idx % 5 == 0:\n", 135 | " # 获取一个数据集长度的随机数\n", 136 | " mixup_idx = random.randint(0, len(self.data) - 1)\n", 137 | " # 设置one_hot数组\n", 138 | " mixup_label = torch.zeros(10)\n", 139 | " label[self.targets[mixup_idx]] = 1.\n", 140 | "\n", 141 | " # 如果存在预处理函数,则对数据集进行预处理\n", 142 | " if self.transform:\n", 143 | " mixup_image = self.transform(self.data[mixup_idx])\n", 144 | "\n", 145 | " # 根据beta分布的随机数,对数据进行cutmix操作\n", 146 | " mask = np.ones_like(image) # 生成mask矩阵,用于对图像进行cut操作\n", 147 | " la = float(np.random.beta(0.5, 0.5, 1)) # 生成一个符合beta分布的随机数\n", 148 | " # 随机获取切割的部分\n", 149 | " rx = np.int8(np.random.uniform(0, 32, 1))[0]\n", 150 | " ry = np.int8(np.random.uniform(0, 32, 1))[0]\n", 151 | " rw = np.int8(np.power(1 - la, 0.5) * 32)\n", 152 | " rh = np.int8(np.power(1 - la, 0.5) * 32)\n", 153 | " if rx > rw:\n", 154 | " rx, rw = rw, rx\n", 155 | " if ry > rh:\n", 156 | " ry, rh = rh, ry\n", 157 | " # 对图像进行cut操作\n", 158 | " mask[rx:rw, ry:rh, :] = 0\n", 159 | " # 对图像进行mix操作\n", 160 | " image = image * mask + mixup_image * (1 - mask)\n", 161 | " label = la * label + (1 - la) * mixup_label\n", 162 | " return image, label" 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "id": "ea772825", 168 | "metadata": {}, 169 | "source": [ 170 | "## 4 取patch" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 4, 176 | "id": "39893c39", 177 | "metadata": {}, 178 | "outputs": [], 179 | "source": [ 180 | "def image2embed(image, patch_size):\n", 181 | " \"\"\"\n", 182 | " 将图像转换为嵌入向量\n", 183 | " :param image: 图片 batch_size * channel * h * w\n", 184 | " :param patch_size: 块大小\n", 185 | " :return:\n", 186 | " \"\"\"\n", 187 | " patch = F.unfold(image, kernel_size=patch_size, stride=patch_size).transpose(-1, -2) # 将图片分成块,它实质是将卷积的部分直接取出来\n", 188 | " return patch " 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "id": "ff8b26bf", 194 | "metadata": {}, 195 | "source": [ 196 | "## 5 Embedding层" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 5, 202 | "id": "d8695769", 203 | "metadata": {}, 204 | "outputs": [], 205 | "source": [ 206 | "class Embedding(nn.Module):\n", 207 | " def __init__(self, channel, batchsize, psize, patchsize, emb_dim, device):\n", 208 | " \"\"\"\n", 209 | " 词嵌入层\n", 210 | " :param batchsize: 批量大小\n", 211 | " :param psize: 用于位置编码的一个参数,它的大小等于 图片通道数 * (一张图片一行数据的大小//patchsize)²\n", 212 | " :param patchsize: 提取图块的边长\n", 213 | " :param emb_dim: 嵌入维度\n", 214 | " :param device: 运算设备\n", 215 | " \"\"\"\n", 216 | " super(Embedding, self).__init__()\n", 217 | " self.pathF = image2embed # 导入提取图片块的函数\n", 218 | " self.patchszie = patchsize # 边长\n", 219 | " self.emb_dim = emb_dim # 嵌入纬度\n", 220 | " self.l1 = nn.Linear(patchsize * patchsize * channel, emb_dim) # 用于将图片块映射为为嵌入纬度大小\n", 221 | " # 定义一个矩阵嵌入到输入数据开头,表示数据的开始\n", 222 | " self.cls_token_emb = torch.randn(batchsize, 1, self.emb_dim, requires_grad=True, device=device)\n", 223 | " # 位置编码\n", 224 | " self.position_emb = torch.randn(1, psize, self.emb_dim, requires_grad=True, device=device)\n", 225 | "\n", 226 | " def forward(self, x): # 前向传播\n", 227 | " \"\"\"\n", 228 | " 这里将图片块转换为嵌入纬度,加入了开头与位置编码\n", 229 | " :param x:\n", 230 | " :return:\n", 231 | " \"\"\"\n", 232 | "\n", 233 | " x = self.pathF(x, self.patchszie)\n", 234 | " x = self.l1(x)\n", 235 | " x = torch.cat((self.cls_token_emb[:x.shape[0]], x), dim=1)\n", 236 | " x += self.position_emb\n", 237 | " return x" 238 | ] 239 | }, 240 | { 241 | "cell_type": "markdown", 242 | "id": "e86b94c6", 243 | "metadata": {}, 244 | "source": [ 245 | "## 6 注意力" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 6, 251 | "id": "94f0a575", 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "class Attention(nn.Module):\n", 256 | " def __init__(self, emb_dim=64, head=8):\n", 257 | " \"\"\"\n", 258 | " 注意力机制\n", 259 | " :param emb_dim: 词嵌入纬度\n", 260 | " :param head: 多头头数\n", 261 | " \"\"\"\n", 262 | " super(Attention, self).__init__()\n", 263 | " assert emb_dim % head == 0 # 保证emb_dim可以整除head,注意力机制的词嵌入维度需要是多头的n倍\n", 264 | " self.emb_dim = emb_dim # 词嵌入纬度\n", 265 | " self.head = head # 多头\n", 266 | " self.head_dim = emb_dim // head\n", 267 | "\n", 268 | " # q k v 三个输入的线性层 维度变换 emb_dim → emb_dim\n", 269 | " self.query_L = nn.Linear(emb_dim, emb_dim)\n", 270 | " self.key_L = nn.Linear(emb_dim, emb_dim)\n", 271 | " self.value_L = nn.Linear(emb_dim, emb_dim)\n", 272 | "\n", 273 | " def forward(self, q, k, v):\n", 274 | " \"\"\"\n", 275 | " 前向传播 q,k,v为transformer的三个输入,这里做了注意力机制的运算\n", 276 | " :return:\n", 277 | " \"\"\"\n", 278 | " # q,k,v的形状为 batchsize 长度 词嵌入纬度 ,下面batchsize,长度,词嵌入纬度,头数,分别用 B L D H 代替\n", 279 | " # 这里进行多头注意力机制进行计算,因此需要进行纬度变换\n", 280 | " x_q = self.query_L(q) # q 进行线性层变换 B,L,D → B,L,D\n", 281 | " x_q = x_q.reshape(q.shape[0], q.shape[1], self.head, self.head_dim) # B,L,D → B,L,H,D/H\n", 282 | " x_q = x_q.transpose(1, 2) # B,L,H,D/H → B,H,L,D/H\n", 283 | " x_q = x_q.reshape(-1, q.shape[1], self.head_dim) # B,H,L,D/H → BH,L,D/H\n", 284 | "\n", 285 | " # k,v操作与q相同\n", 286 | " x_k = self.key_L(k).reshape(k.shape[0], k.shape[1], self.head, self.head_dim)\n", 287 | " x_k = x_k.transpose(1, 2)\n", 288 | " x_k = x_k.reshape(-1, k.shape[1], self.head_dim)\n", 289 | "\n", 290 | " x_v = self.value_L(v).reshape(v.shape[0], v.shape[1], self.head, self.head_dim)\n", 291 | " x_v = x_v.transpose(1, 2)\n", 292 | " x_v = x_v.reshape(-1, v.shape[1], self.head_dim)\n", 293 | " \n", 294 | "\n", 295 | " # 注意力机制计算,这里需要对x_K进行转置才符合运算规则\n", 296 | " x_k = x_k.transpose(1, 2) # BH,L,BH → BH,D/H,L\n", 297 | " x_atten = torch.matmul(x_q, x_k) / (self.head_dim ** 0.5) # q,k相乘并除以根号D → BH,L,L\n", 298 | " x_atten = F.softmax(x_atten, dim=-1)\n", 299 | "\n", 300 | " x_out = torch.matmul(x_atten, x_v) # → BH,L,D/H\n", 301 | " x_out = x_out.reshape(-1, self.head, x_out.shape[1], x_out.shape[2]) # BH,L,D/H → B,H,L,D/H\n", 302 | " x_out = x_out.transpose(1, 2) # B,H,L,D/H → B,L,H,D/H\n", 303 | " x = x_out.reshape(-1, x_out.shape[1], self.head * self.head_dim) # B,L,H,D/H->B,L,D\n", 304 | " return x" 305 | ] 306 | }, 307 | { 308 | "cell_type": "markdown", 309 | "id": "1b48bf3c", 310 | "metadata": { 311 | "scrolled": true 312 | }, 313 | "source": [ 314 | "## 7 Encoder" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": 7, 320 | "id": "8f8f74a8", 321 | "metadata": {}, 322 | "outputs": [], 323 | "source": [ 324 | "class Encoder(nn.Module):\n", 325 | " def __init__(self, emb_dim=512, head=8):\n", 326 | " \"\"\"\n", 327 | " 编码器\n", 328 | " :param emb_dim: 嵌入维度\n", 329 | " :param head: 多头头数\n", 330 | " \"\"\"\n", 331 | " super(Encoder, self).__init__()\n", 332 | " self.Attention = Attention(emb_dim, head) # 注意力机制\n", 333 | " # 前馈全连接子层\n", 334 | " self.l1 = nn.Linear(emb_dim, 256)\n", 335 | " self.l2 = nn.Linear(256, 256)\n", 336 | " # 规范化层\n", 337 | " self.norm1 = nn.LayerNorm(emb_dim)\n", 338 | " self.norm2 = nn.LayerNorm(emb_dim)\n", 339 | "\n", 340 | " def forward(self, q, k, v): # 前向传播计算\n", 341 | " x = self.norm1(q)\n", 342 | " # 注意力机制\n", 343 | " x = self.Attention(q, q, q)\n", 344 | " # 规范化层\n", 345 | " x = x + q\n", 346 | " # 全连接层\n", 347 | " x_ = self.l1(x)\n", 348 | " x_ = F.gelu(x_)\n", 349 | " x_ = self.l2(x_)\n", 350 | " # 规范化层\n", 351 | " x = self.norm2(x + x_)\n", 352 | " return x" 353 | ] 354 | }, 355 | { 356 | "cell_type": "markdown", 357 | "id": "31663c28", 358 | "metadata": {}, 359 | "source": [ 360 | "## 8 VIT" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": 8, 366 | "id": "36570bbf", 367 | "metadata": {}, 368 | "outputs": [], 369 | "source": [ 370 | "class VIT(nn.Module):\n", 371 | " def __init__(self, channel, batchsize, psize, patchsize, emb_dim, head, device, N=3):\n", 372 | " \"\"\"\n", 373 | " VIT模型\n", 374 | " :param batchsize: 批量\n", 375 | " :param psize: 用于位置编码的一个参数,它的大小等于 图片通道数 * (一张图片一行数据的大小//patchsize)²\n", 376 | " :param patchsize: 图片块边长\n", 377 | " :param emb_dim: 嵌入维度\n", 378 | " :param head: 多头\n", 379 | " :param device: 运算设备\n", 380 | " \"\"\"\n", 381 | " super(VIT, self).__init__()\n", 382 | " self.Embed = Embedding(channel, batchsize, psize, patchsize, emb_dim, device) # 词嵌入层\n", 383 | " self.Encoder = torch.nn.ModuleList([Encoder(emb_dim, head) for _ in range(N)])\n", 384 | " # 用于分类的全连接层\n", 385 | " self.l1 = nn.Linear(256, 256)\n", 386 | " self.l2 = nn.Linear(256, 10) # CIFAR10 10分类\n", 387 | "\n", 388 | " def forward(self, x):\n", 389 | " # 词嵌入层\n", 390 | " x = self.Embed(x)\n", 391 | " # 编码器层\n", 392 | " for i in self.Encoder:\n", 393 | " x = i(x, x, x)\n", 394 | " # 分类层\n", 395 | " x = self.l1(x)\n", 396 | " x = F.relu(x)\n", 397 | " x = self.l2(x)\n", 398 | " return x" 399 | ] 400 | }, 401 | { 402 | "cell_type": "markdown", 403 | "id": "73216d6a", 404 | "metadata": {}, 405 | "source": [ 406 | "## 9 准确率函数" 407 | ] 408 | }, 409 | { 410 | "cell_type": "code", 411 | "execution_count": 9, 412 | "id": "491b4b66", 413 | "metadata": {}, 414 | "outputs": [], 415 | "source": [ 416 | "def testacc(model, test, epoch, device):\n", 417 | " \"\"\"\n", 418 | " 测试准确率\n", 419 | " :param model: 模型\n", 420 | " :param test: 测试集\n", 421 | " :param epoch: 第epoch轮\n", 422 | " :param device: 设备\n", 423 | " :return:\n", 424 | " \"\"\"\n", 425 | " all = 0 # 样本总数\n", 426 | " right = 0 # 正确个数\n", 427 | " model.eval()\n", 428 | " with torch.no_grad():\n", 429 | " for i, (data, label) in enumerate(test):\n", 430 | " all += 128\n", 431 | " data = data.to(device)\n", 432 | " label = label.to(device)\n", 433 | " pre = model(data)[:, 0, :]\n", 434 | " pre = torch.argmax(pre, dim=-1) # 获取最大值标签\n", 435 | " label=torch.argmax(label, dim=-1)\n", 436 | " right += (pre == label).sum() # 统计每轮正确的数量\n", 437 | " print(epoch, right / all)\n", 438 | " return right / all\n" 439 | ] 440 | }, 441 | { 442 | "cell_type": "markdown", 443 | "id": "f00de2ce", 444 | "metadata": {}, 445 | "source": [ 446 | "## 10 训练函数" 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "execution_count": 10, 452 | "id": "751e52f9", 453 | "metadata": {}, 454 | "outputs": [], 455 | "source": [ 456 | "import time\n", 457 | "def train(path, batchsize, patchsize, emb_dim=512, head=8, device='cpu', lr=3e-4, N=6):\n", 458 | " \"\"\"\n", 459 | " 训练模型\n", 460 | " :param path: 数据集路径\n", 461 | " :param batchsize: 批量大小\n", 462 | " :param patchsize: 块大小\n", 463 | " :param emb_dim: 嵌入纬度\n", 464 | " :param head: 多头\n", 465 | " :param device: 设备\n", 466 | " :param lr: 学习率\n", 467 | " :param N: Encoder层数\n", 468 | " :return: 模型\n", 469 | " \"\"\"\n", 470 | " train, test = get_dataset(path, batchsize)\n", 471 | " # 损失函数\n", 472 | " lossf = nn.CrossEntropyLoss()\n", 473 | "\n", 474 | " # 用于位置编码的一个参数,它的大小等于 图片通道数 * (一张图片一行数据的大小//patchsize)²\n", 475 | " psize = (56 // patchsize) * (56 // patchsize) + 1\n", 476 | " channel = 3 # 图片通道数\n", 477 | "\n", 478 | " # 创建VIT模型\n", 479 | " model = VIT(channel, batchsize, psize, patchsize, emb_dim, head, device, N=N)\n", 480 | " # 设置优化器\n", 481 | " optm = torch.optim.Adam(model.parameters(), lr=lr)\n", 482 | " model = model.to(device)\n", 483 | " loss_all=[]\n", 484 | " acc_=[]\n", 485 | " t1=time.time()\n", 486 | " for epo in range(400):\n", 487 | " model.train()\n", 488 | " for i, (data, label) in enumerate(train):\n", 489 | " data = data.to(device)\n", 490 | " label = label.to(device)\n", 491 | " optm.zero_grad()\n", 492 | " pre = model(data)[:, 0, :]\n", 493 | " loss = lossf(pre, label.float())\n", 494 | " loss.backward()\n", 495 | " optm.step()\n", 496 | " loss_all.append(float(loss))\n", 497 | " acc_.append(float(testacc(model, test, epo, device)))\n", 498 | " t2=time.time()\n", 499 | " print(t2-t1)\n", 500 | " with open('loss.txt','w',encoding=\"utf-8\") as f:\n", 501 | " f.write(str(loss_all))\n", 502 | " with open('acc.txt','w',encoding=\"utf-8\") as f:\n", 503 | " f.write(str(acc_))\n", 504 | " torch.save(model.state_dict(),'./model.pt')\n", 505 | " return model" 506 | ] 507 | }, 508 | { 509 | "cell_type": "code", 510 | "execution_count": null, 511 | "id": "3f5b832d", 512 | "metadata": { 513 | "scrolled": true 514 | }, 515 | "outputs": [ 516 | { 517 | "name": "stdout", 518 | "output_type": "stream", 519 | "text": [ 520 | "0 tensor(0.4356, device='cuda:0')\n", 521 | "50.01631164550781\n" 522 | ] 523 | } 524 | ], 525 | "source": [ 526 | "batchsize = 128\n", 527 | "patchsize = 4\n", 528 | "path = r'C:\\Users\\30535\\Desktop\\CodeProgram\\Python\\deepstudy\\data'\n", 529 | "\n", 530 | "model = train(path, batchsize, patchsize, device='cuda')" 531 | ] 532 | }, 533 | { 534 | "cell_type": "code", 535 | "execution_count": null, 536 | "id": "9e34347e", 537 | "metadata": {}, 538 | "outputs": [], 539 | "source": [ 540 | "%matplotlib qt\n", 541 | "import matplotlib.pyplot as plt\n", 542 | "import re \n", 543 | "with open('acc.txt','r',encoding=\"utf-8\") as f:\n", 544 | " data=eval(f.read())\n", 545 | "# d=[]\n", 546 | "# s='\\d.\\d+'\n", 547 | "# for i in data:\n", 548 | "# aa=re.findall(s,i)[0]\n", 549 | "# d.append(float(aa))\n", 550 | "\n", 551 | "plt.figure()\n", 552 | "plt.plot([i for i in range(len(data))],data)\n", 553 | "plt.xlabel('epoch')\n", 554 | "plt.ylabel('loss')\n", 555 | "plt.show()" 556 | ] 557 | } 558 | ], 559 | "metadata": { 560 | "kernelspec": { 561 | "display_name": "Python 3 (ipykernel)", 562 | "language": "python", 563 | "name": "python3" 564 | }, 565 | "language_info": { 566 | "codemirror_mode": { 567 | "name": "ipython", 568 | "version": 3 569 | }, 570 | "file_extension": ".py", 571 | "mimetype": "text/x-python", 572 | "name": "python", 573 | "nbconvert_exporter": "python", 574 | "pygments_lexer": "ipython3", 575 | "version": "3.11.4" 576 | } 577 | }, 578 | "nbformat": 4, 579 | "nbformat_minor": 5 580 | } 581 | -------------------------------------------------------------------------------- /国科大-深度学习作业/VIT实现CAFIR10分类/VIT实现CIFAR10分类.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dxcf123/UCAS_DeepLearning_homework/c9a64b80b3377fbc590f8f64f495789eec4cb9ad/国科大-深度学习作业/VIT实现CAFIR10分类/VIT实现CIFAR10分类.pptx -------------------------------------------------------------------------------- /国科大-深度学习作业/VIT实现CAFIR10分类/助教代码魔改.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "66d04f04", 6 | "metadata": {}, 7 | "source": [ 8 | "## 1. 导入相关包" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "cd76a5d9", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "# 加载和预处理数据集\n", 19 | "import torch\n", 20 | "from einops import rearrange, repeat\n", 21 | "from einops.layers.torch import Rearrange\n", 22 | "from torch import nn\n", 23 | "from torchvision import transforms\n", 24 | "import torchvision\n", 25 | "from torch.utils.data import DataLoader" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "id": "9a6e6e79", 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "trans_train = transforms.Compose(\n", 36 | " [transforms.RandomCrop(32,padding=4), # 将给定图像随机裁剪为不同的大小和宽高比,#然后缩放所裁剪得到的图像为制定的大小;\n", 37 | " # (即先随机采集,然后对裁剪得到的图像缩放为同一大小)默认scale=(0.08,1.0)\n", 38 | " transforms.RandomHorizontalFlip(), # 以给定的概率随机水平旋转给定的PIL的图像,默认为0.5;\n", 39 | " transforms.ToTensor(),\n", 40 | " transforms.Normalize(mean=[0.485, 0.456, 0.406],\n", 41 | " std=[0.229, 0.224, 0.225])])\n", 42 | "trans_valid = transforms.Compose(\n", 43 | " [ #transforms.Resize(64), # 是按照比例把图像最小的一个边长放缩到256,另一边按照相同比例放缩。\n", 44 | " #transforms.CenterCrop(28),#依据给定的size从中心裁剪\n", 45 | " transforms.ToTensor(),\n", 46 | " # 将PIL Image或者ndarray 转换为tensor,并且归一化至[0-1]#归一化至[0-1]是直接除以255,若自己的ndarray数据尺度有变化,则需要自行修改。\n", 47 | " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])\n", 48 | "# 对数据按通道进行标准化,即先减均值,再除以标准差,注意是 hwc\n", 49 | "# trainset = torchvision.datasets.CIFAR10(root=r\"H:\\datasets\\data\", train=True, download=True, transform=trans_train)\n", 50 | "# trainloader = DataLoader(trainset, batch_size=256, shuffle=True)\n", 51 | "testset = torchvision.datasets.CIFAR10(root=r'H:\\datasets\\data', train=False,\n", 52 | " download=False, transform=trans_valid)\n", 53 | "testloader = DataLoader(testset, batch_size=256, shuffle=False)\n", 54 | "classes = ('plane', 'car', 'bird', 'cat',\n", 55 | " 'deer', 'dog', 'frog', 'horse ', 'ship', 'truck ')" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "id": "b1e23ad4", 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 3, 69 | "id": "c0228444", 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "class Attention(nn.Module):\n", 74 | " def __init__(self, dim=128, heads=8, dim_head=64, dropout=0.):\n", 75 | " super(Attention, self).__init__()\n", 76 | " inner_dim = dim_head * heads\n", 77 | " project_out = not (heads == 1 and dim_head == dim)\n", 78 | " self.heads = heads\n", 79 | " self.scale = dim_head ** -0.5\n", 80 | " self.norm = nn.LayerNorm(dim)\n", 81 | " self.attend = nn.Softmax(dim=-1)\n", 82 | " self.dropout = nn.Dropout(dropout)\n", 83 | " self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)\n", 84 | " self.to_out = nn.Sequential(\n", 85 | " nn.Linear(inner_dim, dim), nn.Dropout(dropout)\n", 86 | " ) if project_out else nn.Identity()\n", 87 | "\n", 88 | " def forward(self, x):\n", 89 | " x = self.norm(x)\n", 90 | " qkv = self.to_qkv(x).chunk(3, dim=-1)\n", 91 | " q, k, v = map(lambda t: rearrange(t, 'b n (h d)->b h n d', h=self.heads), qkv)\n", 92 | " dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale\n", 93 | " attn = self.attend(dots)\n", 94 | " attn = self.dropout(attn)\n", 95 | " out = torch.matmul(attn, v)\n", 96 | " out = rearrange(out, 'b h n d -> b n (h d)')\n", 97 | " return self.to_out(out)\n", 98 | "\n", 99 | "\n", 100 | "class ViT(nn.Module):\n", 101 | " def __init__(self, num_classes=10, dim=512, depth=6, heads=8, mlp_dim=512, pool='cls', channels=3, dim_head=64,\n", 102 | " dropout=0.1, emb_dropout=0.1):\n", 103 | " super().__init__()\n", 104 | " image_height = 32\n", 105 | " patch_height = 4\n", 106 | " image_width = 32\n", 107 | " patch_width = 4\n", 108 | " num_patches = (image_height // patch_height) * (image_width // patch_width)\n", 109 | " patch_dim = channels * patch_height * patch_width\n", 110 | " self.to_patch_embedding = nn.Sequential(\n", 111 | " Rearrange('b c (h p1) (w p2)->b (h w) (p1 p2 c) ', p1=patch_height, p2=patch_width),\n", 112 | " nn.LayerNorm(patch_dim),\n", 113 | " nn.Linear(patch_dim, dim), nn.LayerNorm(dim), )\n", 114 | " self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))\n", 115 | " self.cls_token = nn.Parameter(\n", 116 | " torch.randn(1, 1, dim))\n", 117 | " self.dropout = nn.Dropout(emb_dropout)\n", 118 | " self.transformer = Encoder(dim, depth, heads, dim_head, mlp_dim, dropout)\n", 119 | " self.pool = pool\n", 120 | " self.to_latent = nn.Identity()\n", 121 | " self.mlp_head = nn.Linear(dim, num_classes)\n", 122 | "\n", 123 | " def forward(self, img):\n", 124 | " x = self.to_patch_embedding(img)\n", 125 | " b, n, _ = x.shape\n", 126 | "\n", 127 | " cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b)\n", 128 | " x = torch.cat((cls_tokens, x), dim=1)\n", 129 | " x += self.pos_embedding[:, :(n + 1)]\n", 130 | " x = self.dropout(x)\n", 131 | " x = self.transformer(x)\n", 132 | " x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]\n", 133 | " x = self.to_latent(x)\n", 134 | " return self.mlp_head(x)\n", 135 | "\n", 136 | "\n", 137 | "class Encoder(nn.Module):\n", 138 | " def __init__(self, dim=512, depth=6, heads=8, dim_head=64, mlp_dim=512, dropout=0.):\n", 139 | " super().__init__()\n", 140 | " self.norm = nn.LayerNorm(dim)\n", 141 | " self.layers = nn.ModuleList([])\n", 142 | " for _ in range(depth):\n", 143 | " self.layers.append(nn.ModuleList([\n", 144 | " Attention(dim=dim, heads=heads, dim_head=dim_head, dropout=dropout),\n", 145 | " FeedForward(dim, mlp_dim, dropout=dropout)\n", 146 | " ]))\n", 147 | "\n", 148 | " def forward(self, x):\n", 149 | " for attn, ff in self.layers:\n", 150 | " x = attn(x) + x\n", 151 | " x = ff(x) + x\n", 152 | " return self.norm(x)\n", 153 | "\n", 154 | "\n", 155 | "class FeedForward(nn.Module):\n", 156 | " def __init__(self, dim, hidden_dim, dropout=0.):\n", 157 | " super().__init__()\n", 158 | " self.net = nn.Sequential(\n", 159 | " nn.LayerNorm(dim),\n", 160 | " nn.Linear(dim, hidden_dim), nn.GELU(),\n", 161 | " nn.Dropout(dropout),\n", 162 | " nn.Linear(hidden_dim, dim), nn.Dropout(dropout)\n", 163 | " )\n", 164 | "\n", 165 | " def forward(self, x):\n", 166 | " return self.net(x)\n" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 4, 172 | "id": "a32b1faa", 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "import time\n", 177 | "\n", 178 | "\n", 179 | "def train(epoch):\n", 180 | " print(' \\nEpoch: %d' % epoch)\n", 181 | " model = ViT()\n", 182 | " device = 'cuda'\n", 183 | " optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)\n", 184 | " net = model.to(device)\n", 185 | " net.to(device)\n", 186 | " train_loss = 0\n", 187 | " correct = 0\n", 188 | " total = 0\n", 189 | " criterion = torch.nn.CrossEntropyLoss()\n", 190 | " t = time.time()\n", 191 | " loss_all=[]\n", 192 | " acc_=[]\n", 193 | " for e in range(epoch):\n", 194 | " net.train()\n", 195 | " trainset = torchvision.datasets.CIFAR10(root=r\"H:\\datasets\\data\", train=True, download=True, transform=trans_train)\n", 196 | " trainloader = DataLoader(trainset, batch_size=256, shuffle=True)\n", 197 | " for batch_idx, (inputs, targets) in enumerate(trainloader):\n", 198 | " inputs, targets = inputs.to(device), targets.to(device)\n", 199 | " optimizer.zero_grad()\n", 200 | " outputs = net(inputs)\n", 201 | " loss = criterion(outputs, targets)\n", 202 | " loss.backward()\n", 203 | " # sparse_selection()\n", 204 | " optimizer.step()\n", 205 | " train_loss += loss.item()\n", 206 | " _, predicted = outputs.max(1)\n", 207 | " total += targets.size(0)\n", 208 | " correct += predicted.eq(targets).sum().item()\n", 209 | " loss_all.append(float(loss))\n", 210 | " acc1=te(net, device, criterion)\n", 211 | " acc_.append(float(acc1))\n", 212 | " print(e, acc1)\n", 213 | " print(time.time() - t)\n", 214 | " with open('loss.txt','w',encoding=\"utf-8\") as f:\n", 215 | " f.write(str(loss_all))\n", 216 | " with open('acc.txt','w',encoding=\"utf-8\") as f:\n", 217 | " f.write(str(acc_))\n", 218 | " torch.save(model.state_dict(),'./model1.pt')\n", 219 | "\n", 220 | "\n", 221 | "def te(net, device, criterion):\n", 222 | " test_loss = 0\n", 223 | " correct = 0\n", 224 | " total = 0\n", 225 | " net.eval()\n", 226 | " with torch.no_grad():\n", 227 | " for batch_idx, (inputs, targets) in enumerate(testloader):\n", 228 | " inputs, targets = inputs.to(device), targets.to(device)\n", 229 | " outputs = net(inputs)\n", 230 | " loss = criterion(outputs, targets)\n", 231 | " test_loss += loss.item()\n", 232 | " _,predicted = outputs.max(1)\n", 233 | " total += targets.size(0)\n", 234 | " correct += predicted.eq(targets).sum().item()\n", 235 | " return correct / total" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": null, 241 | "id": "462f2e28", 242 | "metadata": { 243 | "scrolled": false 244 | }, 245 | "outputs": [ 246 | { 247 | "name": "stdout", 248 | "output_type": "stream", 249 | "text": [ 250 | " \n", 251 | "Epoch: 400\n", 252 | "Files already downloaded and verified\n", 253 | "0 0.4761\n", 254 | "29.307168006896973\n", 255 | "Files already downloaded and verified\n", 256 | "1 0.5389\n", 257 | "58.45121216773987\n", 258 | "Files already downloaded and verified\n", 259 | "2 0.5878\n", 260 | "87.4048445224762\n", 261 | "Files already downloaded and verified\n", 262 | "3 0.6133\n", 263 | "116.34339594841003\n", 264 | "Files already downloaded and verified\n", 265 | "4 0.6341\n", 266 | "145.18039512634277\n", 267 | "Files already downloaded and verified\n", 268 | "5 0.6398\n", 269 | "173.89497756958008\n", 270 | "Files already downloaded and verified\n", 271 | "6 0.6512\n", 272 | "202.7138066291809\n", 273 | "Files already downloaded and verified\n", 274 | "7 0.6406\n", 275 | "231.46543073654175\n", 276 | "Files already downloaded and verified\n", 277 | "8 0.6779\n", 278 | "260.24078011512756\n", 279 | "Files already downloaded and verified\n", 280 | "9 0.692\n", 281 | "289.24566769599915\n", 282 | "Files already downloaded and verified\n", 283 | "10 0.6712\n", 284 | "318.88187623023987\n", 285 | "Files already downloaded and verified\n", 286 | "11 0.7031\n", 287 | "348.45001435279846\n", 288 | "Files already downloaded and verified\n", 289 | "12 0.7064\n", 290 | "377.93659353256226\n", 291 | "Files already downloaded and verified\n", 292 | "13 0.7188\n", 293 | "407.6235933303833\n", 294 | "Files already downloaded and verified\n", 295 | "14 0.731\n", 296 | "437.13007640838623\n", 297 | "Files already downloaded and verified\n", 298 | "15 0.7167\n", 299 | "466.68607687950134\n", 300 | "Files already downloaded and verified\n", 301 | "16 0.7433\n", 302 | "496.4366521835327\n", 303 | "Files already downloaded and verified\n", 304 | "17 0.7405\n", 305 | "526.3396534919739\n", 306 | "Files already downloaded and verified\n", 307 | "18 0.7431\n", 308 | "556.0102255344391\n", 309 | "Files already downloaded and verified\n", 310 | "19 0.7477\n", 311 | "585.6592252254486\n", 312 | "Files already downloaded and verified\n", 313 | "20 0.7535\n", 314 | "615.3573927879333\n", 315 | "Files already downloaded and verified\n", 316 | "21 0.7566\n", 317 | "645.2933886051178\n", 318 | "Files already downloaded and verified\n", 319 | "22 0.7689\n", 320 | "675.122964143753\n", 321 | "Files already downloaded and verified\n", 322 | "23 0.7733\n", 323 | "705.1339612007141\n", 324 | "Files already downloaded and verified\n", 325 | "24 0.7654\n", 326 | "735.3145146369934\n", 327 | "Files already downloaded and verified\n", 328 | "25 0.7738\n", 329 | "765.5345134735107\n", 330 | "Files already downloaded and verified\n", 331 | "26 0.7715\n", 332 | "794.9930815696716\n", 333 | "Files already downloaded and verified\n", 334 | "27 0.785\n", 335 | "824.4260857105255\n", 336 | "Files already downloaded and verified\n", 337 | "28 0.7871\n", 338 | "853.93115067482\n", 339 | "Files already downloaded and verified\n", 340 | "29 0.7923\n", 341 | "883.3743948936462\n", 342 | "Files already downloaded and verified\n", 343 | "30 0.7853\n", 344 | "912.7820010185242\n", 345 | "Files already downloaded and verified\n", 346 | "31 0.7898\n", 347 | "942.0680010318756\n", 348 | "Files already downloaded and verified\n", 349 | "32 0.7961\n", 350 | "971.5485422611237\n", 351 | "Files already downloaded and verified\n", 352 | "33 0.7949\n", 353 | "1000.9655900001526\n", 354 | "Files already downloaded and verified\n", 355 | "34 0.7995\n", 356 | "1030.5071816444397\n", 357 | "Files already downloaded and verified\n", 358 | "35 0.7992\n", 359 | "1059.2558705806732\n", 360 | "Files already downloaded and verified\n", 361 | "36 0.8082\n", 362 | "1087.9628076553345\n", 363 | "Files already downloaded and verified\n", 364 | "37 0.8102\n", 365 | "1116.783807516098\n", 366 | "Files already downloaded and verified\n", 367 | "38 0.8138\n", 368 | "1146.0086514949799\n", 369 | "Files already downloaded and verified\n", 370 | "39 0.809\n", 371 | "1175.2554378509521\n", 372 | "Files already downloaded and verified\n", 373 | "40 0.8096\n", 374 | "1204.2070631980896\n", 375 | "Files already downloaded and verified\n", 376 | "41 0.816\n", 377 | "1233.9224636554718\n", 378 | "Files already downloaded and verified\n", 379 | "42 0.816\n", 380 | "1263.653079509735\n", 381 | "Files already downloaded and verified\n", 382 | "43 0.8224\n", 383 | "1293.1566100120544\n", 384 | "Files already downloaded and verified\n", 385 | "44 0.8147\n", 386 | "1321.8671894073486\n", 387 | "Files already downloaded and verified\n", 388 | "45 0.8242\n", 389 | "1350.642852306366\n", 390 | "Files already downloaded and verified\n", 391 | "46 0.8121\n", 392 | "1379.9295959472656\n", 393 | "Files already downloaded and verified\n", 394 | "47 0.8197\n", 395 | "1408.6795144081116\n", 396 | "Files already downloaded and verified\n", 397 | "48 0.8207\n", 398 | "1437.4854545593262\n", 399 | "Files already downloaded and verified\n", 400 | "49 0.8258\n", 401 | "1466.2095432281494\n", 402 | "Files already downloaded and verified\n", 403 | "50 0.8284\n", 404 | "1495.0988445281982\n", 405 | "Files already downloaded and verified\n", 406 | "51 0.8342\n", 407 | "1523.8552837371826\n", 408 | "Files already downloaded and verified\n", 409 | "52 0.8294\n", 410 | "1552.6847352981567\n", 411 | "Files already downloaded and verified\n", 412 | "53 0.8257\n", 413 | "1581.412267446518\n", 414 | "Files already downloaded and verified\n", 415 | "54 0.8251\n", 416 | "1610.1712653636932\n", 417 | "Files already downloaded and verified\n", 418 | "55 0.8286\n", 419 | "1639.0208444595337\n", 420 | "Files already downloaded and verified\n", 421 | "56 0.8398\n", 422 | "1668.114661693573\n", 423 | "Files already downloaded and verified\n", 424 | "57 0.8343\n", 425 | "1697.1348295211792\n", 426 | "Files already downloaded and verified\n", 427 | "58 0.8372\n", 428 | "1725.884958267212\n", 429 | "Files already downloaded and verified\n", 430 | "59 0.8419\n", 431 | "1754.6611211299896\n", 432 | "Files already downloaded and verified\n", 433 | "60 0.8378\n", 434 | "1783.4131228923798\n", 435 | "Files already downloaded and verified\n", 436 | "61 0.8364\n", 437 | "1812.235957145691\n", 438 | "Files already downloaded and verified\n", 439 | "62 0.8373\n", 440 | "1841.0814232826233\n", 441 | "Files already downloaded and verified\n", 442 | "63 0.8383\n", 443 | "1869.8810245990753\n", 444 | "Files already downloaded and verified\n", 445 | "64 0.8454\n", 446 | "1898.7570266723633\n", 447 | "Files already downloaded and verified\n", 448 | "65 0.8387\n", 449 | "1927.549434185028\n", 450 | "Files already downloaded and verified\n", 451 | "66 0.8443\n", 452 | "1956.4424016475677\n", 453 | "Files already downloaded and verified\n", 454 | "67 0.8423\n", 455 | "1985.2309277057648\n", 456 | "Files already downloaded and verified\n", 457 | "68 0.839\n", 458 | "2014.1441338062286\n", 459 | "Files already downloaded and verified\n", 460 | "69 0.8499\n", 461 | "2042.9219992160797\n", 462 | "Files already downloaded and verified\n", 463 | "70 0.844\n", 464 | "2071.782258272171\n", 465 | "Files already downloaded and verified\n", 466 | "71 0.8412\n", 467 | "2100.6853671073914\n", 468 | "Files already downloaded and verified\n", 469 | "72 0.8521\n", 470 | "2129.5740933418274\n", 471 | "Files already downloaded and verified\n", 472 | "73 0.8436\n", 473 | "2158.4499707221985\n", 474 | "Files already downloaded and verified\n", 475 | "74 0.8477\n", 476 | "2187.188970565796\n", 477 | "Files already downloaded and verified\n", 478 | "75 0.8477\n", 479 | "2216.1784613132477\n", 480 | "Files already downloaded and verified\n", 481 | "76 0.8507\n", 482 | "2245.033898115158\n", 483 | "Files already downloaded and verified\n", 484 | "77 0.8474\n", 485 | "2273.9508938789368\n", 486 | "Files already downloaded and verified\n", 487 | "78 0.8525\n", 488 | "2302.8131840229034\n", 489 | "Files already downloaded and verified\n", 490 | "79 0.8509\n", 491 | "2331.674798488617\n", 492 | "Files already downloaded and verified\n", 493 | "80 0.8526\n", 494 | "2360.4901099205017\n", 495 | "Files already downloaded and verified\n", 496 | "81 0.8518\n", 497 | "2389.290105342865\n", 498 | "Files already downloaded and verified\n", 499 | "82 0.8481\n", 500 | "2418.2461037635803\n", 501 | "Files already downloaded and verified\n", 502 | "83 0.8558\n", 503 | "2447.100433588028\n", 504 | "Files already downloaded and verified\n", 505 | "84 0.849\n", 506 | "2476.1803023815155\n", 507 | "Files already downloaded and verified\n", 508 | "85 0.849\n", 509 | "2505.050640821457\n", 510 | "Files already downloaded and verified\n", 511 | "86 0.8554\n", 512 | "2533.9436407089233\n", 513 | "Files already downloaded and verified\n", 514 | "87 0.8515\n", 515 | "2562.8403527736664\n", 516 | "Files already downloaded and verified\n", 517 | "88 0.8465\n", 518 | "2591.721936941147\n", 519 | "Files already downloaded and verified\n", 520 | "89 0.8525\n", 521 | "2620.6779370307922\n", 522 | "Files already downloaded and verified\n", 523 | "90 0.8498\n", 524 | "2649.5305058956146\n", 525 | "Files already downloaded and verified\n", 526 | "91 0.8475\n", 527 | "2678.519360780716\n", 528 | "Files already downloaded and verified\n", 529 | "92 0.8488\n", 530 | "2707.379362821579\n", 531 | "Files already downloaded and verified\n", 532 | "93 0.8507\n", 533 | "2736.391365289688\n", 534 | "Files already downloaded and verified\n", 535 | "94 0.8531\n", 536 | "2765.2123641967773\n", 537 | "Files already downloaded and verified\n", 538 | "95 0.8485\n", 539 | "2794.0866923332214\n", 540 | "Files already downloaded and verified\n", 541 | "96 0.8501\n", 542 | "2823.059695959091\n", 543 | "Files already downloaded and verified\n", 544 | "97 0.8559\n", 545 | "2851.8858444690704\n", 546 | "Files already downloaded and verified\n", 547 | "98 0.8571\n", 548 | "2880.8518443107605\n", 549 | "Files already downloaded and verified\n", 550 | "99 0.8514\n", 551 | "2909.6992008686066\n", 552 | "Files already downloaded and verified\n", 553 | "100 0.8449\n", 554 | "2938.711202144623\n", 555 | "Files already downloaded and verified\n", 556 | "101 0.8577\n", 557 | "2967.480175256729\n", 558 | "Files already downloaded and verified\n", 559 | "102 0.8576\n", 560 | "2996.4224247932434\n", 561 | "Files already downloaded and verified\n", 562 | "103 0.8548\n", 563 | "3025.2374305725098\n", 564 | "Files already downloaded and verified\n", 565 | "104 0.8485\n", 566 | "3054.1792130470276\n", 567 | "Files already downloaded and verified\n", 568 | "105 0.8557\n", 569 | "3083.138461828232\n", 570 | "Files already downloaded and verified\n", 571 | "106 0.8546\n", 572 | "3112.13246178627\n", 573 | "Files already downloaded and verified\n", 574 | "107 0.853\n", 575 | "3141.0764620304108\n", 576 | "Files already downloaded and verified\n", 577 | "108 0.8566\n", 578 | "3169.9778990745544\n", 579 | "Files already downloaded and verified\n", 580 | "109 0.8555\n", 581 | "3198.9382350444794\n", 582 | "Files already downloaded and verified\n", 583 | "110 0.8606\n", 584 | "3227.7409551143646\n", 585 | "Files already downloaded and verified\n", 586 | "111 0.8543\n", 587 | "3256.7966232299805\n", 588 | "Files already downloaded and verified\n", 589 | "112 0.8597\n", 590 | "3285.7649490833282\n", 591 | "Files already downloaded and verified\n", 592 | "113 0.8572\n", 593 | "3314.626479625702\n", 594 | "Files already downloaded and verified\n", 595 | "114 0.8583\n", 596 | "3343.56650185585\n", 597 | "Files already downloaded and verified\n", 598 | "115 0.8585\n", 599 | "3372.4235701560974\n", 600 | "Files already downloaded and verified\n", 601 | "116 0.8527\n", 602 | "3401.4094746112823\n", 603 | "Files already downloaded and verified\n", 604 | "117 0.8585\n", 605 | "3430.2264745235443\n", 606 | "Files already downloaded and verified\n", 607 | "118 0.8583\n", 608 | "3459.1754744052887\n", 609 | "Files already downloaded and verified\n", 610 | "119 0.8572\n", 611 | "3487.972474336624\n", 612 | "Files already downloaded and verified\n", 613 | "120 0.8568\n", 614 | "3516.984836578369\n", 615 | "Files already downloaded and verified\n", 616 | "121 0.8584\n", 617 | "3545.878359079361\n", 618 | "Files already downloaded and verified\n", 619 | "122 0.8544\n", 620 | "3574.7086312770844\n" 621 | ] 622 | }, 623 | { 624 | "name": "stdout", 625 | "output_type": "stream", 626 | "text": [ 627 | "Files already downloaded and verified\n", 628 | "123 0.8601\n", 629 | "3603.6176307201385\n", 630 | "Files already downloaded and verified\n", 631 | "124 0.8613\n", 632 | "3632.452630996704\n", 633 | "Files already downloaded and verified\n", 634 | "125 0.8611\n", 635 | "3661.4248089790344\n", 636 | "Files already downloaded and verified\n", 637 | "126 0.86\n", 638 | "3690.3648071289062\n", 639 | "Files already downloaded and verified\n", 640 | "127 0.8583\n", 641 | "3719.347805738449\n", 642 | "Files already downloaded and verified\n", 643 | "128 0.8644\n", 644 | "3748.252805709839\n", 645 | "Files already downloaded and verified\n", 646 | "129 0.8624\n", 647 | "3777.2108058929443\n", 648 | "Files already downloaded and verified\n", 649 | "130 0.8616\n", 650 | "3806.202753305435\n", 651 | "Files already downloaded and verified\n", 652 | "131 0.8597\n", 653 | "3835.084849834442\n", 654 | "Files already downloaded and verified\n", 655 | "132 0.8607\n", 656 | "3864.048240184784\n", 657 | "Files already downloaded and verified\n", 658 | "133 0.8588\n", 659 | "3892.964912414551\n", 660 | "Files already downloaded and verified\n", 661 | "134 0.856\n", 662 | "3921.906912088394\n", 663 | "Files already downloaded and verified\n", 664 | "135 0.8525\n", 665 | "3950.734563589096\n", 666 | "Files already downloaded and verified\n", 667 | "136 0.861\n", 668 | "3979.632709503174\n", 669 | "Files already downloaded and verified\n", 670 | "137 0.8571\n", 671 | "4008.5031604766846\n", 672 | "Files already downloaded and verified\n", 673 | "138 0.8533\n", 674 | "4037.4221620559692\n", 675 | "Files already downloaded and verified\n", 676 | "139 0.8563\n", 677 | "4066.4565885066986\n", 678 | "Files already downloaded and verified\n", 679 | "140 0.8578\n", 680 | "4095.342499256134\n", 681 | "Files already downloaded and verified\n", 682 | "141 0.8547\n", 683 | "4124.326496124268\n", 684 | "Files already downloaded and verified\n", 685 | "142 0.861\n", 686 | "4153.162144422531\n", 687 | "Files already downloaded and verified\n", 688 | "143 0.8572\n", 689 | "4182.052597999573\n", 690 | "Files already downloaded and verified\n", 691 | "144 0.8635\n", 692 | "4210.961163282394\n", 693 | "Files already downloaded and verified\n", 694 | "145 0.861\n", 695 | "4239.935163736343\n", 696 | "Files already downloaded and verified\n", 697 | "146 0.8599\n", 698 | "4268.872621536255\n", 699 | "Files already downloaded and verified\n", 700 | "147 0.8591\n", 701 | "4297.766646146774\n", 702 | "Files already downloaded and verified\n", 703 | "148 0.854\n", 704 | "4326.69043636322\n", 705 | "Files already downloaded and verified\n", 706 | "149 0.8631\n", 707 | "4355.6024334430695\n", 708 | "Files already downloaded and verified\n", 709 | "150 0.8635\n", 710 | "4384.638692140579\n", 711 | "Files already downloaded and verified\n", 712 | "151 0.8635\n", 713 | "4413.5236921310425\n", 714 | "Files already downloaded and verified\n", 715 | "152 0.86\n", 716 | "4442.4960334300995\n", 717 | "Files already downloaded and verified\n", 718 | "153 0.8651\n", 719 | "4471.392124652863\n", 720 | "Files already downloaded and verified\n", 721 | "154 0.8591\n", 722 | "4500.314920425415\n", 723 | "Files already downloaded and verified\n", 724 | "155 0.8671\n", 725 | "4529.250314235687\n", 726 | "Files already downloaded and verified\n", 727 | "156 0.8634\n", 728 | "4558.088312864304\n", 729 | "Files already downloaded and verified\n", 730 | "157 0.8633\n", 731 | "4587.045312404633\n", 732 | "Files already downloaded and verified\n", 733 | "158 0.8656\n", 734 | "4615.930367469788\n", 735 | "Files already downloaded and verified\n", 736 | "159 0.8598\n", 737 | "4644.8441252708435\n", 738 | "Files already downloaded and verified\n", 739 | "160 0.8599\n", 740 | "4673.726129055023\n", 741 | "Files already downloaded and verified\n", 742 | "161 0.8635\n", 743 | "4702.672237634659\n", 744 | "Files already downloaded and verified\n", 745 | "162 0.8654\n", 746 | "4731.492237567902\n", 747 | "Files already downloaded and verified\n", 748 | "163 0.865\n", 749 | "4760.3593583106995\n", 750 | "Files already downloaded and verified\n", 751 | "164 0.8668\n", 752 | "4789.251357793808\n", 753 | "Files already downloaded and verified\n", 754 | "165 0.8659\n", 755 | "4818.130734920502\n", 756 | "Files already downloaded and verified\n", 757 | "166 0.8577\n", 758 | "4847.117177248001\n", 759 | "Files already downloaded and verified\n", 760 | "167 0.8682\n", 761 | "4876.013174057007\n", 762 | "Files already downloaded and verified\n", 763 | "168 0.8667\n", 764 | "4905.025173664093\n", 765 | "Files already downloaded and verified\n", 766 | "169 0.8676\n", 767 | "4934.0305070877075\n", 768 | "Files already downloaded and verified\n", 769 | "170 0.8652\n", 770 | "4963.03350520134\n", 771 | "Files already downloaded and verified\n", 772 | "171 0.8652\n", 773 | "4991.935520648956\n", 774 | "Files already downloaded and verified\n", 775 | "172 0.8653\n", 776 | "5020.844522476196\n", 777 | "Files already downloaded and verified\n", 778 | "173 0.8627\n", 779 | "5049.892185688019\n", 780 | "Files already downloaded and verified\n", 781 | "174 0.8677\n", 782 | "5078.832265377045\n", 783 | "Files already downloaded and verified\n", 784 | "175 0.8665\n", 785 | "5107.87126326561\n", 786 | "Files already downloaded and verified\n", 787 | "176 0.8664\n", 788 | "5136.725019454956\n", 789 | "Files already downloaded and verified\n", 790 | "177 0.8679\n", 791 | "5165.740604877472\n", 792 | "Files already downloaded and verified\n", 793 | "178 0.8658\n", 794 | "5194.674741983414\n", 795 | "Files already downloaded and verified\n", 796 | "179 0.8655\n", 797 | "5223.6759557724\n", 798 | "Files already downloaded and verified\n", 799 | "180 0.8625\n", 800 | "5252.520621299744\n", 801 | "Files already downloaded and verified\n", 802 | "181 0.8598\n", 803 | "5281.435574531555\n", 804 | "Files already downloaded and verified\n", 805 | "182 0.8654\n", 806 | "5310.380483865738\n", 807 | "Files already downloaded and verified\n", 808 | "183 0.8683\n", 809 | "5339.325177192688\n", 810 | "Files already downloaded and verified\n", 811 | "184 0.8657\n", 812 | "5368.260927915573\n", 813 | "Files already downloaded and verified\n", 814 | "185 0.8621\n", 815 | "5397.117927789688\n", 816 | "Files already downloaded and verified\n", 817 | "186 0.8676\n", 818 | "5426.113262653351\n", 819 | "Files already downloaded and verified\n", 820 | "187 0.8677\n", 821 | "5455.049920558929\n", 822 | "Files already downloaded and verified\n", 823 | "188 0.8673\n", 824 | "5484.004916906357\n", 825 | "Files already downloaded and verified\n", 826 | "189 0.8726\n", 827 | "5512.976318597794\n", 828 | "Files already downloaded and verified\n", 829 | "190 0.8607\n", 830 | "5541.846435546875\n", 831 | "Files already downloaded and verified\n", 832 | "191 0.8649\n", 833 | "5570.888437271118\n", 834 | "Files already downloaded and verified\n", 835 | "192 0.8665\n", 836 | "5599.779435634613\n", 837 | "Files already downloaded and verified\n", 838 | "193 0.8667\n", 839 | "5628.768196344376\n", 840 | "Files already downloaded and verified\n", 841 | "194 0.8635\n", 842 | "5657.627874851227\n", 843 | "Files already downloaded and verified\n", 844 | "195 0.8693\n", 845 | "5686.620139837265\n", 846 | "Files already downloaded and verified\n", 847 | "196 0.8693\n", 848 | "5715.5207307338715\n", 849 | "Files already downloaded and verified\n", 850 | "197 0.8627\n", 851 | "5744.46194934845\n", 852 | "Files already downloaded and verified\n", 853 | "198 0.8633\n", 854 | "5773.426466703415\n", 855 | "Files already downloaded and verified\n", 856 | "199 0.8651\n", 857 | "5802.355464458466\n", 858 | "Files already downloaded and verified\n", 859 | "200 0.8653\n", 860 | "5831.307244539261\n", 861 | "Files already downloaded and verified\n", 862 | "201 0.8671\n", 863 | "5860.188714265823\n", 864 | "Files already downloaded and verified\n", 865 | "202 0.8719\n", 866 | "5889.160489559174\n", 867 | "Files already downloaded and verified\n", 868 | "203 0.8638\n", 869 | "5918.0951290130615\n", 870 | "Files already downloaded and verified\n", 871 | "204 0.8648\n", 872 | "5947.021124601364\n", 873 | "Files already downloaded and verified\n", 874 | "205 0.8713\n", 875 | "5975.794320106506\n", 876 | "Files already downloaded and verified\n", 877 | "206 0.8694\n", 878 | "6004.633628368378\n", 879 | "Files already downloaded and verified\n", 880 | "207 0.8628\n", 881 | "6033.676629543304\n", 882 | "Files already downloaded and verified\n", 883 | "208 0.866\n", 884 | "6062.569153547287\n", 885 | "Files already downloaded and verified\n", 886 | "209 0.8665\n", 887 | "6091.512659788132\n", 888 | "Files already downloaded and verified\n", 889 | "210 0.8678\n", 890 | "6120.433195114136\n", 891 | "Files already downloaded and verified\n", 892 | "211 0.8681\n", 893 | "6149.428193330765\n", 894 | "Files already downloaded and verified\n", 895 | "212 0.8651\n", 896 | "6178.350557804108\n", 897 | "Files already downloaded and verified\n", 898 | "213 0.8668\n", 899 | "6207.342449426651\n", 900 | "Files already downloaded and verified\n", 901 | "214 0.8651\n", 902 | "6236.23645234108\n", 903 | "Files already downloaded and verified\n", 904 | "215 0.8658\n", 905 | "6265.138925790787\n", 906 | "Files already downloaded and verified\n", 907 | "216 0.8669\n", 908 | "6294.140892982483\n", 909 | "Files already downloaded and verified\n", 910 | "217 0.8658\n", 911 | "6323.1188888549805\n", 912 | "Files already downloaded and verified\n", 913 | "218 0.8704\n", 914 | "6352.101068973541\n", 915 | "Files already downloaded and verified\n", 916 | "219 0.8616\n", 917 | "6381.048783063889\n", 918 | "Files already downloaded and verified\n", 919 | "220 0.8655\n", 920 | "6410.134751796722\n", 921 | "Files already downloaded and verified\n", 922 | "221 0.8653\n", 923 | "6439.146332025528\n", 924 | "Files already downloaded and verified\n", 925 | "222 0.869\n", 926 | "6468.18933224678\n", 927 | "Files already downloaded and verified\n", 928 | "223 0.8645\n", 929 | "6497.202334165573\n", 930 | "Files already downloaded and verified\n", 931 | "224 0.8725\n", 932 | "6526.145641326904\n", 933 | "Files already downloaded and verified\n", 934 | "225 0.8733\n", 935 | "6555.202461719513\n", 936 | "Files already downloaded and verified\n", 937 | "226 0.8715\n", 938 | "6584.272459983826\n", 939 | "Files already downloaded and verified\n", 940 | "227 0.8737\n", 941 | "6613.3134615421295\n", 942 | "Files already downloaded and verified\n", 943 | "228 0.8707\n", 944 | "6642.291624307632\n", 945 | "Files already downloaded and verified\n", 946 | "229 0.8706\n", 947 | "6671.36280465126\n", 948 | "Files already downloaded and verified\n", 949 | "230 0.8681\n", 950 | "6700.384170055389\n", 951 | "Files already downloaded and verified\n", 952 | "231 0.8703\n", 953 | "6729.372729063034\n", 954 | "Files already downloaded and verified\n", 955 | "232 0.8701\n", 956 | "6758.36572933197\n", 957 | "Files already downloaded and verified\n", 958 | "233 0.8703\n", 959 | "6787.321729183197\n", 960 | "Files already downloaded and verified\n", 961 | "234 0.8661\n", 962 | "6816.275586605072\n", 963 | "Files already downloaded and verified\n", 964 | "235 0.8685\n", 965 | "6845.386409282684\n", 966 | "Files already downloaded and verified\n", 967 | "236 0.8645\n", 968 | "6874.513419866562\n", 969 | "Files already downloaded and verified\n", 970 | "237 0.8687\n", 971 | "6903.353421211243\n", 972 | "Files already downloaded and verified\n", 973 | "238 0.8626\n", 974 | "6932.3168156147\n", 975 | "Files already downloaded and verified\n", 976 | "239 0.8654\n", 977 | "6961.345813989639\n", 978 | "Files already downloaded and verified\n", 979 | "240 0.869\n", 980 | "6990.420814037323\n", 981 | "Files already downloaded and verified\n", 982 | "241 0.8725\n", 983 | "7019.544239997864\n", 984 | "Files already downloaded and verified\n", 985 | "242 0.8709\n", 986 | "7048.445239782333\n", 987 | "Files already downloaded and verified\n", 988 | "243 0.8698\n", 989 | "7077.4194214344025\n", 990 | "Files already downloaded and verified\n", 991 | "244 0.8692\n", 992 | "7106.395797014236\n", 993 | "Files already downloaded and verified\n" 994 | ] 995 | }, 996 | { 997 | "name": "stdout", 998 | "output_type": "stream", 999 | "text": [ 1000 | "245 0.8703\n", 1001 | "7135.537796735764\n", 1002 | "Files already downloaded and verified\n", 1003 | "246 0.8665\n", 1004 | "7164.451796770096\n", 1005 | "Files already downloaded and verified\n", 1006 | "247 0.8716\n", 1007 | "7193.568591594696\n", 1008 | "Files already downloaded and verified\n", 1009 | "248 0.8713\n", 1010 | "7222.418637275696\n", 1011 | "Files already downloaded and verified\n", 1012 | "249 0.8704\n", 1013 | "7251.385093688965\n", 1014 | "Files already downloaded and verified\n", 1015 | "250 0.8704\n", 1016 | "7280.374188661575\n", 1017 | "Files already downloaded and verified\n", 1018 | "251 0.8657\n", 1019 | "7309.295281887054\n", 1020 | "Files already downloaded and verified\n", 1021 | "252 0.8719\n", 1022 | "7338.296778917313\n", 1023 | "Files already downloaded and verified\n", 1024 | "253 0.8694\n", 1025 | "7367.257781505585\n", 1026 | "Files already downloaded and verified\n", 1027 | "254 0.8697\n", 1028 | "7396.2707777023315\n", 1029 | "Files already downloaded and verified\n", 1030 | "255 0.8695\n", 1031 | "7425.350219249725\n", 1032 | "Files already downloaded and verified\n", 1033 | "256 0.8714\n", 1034 | "7454.29833316803\n", 1035 | "Files already downloaded and verified\n", 1036 | "257 0.8667\n", 1037 | "7483.201336860657\n", 1038 | "Files already downloaded and verified\n", 1039 | "258 0.8703\n", 1040 | "7512.001064538956\n", 1041 | "Files already downloaded and verified\n", 1042 | "259 0.872\n", 1043 | "7541.01106262207\n", 1044 | "Files already downloaded and verified\n", 1045 | "260 0.8697\n", 1046 | "7569.905064344406\n", 1047 | "Files already downloaded and verified\n", 1048 | "261 0.8684\n", 1049 | "7598.985243558884\n", 1050 | "Files already downloaded and verified\n", 1051 | "262 0.8715\n", 1052 | "7627.904242277145\n", 1053 | "Files already downloaded and verified\n", 1054 | "263 0.8668\n", 1055 | "7656.911458730698\n", 1056 | "Files already downloaded and verified\n", 1057 | "264 0.8702\n", 1058 | "7685.813384056091\n", 1059 | "Files already downloaded and verified\n", 1060 | "265 0.8689\n", 1061 | "7714.840382099152\n", 1062 | "Files already downloaded and verified\n", 1063 | "266 0.8697\n", 1064 | "7743.732383728027\n", 1065 | "Files already downloaded and verified\n", 1066 | "267 0.8658\n", 1067 | "7772.648947715759\n", 1068 | "Files already downloaded and verified\n", 1069 | "268 0.8693\n", 1070 | "7801.564166069031\n", 1071 | "Files already downloaded and verified\n", 1072 | "269 0.8723\n", 1073 | "7830.4175906181335\n", 1074 | "Files already downloaded and verified\n", 1075 | "270 0.8682\n", 1076 | "7859.44679069519\n", 1077 | "Files already downloaded and verified\n", 1078 | "271 0.865\n", 1079 | "7888.371791124344\n", 1080 | "Files already downloaded and verified\n", 1081 | "272 0.8685\n", 1082 | "7917.349725723267\n", 1083 | "Files already downloaded and verified\n", 1084 | "273 0.8722\n", 1085 | "7946.2257261276245\n", 1086 | "Files already downloaded and verified\n", 1087 | "274 0.8726\n", 1088 | "7975.226726055145\n", 1089 | "Files already downloaded and verified\n", 1090 | "275 0.8714\n", 1091 | "8004.0921041965485\n", 1092 | "Files already downloaded and verified\n", 1093 | "276 0.8673\n", 1094 | "8033.001600980759\n", 1095 | "Files already downloaded and verified\n", 1096 | "277 0.8692\n", 1097 | "8061.983365535736\n", 1098 | "Files already downloaded and verified\n", 1099 | "278 0.8675\n", 1100 | "8090.912189483643\n", 1101 | "Files already downloaded and verified\n", 1102 | "279 0.8692\n", 1103 | "8119.910188436508\n", 1104 | "Files already downloaded and verified\n", 1105 | "280 0.8684\n", 1106 | "8148.76118850708\n", 1107 | "Files already downloaded and verified\n", 1108 | "281 0.8676\n", 1109 | "8177.749190568924\n", 1110 | "Files already downloaded and verified\n", 1111 | "282 0.87\n", 1112 | "8206.656646251678\n", 1113 | "Files already downloaded and verified\n", 1114 | "283 0.8699\n", 1115 | "8235.712460756302\n", 1116 | "Files already downloaded and verified\n", 1117 | "284 0.8653\n", 1118 | "8264.724035978317\n", 1119 | "Files already downloaded and verified\n", 1120 | "285 0.8689\n", 1121 | "8293.543662309647\n", 1122 | "Files already downloaded and verified\n", 1123 | "286 0.8686\n", 1124 | "8322.527196645737\n", 1125 | "Files already downloaded and verified\n", 1126 | "287 0.8713\n", 1127 | "8351.529844999313\n", 1128 | "Files already downloaded and verified\n", 1129 | "288 0.8694\n", 1130 | "8380.490220546722\n", 1131 | "Files already downloaded and verified\n", 1132 | "289 0.8705\n", 1133 | "8409.314220428467\n", 1134 | "Files already downloaded and verified\n", 1135 | "290 0.8732\n", 1136 | "8438.293220758438\n", 1137 | "Files already downloaded and verified\n", 1138 | "291 0.868\n", 1139 | "8467.266277551651\n", 1140 | "Files already downloaded and verified\n", 1141 | "292 0.8717\n", 1142 | "8496.26927781105\n", 1143 | "Files already downloaded and verified\n", 1144 | "293 0.8715\n", 1145 | "8525.324277639389\n", 1146 | "Files already downloaded and verified\n", 1147 | "294 0.8668\n", 1148 | "8554.318277359009\n", 1149 | "Files already downloaded and verified\n", 1150 | "295 0.8701\n", 1151 | "8583.343279123306\n", 1152 | "Files already downloaded and verified\n", 1153 | "296 0.869\n", 1154 | "8612.264277219772\n", 1155 | "Files already downloaded and verified\n", 1156 | "297 0.8694\n", 1157 | "8641.254587650299\n", 1158 | "Files already downloaded and verified\n", 1159 | "298 0.866\n", 1160 | "8670.124589681625\n", 1161 | "Files already downloaded and verified\n", 1162 | "299 0.8684\n", 1163 | "8699.161587238312\n", 1164 | "Files already downloaded and verified\n", 1165 | "300 0.8701\n", 1166 | "8728.061585903168\n", 1167 | "Files already downloaded and verified\n", 1168 | "301 0.8746\n", 1169 | "8756.969587564468\n", 1170 | "Files already downloaded and verified\n", 1171 | "302 0.8696\n", 1172 | "8785.97458577156\n", 1173 | "Files already downloaded and verified\n", 1174 | "303 0.8731\n", 1175 | "8814.947585821152\n", 1176 | "Files already downloaded and verified\n", 1177 | "304 0.8674\n", 1178 | "8843.993586301804\n", 1179 | "Files already downloaded and verified\n", 1180 | "305 0.8728\n", 1181 | "8872.878713130951\n", 1182 | "Files already downloaded and verified\n", 1183 | "306 0.8718\n", 1184 | "8901.823713302612\n", 1185 | "Files already downloaded and verified\n", 1186 | "307 0.8716\n", 1187 | "8930.777712345123\n", 1188 | "Files already downloaded and verified\n", 1189 | "308 0.8682\n", 1190 | "8959.677854061127\n", 1191 | "Files already downloaded and verified\n", 1192 | "309 0.8736\n", 1193 | "8988.506853818893\n", 1194 | "Files already downloaded and verified\n", 1195 | "310 0.8663\n", 1196 | "9017.350752830505\n", 1197 | "Files already downloaded and verified\n", 1198 | "311 0.8753\n", 1199 | "9046.346252679825\n", 1200 | "Files already downloaded and verified\n", 1201 | "312 0.873\n", 1202 | "9075.24724984169\n", 1203 | "Files already downloaded and verified\n", 1204 | "313 0.8741\n", 1205 | "9104.29703092575\n", 1206 | "Files already downloaded and verified\n", 1207 | "314 0.8714\n", 1208 | "9133.227599859238\n", 1209 | "Files already downloaded and verified\n", 1210 | "315 0.8736\n", 1211 | "9162.228678941727\n", 1212 | "Files already downloaded and verified\n", 1213 | "316 0.8646\n", 1214 | "9191.143491983414\n", 1215 | "Files already downloaded and verified\n", 1216 | "317 0.8668\n", 1217 | "9219.958491563797\n", 1218 | "Files already downloaded and verified\n", 1219 | "318 0.8745\n", 1220 | "9249.023736715317\n", 1221 | "Files already downloaded and verified\n", 1222 | "319 0.879\n", 1223 | "9277.962262392044\n", 1224 | "Files already downloaded and verified\n", 1225 | "320 0.8764\n", 1226 | "9306.960386037827\n", 1227 | "Files already downloaded and verified\n", 1228 | "321 0.8737\n", 1229 | "9335.868592500687\n", 1230 | "Files already downloaded and verified\n", 1231 | "322 0.8727\n", 1232 | "9364.918592214584\n", 1233 | "Files already downloaded and verified\n", 1234 | "323 0.8756\n", 1235 | "9393.86459517479\n", 1236 | "Files already downloaded and verified\n", 1237 | "324 0.8739\n", 1238 | "9422.836590766907\n", 1239 | "Files already downloaded and verified\n", 1240 | "325 0.876\n", 1241 | "9451.738594532013\n", 1242 | "Files already downloaded and verified\n", 1243 | "326 0.876\n", 1244 | "9480.619082927704\n", 1245 | "Files already downloaded and verified\n", 1246 | "327 0.877\n", 1247 | "9509.576842308044\n", 1248 | "Files already downloaded and verified\n", 1249 | "328 0.8764\n", 1250 | "9538.58263874054\n", 1251 | "Files already downloaded and verified\n", 1252 | "329 0.8739\n", 1253 | "9567.516907930374\n", 1254 | "Files already downloaded and verified\n", 1255 | "330 0.8742\n", 1256 | "9596.388902664185\n", 1257 | "Files already downloaded and verified\n", 1258 | "331 0.8744\n", 1259 | "9625.361789226532\n", 1260 | "Files already downloaded and verified\n", 1261 | "332 0.8741\n", 1262 | "9655.885634183884\n", 1263 | "Files already downloaded and verified\n", 1264 | "333 0.8764\n", 1265 | "9685.418642282486\n", 1266 | "Files already downloaded and verified\n", 1267 | "334 0.8676\n", 1268 | "9715.110464811325\n", 1269 | "Files already downloaded and verified\n", 1270 | "335 0.8766\n", 1271 | "9745.10642695427\n", 1272 | "Files already downloaded and verified\n", 1273 | "336 0.8724\n", 1274 | "9775.122428417206\n", 1275 | "Files already downloaded and verified\n", 1276 | "337 0.8741\n", 1277 | "9804.514398813248\n", 1278 | "Files already downloaded and verified\n", 1279 | "338 0.8747\n", 1280 | "9833.624905347824\n", 1281 | "Files already downloaded and verified\n", 1282 | "339 0.8742\n", 1283 | "9862.706157207489\n", 1284 | "Files already downloaded and verified\n", 1285 | "340 0.8725\n", 1286 | "9892.180146932602\n", 1287 | "Files already downloaded and verified\n", 1288 | "341 0.8726\n", 1289 | "9921.79210782051\n", 1290 | "Files already downloaded and verified\n" 1291 | ] 1292 | } 1293 | ], 1294 | "source": [ 1295 | "train(400)" 1296 | ] 1297 | } 1298 | ], 1299 | "metadata": { 1300 | "kernelspec": { 1301 | "display_name": "Python 3 (ipykernel)", 1302 | "language": "python", 1303 | "name": "python3" 1304 | }, 1305 | "language_info": { 1306 | "codemirror_mode": { 1307 | "name": "ipython", 1308 | "version": 3 1309 | }, 1310 | "file_extension": ".py", 1311 | "mimetype": "text/x-python", 1312 | "name": "python", 1313 | "nbconvert_exporter": "python", 1314 | "pygments_lexer": "ipython3", 1315 | "version": "3.11.4" 1316 | } 1317 | }, 1318 | "nbformat": 4, 1319 | "nbformat_minor": 5 1320 | } 1321 | -------------------------------------------------------------------------------- /国科大-深度学习作业/手写数字识别/手写数字识别.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dxcf123/UCAS_DeepLearning_homework/c9a64b80b3377fbc590f8f64f495789eec4cb9ad/国科大-深度学习作业/手写数字识别/手写数字识别.docx -------------------------------------------------------------------------------- /国科大-深度学习作业/手写数字识别/手写数字识别.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "a21ac2aa", 6 | "metadata": {}, 7 | "source": [ 8 | "## 1 导入相关库" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 25, 14 | "id": "eb30e745", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import torch\n", 19 | "from torch import nn\n", 20 | "import torchvision\n", 21 | "from torch.utils.data import DataLoader\n", 22 | "from torch.nn import functional as F" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "id": "e0bd19ea", 28 | "metadata": {}, 29 | "source": [ 30 | "## 2 获取数据集" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 26, 36 | "id": "bb34e1eb", 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "def get_dataset(path, batch_size=32, transform=None):\n", 41 | " \"\"\"\n", 42 | " 加载MNIST数据集并将其转换为DataLoader对象。\n", 43 | " :param path: 数据集路径\n", 44 | " :param batch_size: 批处理大小\n", 45 | " :param transform: 数据预处理\n", 46 | " :return: 训练集与测试集的DataLoader对象\n", 47 | " \"\"\"\n", 48 | " if transform is None:\n", 49 | " transform = torchvision.transforms.Compose([ # 对图像进行预处理\n", 50 | " torchvision.transforms.ToTensor(), # 将图片转换成张量\n", 51 | " torchvision.transforms.Normalize(mean=(0.5,), std=(0.5,)) # 对图像进行归一化处理\n", 52 | " ])\n", 53 | "\n", 54 | " # 训练集\n", 55 | " mnist_train = torchvision.datasets.MNIST( # 加载MNIST数据集,如果本地没有会自动下载\n", 56 | " root=path, train=True, transform=transform, download=True)\n", 57 | " # 测试集\n", 58 | " mnist_test = torchvision.datasets.MNIST(\n", 59 | " root=path, train=False, transform=transform, download=True)\n", 60 | "\n", 61 | " # 创建dataloader对象\n", 62 | " train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)\n", 63 | " test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False)\n", 64 | "\n", 65 | " return train_loader, test_loader" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 27, 71 | "id": "fc6685e8", 72 | "metadata": {}, 73 | "outputs": [ 74 | { 75 | "name": "stdout", 76 | "output_type": "stream", 77 | "text": [ 78 | "5\n" 79 | ] 80 | }, 81 | { 82 | "data": { 83 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABAElEQVR4nGNgGMyAWUhIqK5jvdSy/9/rGRgYGFhgEnJsVjYCwQwMDAxPJgV+vniQgYGBgREqZ7iXH8r6l/SV4dn7m8gmCt3++/fv37/Htn3/iMW+gDnZf/+e5WbQnoXNNXyMs/5GoQoxwVmf/n9kSGFiwAW49/11wynJoPzx4YIcRlyygR/+/i2XxCWru+vv32nSuGQFYv/83Y3b4p9/fzpAmSyoMnohpiwM1w5h06Q+5enfv39/bcMiJVF09+/fv39P+mFKiTtd/fv3799jgZiBJLT69t+/f/8eDuDEkDJf8+jv379/v7Ryo4qzMDAwMAQGMjBc3/y35wM2V1IfAABFF16Aa0wAOwAAAABJRU5ErkJggg==\n", 84 | "text/plain": [ 85 | "" 86 | ] 87 | }, 88 | "execution_count": 27, 89 | "metadata": {}, 90 | "output_type": "execute_result" 91 | } 92 | ], 93 | "source": [ 94 | "# 查看MNIST数据集\n", 95 | "mnist_train = torchvision.datasets.MNIST( # 加载MNIST数据集,如果本地没有会自动下载\n", 96 | " root='./data', train=True, download=True)\n", 97 | "print(mnist_train[0][1])\n", 98 | "mnist_train[0][0]" 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "id": "55bc3fc9", 104 | "metadata": {}, 105 | "source": [ 106 | "## 3 定义模型" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 28, 112 | "id": "f557bafd", 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "class Model(nn.Module): # 构建卷积神经网络\n", 117 | " def __init__(self):\n", 118 | " super(Model, self).__init__()\n", 119 | " # 输入通道,输出通道,卷积核大小,步长,填充\n", 120 | " self.cov1 = nn.Conv2d(1, 15, 3, stride=1, padding=1)\n", 121 | " self.cov2 = nn.Conv2d(15, 45, 3, stride=1, padding=1)\n", 122 | " # 池化层 核大小2*2,步长2\n", 123 | " self.maxpool1 = nn.MaxPool2d(2, stride=2)\n", 124 | " self.maxpool2 = nn.MaxPool2d(2, stride=2)\n", 125 | " # 线性层 输入参数与前面的卷积与池化层输出通道数有关\n", 126 | " self.lin1 = nn.Linear(49*45, 10)\n", 127 | "\n", 128 | " def forward(self, x):\n", 129 | " x = self.cov1(x) # 第一层卷积 28→28\n", 130 | " x = self.maxpool1(x) # 第一层池化 28→14\n", 131 | " x = torch.relu(x) # 激活函数\n", 132 | " x = self.cov2(x) # 第二层卷积 14→14\n", 133 | " x = self.maxpool2(x) # 第二层池化 14→7\n", 134 | " x = torch.relu(x) # 激活函数\n", 135 | " x = x.view(x.size(0), -1) # 将特征展平 7*7→49\n", 136 | " x = self.lin1(x) # 全连接层 49→10\n", 137 | " return x" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "id": "104c1a70", 143 | "metadata": {}, 144 | "source": [ 145 | "## 4 定义准确率函数" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 29, 151 | "id": "91d1e0a9", 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "def acc_test(loader, model, device):\n", 156 | " \"\"\"\n", 157 | " 计算模型在测试集上的准确率。\n", 158 | " :param loader: 测试集的DataLoader对象\n", 159 | " :param model: 模型对象\n", 160 | " :param device: 设备对象\n", 161 | " :return: 准确率\n", 162 | " \"\"\"\n", 163 | " model.eval() # 将模型设置为评估模式\n", 164 | " acc = 0 # 准确的个数\n", 165 | " all_ = 0 # 总个数\n", 166 | " with torch.no_grad(): # 不计算梯度\n", 167 | " for i, (x, y) in enumerate(loader): # 获取输入与输出\n", 168 | " x = x.to(device) # 将图片转换为一维张量\n", 169 | " y = y.to(device)\n", 170 | " pre = model(x) # 预测\n", 171 | " pre = torch.argmax(pre, dim=1) # 获取预测结果每行中的最大值的坐标\n", 172 | " all_ += len(pre) # 记录数据总数\n", 173 | " acc += (pre == y).sum().item() # 记录准确的个数\n", 174 | " return acc / all_ # 返回准确率" 175 | ] 176 | }, 177 | { 178 | "cell_type": "markdown", 179 | "id": "077fe5cd", 180 | "metadata": {}, 181 | "source": [ 182 | "## 5 定义训练函数" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 30, 188 | "id": "cca28ad6", 189 | "metadata": {}, 190 | "outputs": [], 191 | "source": [ 192 | "def train(path, output_=10, batch_size=128, lr=0.01, device='cpu', epochs=1):\n", 193 | " \"\"\"\n", 194 | " 训练模型\n", 195 | " :param path: 数据存放路径\n", 196 | " :param output_: 输出层神经元个数\n", 197 | " :param lr: 学习率\n", 198 | " :param device: 训练设备\n", 199 | " :param epochs: 训练轮数\n", 200 | " :param batch_size 批量大小\n", 201 | " :return: 返回训练后的模型\n", 202 | " \"\"\"\n", 203 | " # 损失函数设置为交叉熵损失\n", 204 | " lossFuction = torch.nn.CrossEntropyLoss()\n", 205 | "\n", 206 | " # 创建一个卷积神经网络的对象\n", 207 | " model = Model()\n", 208 | "\n", 209 | " # 创建优化器\n", 210 | " optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 使用Adam优化器\n", 211 | "\n", 212 | " # 获取数据\n", 213 | " train_loader, test_loader = get_dataset(path, batch_size=batch_size)\n", 214 | "\n", 215 | " # 将模型移动到设备上\n", 216 | " model.to(device)\n", 217 | "\n", 218 | " # 模型设置为训练模式\n", 219 | " model.train()\n", 220 | "\n", 221 | " # 训练模型\n", 222 | " for epoch in range(epochs):\n", 223 | " all_loss=[]\n", 224 | " acc_=[]\n", 225 | " for i, (x, y) in enumerate(train_loader): # 获取输入与输出\n", 226 | " x = x.to(device) # 将图片转换移动到设备上\n", 227 | " # 将输出数据转换为one_hot编码并转换为32位浮点数并移动到设备上\n", 228 | " y = torch.tensor(F.one_hot(y, num_classes=output_), dtype=torch.float32).to(device)\n", 229 | " optimizer.zero_grad() # 将优化器梯度置零\n", 230 | " pre = model(x) # 预测数据\n", 231 | " loss = lossFuction(pre, y) # 计算损失\n", 232 | " loss.backward() # 反向传播\n", 233 | " optimizer.step() # 梯度更新\n", 234 | " if (i + 1) % 10 == 0:\n", 235 | " all_loss.append(float(loss))\n", 236 | " with open('loss.txt','w',encoding='utf-8') as f:\n", 237 | " f.write(str(all_loss))\n", 238 | " acc=acc_test(test_loader, model, device)\n", 239 | " acc_.append(acc)\n", 240 | " with open('acc.txt','w',encoding='utf-8') as f:\n", 241 | " f.write(str(acc_))\n", 242 | " print('准确率: ',acc)\n", 243 | " model.train()\n", 244 | " return model\n" 245 | ] 246 | }, 247 | { 248 | "cell_type": "markdown", 249 | "id": "8b2041d1", 250 | "metadata": {}, 251 | "source": [ 252 | "## 6 训练" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": 32, 258 | "id": "9f41f2ef", 259 | "metadata": {}, 260 | "outputs": [ 261 | { 262 | "name": "stderr", 263 | "output_type": "stream", 264 | "text": [ 265 | "C:\\Users\\30535\\AppData\\Local\\Temp\\ipykernel_15240\\1565514881.py:37: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", 266 | " y = torch.tensor(F.one_hot(y, num_classes=output_), dtype=torch.float32).to(device)\n" 267 | ] 268 | }, 269 | { 270 | "name": "stdout", 271 | "output_type": "stream", 272 | "text": [ 273 | "准确率: 0.7956\n", 274 | "准确率: 0.8726\n", 275 | "准确率: 0.9195\n", 276 | "准确率: 0.94\n", 277 | "准确率: 0.9494\n", 278 | "准确率: 0.9414\n", 279 | "准确率: 0.9564\n", 280 | "准确率: 0.9627\n", 281 | "准确率: 0.9638\n", 282 | "准确率: 0.9662\n", 283 | "准确率: 0.9662\n", 284 | "准确率: 0.9675\n", 285 | "准确率: 0.9655\n", 286 | "准确率: 0.9686\n", 287 | "准确率: 0.9646\n", 288 | "准确率: 0.9661\n", 289 | "准确率: 0.971\n", 290 | "准确率: 0.9706\n", 291 | "准确率: 0.9699\n", 292 | "准确率: 0.974\n", 293 | "准确率: 0.9739\n", 294 | "准确率: 0.9762\n", 295 | "准确率: 0.9782\n", 296 | "准确率: 0.9754\n", 297 | "准确率: 0.9758\n", 298 | "准确率: 0.9773\n", 299 | "准确率: 0.9766\n", 300 | "准确率: 0.974\n", 301 | "准确率: 0.9741\n", 302 | "准确率: 0.9795\n", 303 | "准确率: 0.9795\n", 304 | "准确率: 0.9824\n", 305 | "准确率: 0.9754\n", 306 | "准确率: 0.9772\n", 307 | "准确率: 0.9791\n", 308 | "准确率: 0.9803\n", 309 | "准确率: 0.9728\n", 310 | "准确率: 0.9781\n", 311 | "准确率: 0.9807\n", 312 | "准确率: 0.9811\n", 313 | "准确率: 0.9807\n", 314 | "准确率: 0.9783\n", 315 | "准确率: 0.9809\n", 316 | "准确率: 0.9801\n", 317 | "准确率: 0.9821\n", 318 | "准确率: 0.9811\n" 319 | ] 320 | } 321 | ], 322 | "source": [ 323 | "model = train('./data',device='cuda')" 324 | ] 325 | }, 326 | { 327 | "cell_type": "markdown", 328 | "id": "3125c37e", 329 | "metadata": {}, 330 | "source": [ 331 | "## 7 loss与准确度" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": 33, 337 | "id": "cfd3642e", 338 | "metadata": {}, 339 | "outputs": [], 340 | "source": [ 341 | "%matplotlib qt\n", 342 | "import matplotlib.pyplot as plt" 343 | ] 344 | }, 345 | { 346 | "cell_type": "code", 347 | "execution_count": 34, 348 | "id": "3c0a7e5e", 349 | "metadata": {}, 350 | "outputs": [], 351 | "source": [ 352 | "with open('loss.txt','r',encoding='utf-8') as f:\n", 353 | " data=f.read()\n", 354 | "data=eval(data)\n", 355 | "fig=plt.figure()\n", 356 | "plt.plot([i*10 for i in range(len(data))],data)\n", 357 | "plt.xlabel('batch_num')\n", 358 | "plt.ylabel('loss')\n", 359 | "fig.show()" 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": 35, 365 | "id": "3e48eb07", 366 | "metadata": {}, 367 | "outputs": [], 368 | "source": [ 369 | "with open('acc.txt','r',encoding='utf-8') as f:\n", 370 | " data=f.read()\n", 371 | "data=eval(data)\n", 372 | "fig=plt.figure()\n", 373 | "plt.plot([i for i in range(len(data))],data)\n", 374 | "plt.xlabel('batch_num')\n", 375 | "plt.ylabel('acc')\n", 376 | "fig.show()" 377 | ] 378 | }, 379 | { 380 | "cell_type": "code", 381 | "execution_count": 8, 382 | "id": "118e29d0", 383 | "metadata": {}, 384 | "outputs": [ 385 | { 386 | "data": { 387 | "text/plain": [ 388 | "[]" 389 | ] 390 | }, 391 | "execution_count": 8, 392 | "metadata": {}, 393 | "output_type": "execute_result" 394 | } 395 | ], 396 | "source": [] 397 | } 398 | ], 399 | "metadata": { 400 | "kernelspec": { 401 | "display_name": "Python 3 (ipykernel)", 402 | "language": "python", 403 | "name": "python3" 404 | }, 405 | "language_info": { 406 | "codemirror_mode": { 407 | "name": "ipython", 408 | "version": 3 409 | }, 410 | "file_extension": ".py", 411 | "mimetype": "text/x-python", 412 | "name": "python", 413 | "nbconvert_exporter": "python", 414 | "pygments_lexer": "ipython3", 415 | "version": "3.11.4" 416 | } 417 | }, 418 | "nbformat": 4, 419 | "nbformat_minor": 5 420 | } 421 | -------------------------------------------------------------------------------- /国科大-深度学习作业/手写数字识别/手写数字识别.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dxcf123/UCAS_DeepLearning_homework/c9a64b80b3377fbc590f8f64f495789eec4cb9ad/国科大-深度学习作业/手写数字识别/手写数字识别.pptx -------------------------------------------------------------------------------- /国科大-深度学习作业/机器翻译/bloom5-1.4b-半精度lora微调-机器翻译.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "5575fad4", 6 | "metadata": {}, 7 | "source": [ 8 | "# 源数据处理" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "286db6b0", 14 | "metadata": {}, 15 | "source": [ 16 | "## 1 导入相关包" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "id": "71e4ad43", 23 | "metadata": {}, 24 | "outputs": [ 25 | { 26 | "name": "stdout", 27 | "output_type": "stream", 28 | "text": [ 29 | "bin D:\\anac\\Lib\\site-packages\\bitsandbytes\\libbitsandbytes_cuda118.dll\n" 30 | ] 31 | } 32 | ], 33 | "source": [ 34 | "import pandas as pd\n", 35 | "import random\n", 36 | "import os\n", 37 | "from datasets import load_dataset\n", 38 | "from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer\n", 39 | "from peft import PeftModel\n", 40 | "from transformers import pipeline" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 4, 46 | "id": "308115ed", 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "en_path = r'H:\\datasets\\data\\翻译1\\test.en.txt'\n", 51 | "ch_path = r'H:\\datasets\\data\\翻译1\\test.ch.txt'\n", 52 | "csv_path=r'C:\\Users\\30535\\Desktop'" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 3, 58 | "id": "4ab20aec", 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "class TextToCsv:\n", 63 | " ## 定义tokenizer,对原始数据进行处理\n", 64 | " def __init__(self, en_path, ch_path,csv_path,text_pair_nums=30000):\n", 65 | " \"\"\"\n", 66 | " 初始化\n", 67 | " :param en_path: 英文数据路径\n", 68 | " :param ch_path: 中文数据路径\n", 69 | " :csv_path 文件保存路径\n", 70 | " :text_pair_nums: 使用多少对数据\n", 71 | " \"\"\"\n", 72 | " self.en_path = en_path # 英文路径\n", 73 | " self.ch_path = ch_path # 中文路径\n", 74 | " self.text_pair_nums=text_pair_nums\n", 75 | " \n", 76 | " # 读取原始英文数据\n", 77 | " self.en_data = self.__read_ori_data(en_path)\n", 78 | " # 读取原始中文数据\n", 79 | " self.ch_data = self.__read_ori_data(ch_path)\n", 80 | " self.x=self.return_csv(csv_path)\n", 81 | "\n", 82 | " def __read_ori_data(self, path):\n", 83 | " \"\"\"\n", 84 | " 读取原始数据\n", 85 | " :param path: 数据路径\n", 86 | " :return: 返回一个列表,每个元素是一条数据\n", 87 | " \"\"\"\n", 88 | " with open(path, 'r', encoding='utf-8') as f:\n", 89 | " data = f.read().split('\\n')[:-1]\n", 90 | " self.text_pair_nums =self.text_pair_nums if self.text_pair_nums <=len(data) else len(data)\n", 91 | " data = data[:self.text_pair_nums] \n", 92 | " return data\n", 93 | " \n", 94 | " def return_csv(self,csv_path):\n", 95 | " \"\"\"\n", 96 | " 将源数据处理成csv文件\n", 97 | " :csv_path 文件保存路径\n", 98 | " \"\"\"\n", 99 | " data=[]\n", 100 | " # 遍历所有数据,长度大于127的数据抛弃\n", 101 | " for i in range(self.text_pair_nums):\n", 102 | " if len(self.en_data[i])>127 or len(self.en_data[i])>127:\n", 103 | " continue\n", 104 | " # 英文→中文\n", 105 | " data.append([\n", 106 | " self.en_data[i],\n", 107 | " self.ch_data[i]]\n", 108 | " )\n", 109 | " # 中文→英文\n", 110 | " data.append([\n", 111 | " self.ch_data[i],\n", 112 | " self.en_data[i]]\n", 113 | " )\n", 114 | " random.shuffle(data) # 数据随机打乱\n", 115 | " csv_train=os.path.join(csv_path,'train.csv') # 训练集文件\n", 116 | " csv_test=os.path.join(csv_path,'test.csv') # 测试集文件\n", 117 | " dat=pd.DataFrame(data[:len(data)-500],columns=['src','tgt']) # 训练集\n", 118 | " dat2=pd.DataFrame(data[len(data)-500:],columns=['src','tgt']) # 测试集\n", 119 | " dat.to_csv(csv_train,index=False) # 转换为csv文件\n", 120 | " dat2.to_csv(csv_test,index=False)\n", 121 | " " 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 5, 127 | "id": "23e24831", 128 | "metadata": { 129 | "scrolled": true 130 | }, 131 | "outputs": [ 132 | { 133 | "data": { 134 | "text/plain": [ 135 | "<__main__.TextToCsv at 0x176b22e8850>" 136 | ] 137 | }, 138 | "execution_count": 5, 139 | "metadata": {}, 140 | "output_type": "execute_result" 141 | } 142 | ], 143 | "source": [ 144 | "TextToCsv(en_path,ch_path,csv_path)" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "id": "5cfcfa66", 150 | "metadata": {}, 151 | "source": [ 152 | "## 1 导入相关包" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 2, 158 | "id": "3dd34940", 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [ 162 | "from datasets import load_dataset\n", 163 | "from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer" 164 | ] 165 | }, 166 | { 167 | "cell_type": "markdown", 168 | "id": "77a540d7", 169 | "metadata": {}, 170 | "source": [ 171 | "## 2 加载数据集" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": 3, 177 | "id": "64d6fc0b", 178 | "metadata": { 179 | "scrolled": false 180 | }, 181 | "outputs": [ 182 | { 183 | "data": { 184 | "text/plain": [ 185 | "[Dataset({\n", 186 | " features: ['src', 'tgt'],\n", 187 | " num_rows: 92644\n", 188 | " }),\n", 189 | " Dataset({\n", 190 | " features: ['src', 'tgt'],\n", 191 | " num_rows: 1000\n", 192 | " })]" 193 | ] 194 | }, 195 | "execution_count": 3, 196 | "metadata": {}, 197 | "output_type": "execute_result" 198 | } 199 | ], 200 | "source": [ 201 | "data_train=r'C:\\Users\\30535\\Desktop\\train.csv'\n", 202 | "data_test=r'C:\\Users\\30535\\Desktop\\test.csv'\n", 203 | "ds=load_dataset('csv',data_files={'train':data_train, 'test': data_test},\n", 204 | " split=['train', 'test'])\n", 205 | "ds" 206 | ] 207 | }, 208 | { 209 | "cell_type": "markdown", 210 | "id": "d63ae622", 211 | "metadata": {}, 212 | "source": [ 213 | "## 4 数据处理" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 4, 219 | "id": "1d71b691", 220 | "metadata": { 221 | "scrolled": true 222 | }, 223 | "outputs": [], 224 | "source": [ 225 | "model_path=r'H:\\models\\bloom-1b4-zh'\n", 226 | "tokenizer = AutoTokenizer.from_pretrained(model_path)" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": 5, 232 | "id": "2f48676e", 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [ 236 | "def process_func(examples):\n", 237 | " MAX_LENGTH = 150\n", 238 | " contents='机器翻译:\\n' + examples['src']\n", 239 | " # 对输入与label进行编码\n", 240 | " inputs=tokenizer(contents)\n", 241 | " labels = tokenizer(text_target=examples['tgt'] + tokenizer.eos_token)\n", 242 | " input_ids=inputs[\"input_ids\"]+labels[\"input_ids\"]\n", 243 | " attention_mask=inputs[\"attention_mask\"] + labels[\"attention_mask\"]\n", 244 | " labels = [-100] * len(inputs[\"input_ids\"]) + labels[\"input_ids\"]\n", 245 | " # 数据截断\n", 246 | " if len(input_ids) > MAX_LENGTH:\n", 247 | " input_ids = input_ids[:MAX_LENGTH]\n", 248 | " attention_mask = attention_mask[:MAX_LENGTH]\n", 249 | " labels = labels[:MAX_LENGTH]\n", 250 | " return {\n", 251 | " \"input_ids\": input_ids,\n", 252 | " \"attention_mask\": attention_mask,\n", 253 | " \"labels\": labels\n", 254 | " }" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": 6, 260 | "id": "eb8f1a88", 261 | "metadata": { 262 | "scrolled": false 263 | }, 264 | "outputs": [], 265 | "source": [ 266 | "tokenized_train=ds[0].map(process_func, remove_columns=ds[0].column_names)" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": 7, 272 | "id": "a90825de", 273 | "metadata": { 274 | "scrolled": false 275 | }, 276 | "outputs": [ 277 | { 278 | "data": { 279 | "text/plain": [ 280 | "Dataset({\n", 281 | " features: ['input_ids', 'attention_mask', 'labels'],\n", 282 | " num_rows: 55750\n", 283 | "})" 284 | ] 285 | }, 286 | "execution_count": 7, 287 | "metadata": {}, 288 | "output_type": "execute_result" 289 | } 290 | ], 291 | "source": [ 292 | "tokenized_train" 293 | ] 294 | }, 295 | { 296 | "cell_type": "markdown", 297 | "id": "ad20e4d9", 298 | "metadata": {}, 299 | "source": [ 300 | "## 5 创建模型" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": 7, 306 | "id": "8f5fa333", 307 | "metadata": {}, 308 | "outputs": [], 309 | "source": [ 310 | "model=AutoModelForCausalLM.from_pretrained(model_path)" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": 8, 316 | "id": "8fec97cb", 317 | "metadata": {}, 318 | "outputs": [], 319 | "source": [ 320 | "model = model.half()\n", 321 | "model=model.to('cuda')" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": 13, 327 | "id": "e1938044", 328 | "metadata": {}, 329 | "outputs": [ 330 | { 331 | "name": "stdout", 332 | "output_type": "stream", 333 | "text": [ 334 | " 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译\n" 335 | ] 336 | } 337 | ], 338 | "source": [ 339 | "x=\"机器翻译:\\n{}\".format(\"what is this。\").strip()\n", 340 | "ipt = tokenizer(x,return_tensors='pt').to('cuda')\n", 341 | "print(tokenizer.decode(model.generate(**ipt,max_length=256, do_sample=False)[0],skip_special_tokens=True)[len(x):])" 342 | ] 343 | }, 344 | { 345 | "cell_type": "markdown", 346 | "id": "571fadbf", 347 | "metadata": {}, 348 | "source": [ 349 | "## 6 使用Lora进行微调" 350 | ] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "execution_count": 9, 355 | "id": "13bdda76", 356 | "metadata": {}, 357 | "outputs": [ 358 | { 359 | "data": { 360 | "text/plain": [ 361 | "LoraConfig(peft_type=, auto_mapping=None, base_model_name_or_path=None, revision=None, task_type=, inference_mode=False, r=8, target_modules=None, lora_alpha=8, lora_dropout=0.0, fan_in_fan_out=False, bias='none', modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None, rank_pattern={}, alpha_pattern={})" 362 | ] 363 | }, 364 | "execution_count": 9, 365 | "metadata": {}, 366 | "output_type": "execute_result" 367 | } 368 | ], 369 | "source": [ 370 | "# 6.1 创建配置文件\n", 371 | "from peft import LoraConfig,get_peft_model,TaskType\n", 372 | "comfig = LoraConfig(task_type=TaskType.CAUSAL_LM)\n", 373 | "comfig" 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": 10, 379 | "id": "936b46cd", 380 | "metadata": {}, 381 | "outputs": [], 382 | "source": [ 383 | "# 6.2 创建模型\n", 384 | "model_lora = get_peft_model(model,comfig)" 385 | ] 386 | }, 387 | { 388 | "cell_type": "code", 389 | "execution_count": 11, 390 | "id": "45513150", 391 | "metadata": {}, 392 | "outputs": [], 393 | "source": [ 394 | "model_lora=model_lora.half()" 395 | ] 396 | }, 397 | { 398 | "cell_type": "code", 399 | "execution_count": 12, 400 | "id": "33be677b", 401 | "metadata": { 402 | "scrolled": true 403 | }, 404 | "outputs": [ 405 | { 406 | "name": "stdout", 407 | "output_type": "stream", 408 | "text": [ 409 | "trainable params: 1,572,864 || all params: 1,304,684,544 || trainable%: 0.120555118647899\n" 410 | ] 411 | } 412 | ], 413 | "source": [ 414 | "model_lora.print_trainable_parameters()" 415 | ] 416 | }, 417 | { 418 | "cell_type": "markdown", 419 | "id": "f7d689a5", 420 | "metadata": {}, 421 | "source": [ 422 | "## 7 配置训练参数" 423 | ] 424 | }, 425 | { 426 | "cell_type": "code", 427 | "execution_count": 13, 428 | "id": "77a2e300", 429 | "metadata": {}, 430 | "outputs": [ 431 | { 432 | "name": "stderr", 433 | "output_type": "stream", 434 | "text": [ 435 | "Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).\n" 436 | ] 437 | } 438 | ], 439 | "source": [ 440 | "import os\n", 441 | "os.environ[\"WANDB_DISABLED\"] = \"true\" # 防止日志输出到wandb.ai\n", 442 | "args= TrainingArguments(\n", 443 | " output_dir='./modelcheak/m2',\n", 444 | " logging_dir=r'./modelcheak/m2',\n", 445 | " per_device_train_batch_size=8, # batch_size\n", 446 | " gradient_accumulation_steps=4,\n", 447 | " logging_steps=20,\n", 448 | " optim=\"adafactor\", # 使用特定的优化器优化显存\n", 449 | " save_strategy='epoch', # 每一轮保存一个模型\n", 450 | " num_train_epochs=1,\n", 451 | " adam_epsilon=1e-4\n", 452 | ")" 453 | ] 454 | }, 455 | { 456 | "cell_type": "markdown", 457 | "id": "122adaa1", 458 | "metadata": {}, 459 | "source": [ 460 | "## 8 创建训练器" 461 | ] 462 | }, 463 | { 464 | "cell_type": "code", 465 | "execution_count": 14, 466 | "id": "43b7e698", 467 | "metadata": { 468 | "scrolled": true 469 | }, 470 | "outputs": [], 471 | "source": [ 472 | "trainr=Trainer(\n", 473 | " args=args,\n", 474 | " model=model_lora,\n", 475 | " train_dataset=tokenized_train,\n", 476 | " tokenizer=tokenizer,\n", 477 | " data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer)\n", 478 | ")" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": 15, 484 | "id": "0943fb9e", 485 | "metadata": { 486 | "scrolled": true 487 | }, 488 | "outputs": [ 489 | { 490 | "name": "stderr", 491 | "output_type": "stream", 492 | "text": [ 493 | "You're using a BloomTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n" 494 | ] 495 | }, 496 | { 497 | "data": { 498 | "text/html": [ 499 | "\n", 500 | "
\n", 501 | " \n", 502 | " \n", 503 | " [2895/2895 16:39, Epoch 0/1]\n", 504 | "
\n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | " \n", 682 | " \n", 683 | " \n", 684 | " \n", 685 | " \n", 686 | " \n", 687 | " \n", 688 | " \n", 689 | " \n", 690 | " \n", 691 | " \n", 692 | " \n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | " \n", 745 | " \n", 746 | " \n", 747 | " \n", 748 | " \n", 749 | " \n", 750 | " \n", 751 | " \n", 752 | " \n", 753 | " \n", 754 | " \n", 755 | " \n", 756 | " \n", 757 | " \n", 758 | " \n", 759 | " \n", 760 | " \n", 761 | " \n", 762 | " \n", 763 | " \n", 764 | " \n", 765 | " \n", 766 | " \n", 767 | " \n", 768 | " \n", 769 | " \n", 770 | " \n", 771 | " \n", 772 | " \n", 773 | " \n", 774 | " \n", 775 | " \n", 776 | " \n", 777 | " \n", 778 | " \n", 779 | " \n", 780 | " \n", 781 | " \n", 782 | " \n", 783 | " \n", 784 | " \n", 785 | " \n", 786 | " \n", 787 | " \n", 788 | " \n", 789 | " \n", 790 | " \n", 791 | " \n", 792 | " \n", 793 | " \n", 794 | " \n", 795 | " \n", 796 | " \n", 797 | " \n", 798 | " \n", 799 | " \n", 800 | " \n", 801 | " \n", 802 | " \n", 803 | " \n", 804 | " \n", 805 | " \n", 806 | " \n", 807 | " \n", 808 | " \n", 809 | " \n", 810 | " \n", 811 | " \n", 812 | " \n", 813 | " \n", 814 | " \n", 815 | " \n", 816 | " \n", 817 | " \n", 818 | " \n", 819 | " \n", 820 | " \n", 821 | " \n", 822 | " \n", 823 | " \n", 824 | " \n", 825 | " \n", 826 | " \n", 827 | " \n", 828 | " \n", 829 | " \n", 830 | " \n", 831 | " \n", 832 | " \n", 833 | " \n", 834 | " \n", 835 | " \n", 836 | " \n", 837 | " \n", 838 | " \n", 839 | " \n", 840 | " \n", 841 | " \n", 842 | " \n", 843 | " \n", 844 | " \n", 845 | " \n", 846 | " \n", 847 | " \n", 848 | " \n", 849 | " \n", 850 | " \n", 851 | " \n", 852 | " \n", 853 | " \n", 854 | " \n", 855 | " \n", 856 | " \n", 857 | " \n", 858 | " \n", 859 | " \n", 860 | " \n", 861 | " \n", 862 | " \n", 863 | " \n", 864 | " \n", 865 | " \n", 866 | " \n", 867 | " \n", 868 | " \n", 869 | " \n", 870 | " \n", 871 | " \n", 872 | " \n", 873 | " \n", 874 | " \n", 875 | " \n", 876 | " \n", 877 | " \n", 878 | " \n", 879 | " \n", 880 | " \n", 881 | " \n", 882 | " \n", 883 | " \n", 884 | " \n", 885 | " \n", 886 | " \n", 887 | " \n", 888 | " \n", 889 | " \n", 890 | " \n", 891 | " \n", 892 | " \n", 893 | " \n", 894 | " \n", 895 | " \n", 896 | " \n", 897 | " \n", 898 | " \n", 899 | " \n", 900 | " \n", 901 | " \n", 902 | " \n", 903 | " \n", 904 | " \n", 905 | " \n", 906 | " \n", 907 | " \n", 908 | " \n", 909 | " \n", 910 | " \n", 911 | " \n", 912 | " \n", 913 | " \n", 914 | " \n", 915 | " \n", 916 | " \n", 917 | " \n", 918 | " \n", 919 | " \n", 920 | " \n", 921 | " \n", 922 | " \n", 923 | " \n", 924 | " \n", 925 | " \n", 926 | " \n", 927 | " \n", 928 | " \n", 929 | " \n", 930 | " \n", 931 | " \n", 932 | " \n", 933 | " \n", 934 | " \n", 935 | " \n", 936 | " \n", 937 | " \n", 938 | " \n", 939 | " \n", 940 | " \n", 941 | " \n", 942 | " \n", 943 | " \n", 944 | " \n", 945 | " \n", 946 | " \n", 947 | " \n", 948 | " \n", 949 | " \n", 950 | " \n", 951 | " \n", 952 | " \n", 953 | " \n", 954 | " \n", 955 | " \n", 956 | " \n", 957 | " \n", 958 | " \n", 959 | " \n", 960 | " \n", 961 | " \n", 962 | " \n", 963 | " \n", 964 | " \n", 965 | " \n", 966 | " \n", 967 | " \n", 968 | " \n", 969 | " \n", 970 | " \n", 971 | " \n", 972 | " \n", 973 | " \n", 974 | " \n", 975 | " \n", 976 | " \n", 977 | " \n", 978 | " \n", 979 | " \n", 980 | " \n", 981 | " \n", 982 | " \n", 983 | " \n", 984 | " \n", 985 | " \n", 986 | " \n", 987 | " \n", 988 | " \n", 989 | " \n", 990 | " \n", 991 | " \n", 992 | " \n", 993 | " \n", 994 | " \n", 995 | " \n", 996 | " \n", 997 | " \n", 998 | " \n", 999 | " \n", 1000 | " \n", 1001 | " \n", 1002 | " \n", 1003 | " \n", 1004 | " \n", 1005 | " \n", 1006 | " \n", 1007 | " \n", 1008 | " \n", 1009 | " \n", 1010 | " \n", 1011 | " \n", 1012 | " \n", 1013 | " \n", 1014 | " \n", 1015 | " \n", 1016 | " \n", 1017 | " \n", 1018 | " \n", 1019 | " \n", 1020 | " \n", 1021 | " \n", 1022 | " \n", 1023 | " \n", 1024 | " \n", 1025 | " \n", 1026 | " \n", 1027 | " \n", 1028 | " \n", 1029 | " \n", 1030 | " \n", 1031 | " \n", 1032 | " \n", 1033 | " \n", 1034 | " \n", 1035 | " \n", 1036 | " \n", 1037 | " \n", 1038 | " \n", 1039 | " \n", 1040 | " \n", 1041 | " \n", 1042 | " \n", 1043 | " \n", 1044 | " \n", 1045 | " \n", 1046 | " \n", 1047 | " \n", 1048 | " \n", 1049 | " \n", 1050 | " \n", 1051 | " \n", 1052 | " \n", 1053 | " \n", 1054 | " \n", 1055 | " \n", 1056 | " \n", 1057 | " \n", 1058 | " \n", 1059 | " \n", 1060 | " \n", 1061 | " \n", 1062 | " \n", 1063 | " \n", 1064 | " \n", 1065 | " \n", 1066 | " \n", 1067 | " \n", 1068 | " \n", 1069 | " \n", 1070 | " \n", 1071 | " \n", 1072 | " \n", 1073 | " \n", 1074 | " \n", 1075 | " \n", 1076 | " \n", 1077 | " \n", 1078 | " \n", 1079 | " \n", 1080 | " \n", 1081 | " \n", 1082 | " \n", 1083 | " \n", 1084 | " \n", 1085 | " \n", 1086 | " \n", 1087 | " \n", 1088 | " \n", 1089 | " \n", 1090 | "
StepTraining Loss
204.216200
403.881300
603.435600
803.148200
1002.938600
1202.803400
1402.758000
1602.715000
1802.743700
2002.581500
2202.601400
2402.578200
2602.614400
2802.578600
3002.507100
3202.524400
3402.486600
3602.466500
3802.434400
4002.426700
4202.447100
4402.395600
4602.487500
4802.446500
5002.366600
5202.352500
5402.341800
5602.360900
5802.403900
6002.392700
6202.388200
6402.343200
6602.329800
6802.360600
7002.359100
7202.337200
7402.377400
7602.315700
7802.336100
8002.356400
8202.331200
8402.363500
8602.288700
8802.278200
9002.301600
9202.299200
9402.293300
9602.316700
9802.314400
10002.347900
10202.292300
10402.304800
10602.254600
10802.329300
11002.275200
11202.264700
11402.264100
11602.283500
11802.285400
12002.265200
12202.260300
12402.259500
12602.217500
12802.258600
13002.319800
13202.286100
13402.264600
13602.311700
13802.246600
14002.272400
14202.234700
14402.275300
14602.249600
14802.289000
15002.308300
15202.233400
15402.269700
15602.284000
15802.267100
16002.292000
16202.233200
16402.301200
16602.256600
16802.203100
17002.253200
17202.260900
17402.235300
17602.270900
17802.221100
18002.230900
18202.316900
18402.290100
18602.209900
18802.367500
19002.231800
19202.187900
19402.260100
19602.199900
19802.265300
20002.269600
20202.235300
20402.205000
20602.322900
20802.300000
21002.267600
21202.232900
21402.234700
21602.286400
21802.350300
22002.197800
22202.219000
22402.259900
22602.250600
22802.211000
23002.250400
23202.274900
23402.263600
23602.160700
23802.165600
24002.266600
24202.197600
24402.293700
24602.318400
24802.275400
25002.263800
25202.203200
25402.271400
25602.201900
25802.170800
26002.234300
26202.189400
26402.245800
26602.235400
26802.175000
27002.260800
27202.236400
27402.247900
27602.278200
27802.218100
28002.250500
28202.270200
28402.209600
28602.138600
28802.176500

" 1091 | ], 1092 | "text/plain": [ 1093 | "" 1094 | ] 1095 | }, 1096 | "metadata": {}, 1097 | "output_type": "display_data" 1098 | }, 1099 | { 1100 | "data": { 1101 | "text/plain": [ 1102 | "TrainOutput(global_step=2895, training_loss=2.3527866607297065, metrics={'train_runtime': 1000.5068, 'train_samples_per_second': 92.597, 'train_steps_per_second': 2.894, 'total_flos': 3.09147635810304e+16, 'train_loss': 2.3527866607297065, 'epoch': 1.0})" 1103 | ] 1104 | }, 1105 | "execution_count": 15, 1106 | "metadata": {}, 1107 | "output_type": "execute_result" 1108 | } 1109 | ], 1110 | "source": [ 1111 | "trainr.train()" 1112 | ] 1113 | }, 1114 | { 1115 | "cell_type": "markdown", 1116 | "id": "7ae220ef", 1117 | "metadata": {}, 1118 | "source": [ 1119 | "## 9 权重合并与" 1120 | ] 1121 | }, 1122 | { 1123 | "cell_type": "code", 1124 | "execution_count": 16, 1125 | "id": "5aca0d3b", 1126 | "metadata": {}, 1127 | "outputs": [], 1128 | "source": [ 1129 | "from peft import PeftModel\n", 1130 | "# model_id 是checkpoint那个路径\n", 1131 | "prft_model=PeftModel.from_pretrained(model=model,model_id=r\"C:\\Users\\30535\\Desktop\\CodeProgram\\Python\\deepstudy\\code2\\使用Transformer进行中英文翻译\\modelcheak\\m2\\checkpoint-2895\")\n", 1132 | "# 权重合并\n", 1133 | "merge_model=prft_model.merge_and_unload()" 1134 | ] 1135 | }, 1136 | { 1137 | "cell_type": "code", 1138 | "execution_count": 7, 1139 | "id": "59fc687a", 1140 | "metadata": {}, 1141 | "outputs": [], 1142 | "source": [ 1143 | "# 模型保存\n", 1144 | "merge_model.save_pretrained('./modelcheak/trans11')" 1145 | ] 1146 | }, 1147 | { 1148 | "cell_type": "code", 1149 | "execution_count": 14, 1150 | "id": "876ab7dc", 1151 | "metadata": {}, 1152 | "outputs": [ 1153 | { 1154 | "name": "stdout", 1155 | "output_type": "stream", 1156 | "text": [ 1157 | "这是什么?\n" 1158 | ] 1159 | } 1160 | ], 1161 | "source": [ 1162 | "x=\"机器翻译:\\n{}\".format(\"what is this。\").strip()\n", 1163 | "ipt = tokenizer(x,return_tensors='pt').to('cuda')\n", 1164 | "print(tokenizer.decode(merge_model.generate(**ipt,max_length=256, do_sample=False)[0],skip_special_tokens=True)[len(x):])" 1165 | ] 1166 | }, 1167 | { 1168 | "cell_type": "code", 1169 | "execution_count": 19, 1170 | "id": "ae2b6ec1", 1171 | "metadata": {}, 1172 | "outputs": [ 1173 | { 1174 | "name": "stdout", 1175 | "output_type": "stream", 1176 | "text": [ 1177 | "What is this?\n" 1178 | ] 1179 | } 1180 | ], 1181 | "source": [ 1182 | "x=\"机器翻译:\\n{}\".format(\"这又是什么呢?\").strip()\n", 1183 | "ipt = tokenizer(x,return_tensors='pt').to('cuda')\n", 1184 | "print(tokenizer.decode(merge_model.generate(**ipt,max_length=256, do_sample=False)[0],skip_special_tokens=True)[len(x):])" 1185 | ] 1186 | }, 1187 | { 1188 | "cell_type": "code", 1189 | "execution_count": 17, 1190 | "id": "ce272f6e", 1191 | "metadata": {}, 1192 | "outputs": [ 1193 | { 1194 | "name": "stdout", 1195 | "output_type": "stream", 1196 | "text": [ 1197 | "0.0\n", 1198 | "0.04\n", 1199 | "0.08\n", 1200 | "时间 18.47494339942932\n", 1201 | "15.121825586870461\n" 1202 | ] 1203 | } 1204 | ], 1205 | "source": [ 1206 | "import re\n", 1207 | "import sacrebleu\n", 1208 | "def is_english_sentence(sentence):\n", 1209 | " # 使用正则表达式检查句子中是否包含英文字母\n", 1210 | " english_pattern = re.compile(r'[a-zA-Z]')\n", 1211 | " match = english_pattern.search(sentence)\n", 1212 | " \n", 1213 | " if match:\n", 1214 | " return True\n", 1215 | " else:\n", 1216 | " return False\n", 1217 | "from nltk.translate.bleu_score import sentence_bleu\n", 1218 | "from nltk.translate.bleu_score import SmoothingFunction\n", 1219 | "\n", 1220 | "smooth = SmoothingFunction().method1\n", 1221 | "bleu_scores=[]\n", 1222 | "m1,m2=[],[]\n", 1223 | "m3,m4=[],[]\n", 1224 | "import time\n", 1225 | "t=time.time()\n", 1226 | "for i in range(100):\n", 1227 | " if i%40==0:\n", 1228 | " print(i/len(ds[1]['src']))\n", 1229 | " x=\"机器翻译:\\n{}\".format(ds[1]['src'][i]).strip()\n", 1230 | " ipt = tokenizer(x,return_tensors='pt').to('cuda')\n", 1231 | " y=tokenizer.decode(merge_model.generate(**ipt,max_length=150, do_sample=False)[0],skip_special_tokens=True)[len(x):]\n", 1232 | " if is_english_sentence(ds[1]['tgt'][i]):\n", 1233 | " m1.append(ds[1]['tgt'][i])\n", 1234 | " m2.append([y])\n", 1235 | " else:\n", 1236 | " m3.append(list(ds[1]['tgt'][i][:-1]))\n", 1237 | " m4.append([list(y)[:-1]])\n", 1238 | "print('时间',time.time()-t)\n", 1239 | "smooth = SmoothingFunction().method1\n", 1240 | "b1=[sacrebleu.sentence_bleu(candidate, refs).score for candidate, refs in zip(m1, m2)]\n", 1241 | "for i in range(len(m4)):\n", 1242 | " b2 = sentence_bleu(m4[i], m3[i], weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smooth)*100\n", 1243 | " b1.append(b2)\n", 1244 | "print(sum(b1)/100)" 1245 | ] 1246 | }, 1247 | { 1248 | "cell_type": "code", 1249 | "execution_count": null, 1250 | "id": "8162c152", 1251 | "metadata": {}, 1252 | "outputs": [], 1253 | "source": [] 1254 | }, 1255 | { 1256 | "cell_type": "markdown", 1257 | "id": "9a092afd", 1258 | "metadata": {}, 1259 | "source": [ 1260 | "## 9 模型推理" 1261 | ] 1262 | }, 1263 | { 1264 | "cell_type": "code", 1265 | "execution_count": 32, 1266 | "id": "df9ad78c", 1267 | "metadata": {}, 1268 | "outputs": [], 1269 | "source": [ 1270 | "from transformers import pipeline" 1271 | ] 1272 | }, 1273 | { 1274 | "cell_type": "code", 1275 | "execution_count": 33, 1276 | "id": "7145468b", 1277 | "metadata": {}, 1278 | "outputs": [ 1279 | { 1280 | "name": "stderr", 1281 | "output_type": "stream", 1282 | "text": [ 1283 | "The model 'BloomForCausalLM' is not supported for text2text-generation. Supported models are ['BartForConditionalGeneration', 'BigBirdPegasusForConditionalGeneration', 'BlenderbotForConditionalGeneration', 'BlenderbotSmallForConditionalGeneration', 'EncoderDecoderModel', 'FSMTForConditionalGeneration', 'GPTSanJapaneseForConditionalGeneration', 'LEDForConditionalGeneration', 'LongT5ForConditionalGeneration', 'M2M100ForConditionalGeneration', 'MarianMTModel', 'MBartForConditionalGeneration', 'MT5ForConditionalGeneration', 'MvpForConditionalGeneration', 'NllbMoeForConditionalGeneration', 'PegasusForConditionalGeneration', 'PegasusXForConditionalGeneration', 'PLBartForConditionalGeneration', 'ProphetNetForConditionalGeneration', 'SwitchTransformersForConditionalGeneration', 'T5ForConditionalGeneration', 'UMT5ForConditionalGeneration', 'XLMProphetNetForConditionalGeneration'].\n" 1284 | ] 1285 | } 1286 | ], 1287 | "source": [ 1288 | "pipe=pipeline('text2text-generation',model=merge_model,tokenizer=tokenizer,device=0)" 1289 | ] 1290 | }, 1291 | { 1292 | "cell_type": "code", 1293 | "execution_count": 35, 1294 | "id": "89d02ec1", 1295 | "metadata": { 1296 | "scrolled": true 1297 | }, 1298 | "outputs": [ 1299 | { 1300 | "data": { 1301 | "text/plain": [ 1302 | "[{'generated_text': '机器翻译:\\n我有一个苹果I have a Apple'}]" 1303 | ] 1304 | }, 1305 | "execution_count": 35, 1306 | "metadata": {}, 1307 | "output_type": "execute_result" 1308 | } 1309 | ], 1310 | "source": [ 1311 | "pipe('机器翻译:\\n'+'我有一个苹果',max_length=30,do_sample=False)" 1312 | ] 1313 | } 1314 | ], 1315 | "metadata": { 1316 | "kernelspec": { 1317 | "display_name": "Python 3 (ipykernel)", 1318 | "language": "python", 1319 | "name": "python3" 1320 | }, 1321 | "language_info": { 1322 | "codemirror_mode": { 1323 | "name": "ipython", 1324 | "version": 3 1325 | }, 1326 | "file_extension": ".py", 1327 | "mimetype": "text/x-python", 1328 | "name": "python", 1329 | "nbconvert_exporter": "python", 1330 | "pygments_lexer": "ipython3", 1331 | "version": "3.11.4" 1332 | } 1333 | }, 1334 | "nbformat": 4, 1335 | "nbformat_minor": 5 1336 | } 1337 | -------------------------------------------------------------------------------- /国科大-深度学习作业/机器翻译/bloom5-2.5b-半精度lora微调-机器翻译.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "5575fad4", 6 | "metadata": {}, 7 | "source": [ 8 | "# 源数据处理" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "286db6b0", 14 | "metadata": {}, 15 | "source": [ 16 | "## 1 导入相关包" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "id": "71e4ad43", 23 | "metadata": {}, 24 | "outputs": [ 25 | { 26 | "name": "stdout", 27 | "output_type": "stream", 28 | "text": [ 29 | "bin D:\\anac\\Lib\\site-packages\\bitsandbytes\\libbitsandbytes_cuda118.dll\n" 30 | ] 31 | } 32 | ], 33 | "source": [ 34 | "import pandas as pd\n", 35 | "import random\n", 36 | "import os\n", 37 | "from datasets import load_dataset\n", 38 | "from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer\n", 39 | "from peft import PeftModel\n", 40 | "from transformers import pipeline" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 5, 46 | "id": "308115ed", 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "en_path = r'H:\\datasets\\data\\翻译1\\test.en.txt'\n", 51 | "ch_path = r'H:\\datasets\\data\\翻译1\\test.ch.txt'\n", 52 | "csv_path=r'C:\\Users\\30535\\Desktop'" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 6, 58 | "id": "4ab20aec", 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "class TextToCsv:\n", 63 | " ## 定义tokenizer,对原始数据进行处理\n", 64 | " def __init__(self, en_path, ch_path,csv_path,text_pair_nums=50000):\n", 65 | " \"\"\"\n", 66 | " 初始化\n", 67 | " :param en_path: 英文数据路径\n", 68 | " :param ch_path: 中文数据路径\n", 69 | " :csv_path 文件保存路径\n", 70 | " :text_pair_nums: 使用多少对数据\n", 71 | " \"\"\"\n", 72 | " self.en_path = en_path # 英文路径\n", 73 | " self.ch_path = ch_path # 中文路径\n", 74 | " self.text_pair_nums=text_pair_nums\n", 75 | " \n", 76 | " # 读取原始英文数据\n", 77 | " self.en_data = self.__read_ori_data(en_path)\n", 78 | " # 读取原始中文数据\n", 79 | " self.ch_data = self.__read_ori_data(ch_path)\n", 80 | " self.x=self.return_csv(csv_path)\n", 81 | "\n", 82 | " def __read_ori_data(self, path):\n", 83 | " \"\"\"\n", 84 | " 读取原始数据\n", 85 | " :param path: 数据路径\n", 86 | " :return: 返回一个列表,每个元素是一条数据\n", 87 | " \"\"\"\n", 88 | " with open(path, 'r', encoding='utf-8') as f:\n", 89 | " data = f.read().split('\\n')[:-1]\n", 90 | " self.text_pair_nums =self.text_pair_nums if self.text_pair_nums <=len(data) else len(data)\n", 91 | " data = data[:self.text_pair_nums] \n", 92 | " return data\n", 93 | " \n", 94 | " def return_csv(self,csv_path):\n", 95 | " \"\"\"\n", 96 | " 将源数据处理成csv文件\n", 97 | " :csv_path 文件保存路径\n", 98 | " \"\"\"\n", 99 | " data=[]\n", 100 | " # 遍历所有数据,长度大于127的数据抛弃\n", 101 | " for i in range(self.text_pair_nums):\n", 102 | " if len(self.en_data[i])>127 or len(self.en_data[i])>127:\n", 103 | " continue\n", 104 | " # 英文→中文\n", 105 | " data.append([\n", 106 | " self.en_data[i],\n", 107 | " self.ch_data[i]]\n", 108 | " )\n", 109 | " # 中文→英文\n", 110 | " data.append([\n", 111 | " self.ch_data[i],\n", 112 | " self.en_data[i]]\n", 113 | " )\n", 114 | " random.shuffle(data) # 数据随机打乱\n", 115 | " csv_train=os.path.join(csv_path,'train.csv') # 训练集文件\n", 116 | " csv_test=os.path.join(csv_path,'test.csv') # 测试集文件\n", 117 | " dat=pd.DataFrame(data[:len(data)-1000],columns=['src','tgt']) # 训练集\n", 118 | " dat2=pd.DataFrame(data[len(data)-1000:],columns=['src','tgt']) # 测试集\n", 119 | " dat.to_csv(csv_train,index=False) # 转换为csv文件\n", 120 | " dat2.to_csv(csv_test,index=False)\n", 121 | " " 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 7, 127 | "id": "23e24831", 128 | "metadata": { 129 | "scrolled": true 130 | }, 131 | "outputs": [ 132 | { 133 | "data": { 134 | "text/plain": [ 135 | "<__main__.TextToCsv at 0x13d8a27ab10>" 136 | ] 137 | }, 138 | "execution_count": 7, 139 | "metadata": {}, 140 | "output_type": "execute_result" 141 | } 142 | ], 143 | "source": [ 144 | "TextToCsv(en_path,ch_path,csv_path)" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "id": "5cfcfa66", 150 | "metadata": {}, 151 | "source": [ 152 | "## 1 导入相关包" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 2, 158 | "id": "3dd34940", 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [ 162 | "from datasets import load_dataset\n", 163 | "from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer" 164 | ] 165 | }, 166 | { 167 | "cell_type": "markdown", 168 | "id": "77a540d7", 169 | "metadata": {}, 170 | "source": [ 171 | "## 2 加载数据集" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": 3, 177 | "id": "64d6fc0b", 178 | "metadata": { 179 | "scrolled": true 180 | }, 181 | "outputs": [ 182 | { 183 | "data": { 184 | "text/plain": [ 185 | "[Dataset({\n", 186 | " features: ['src', 'tgt'],\n", 187 | " num_rows: 92644\n", 188 | " }),\n", 189 | " Dataset({\n", 190 | " features: ['src', 'tgt'],\n", 191 | " num_rows: 1000\n", 192 | " })]" 193 | ] 194 | }, 195 | "execution_count": 3, 196 | "metadata": {}, 197 | "output_type": "execute_result" 198 | } 199 | ], 200 | "source": [ 201 | "data_train=r'C:\\Users\\30535\\Desktop\\train.csv'\n", 202 | "data_test=r'C:\\Users\\30535\\Desktop\\test.csv'\n", 203 | "ds=load_dataset('csv',data_files={'train':data_train, 'test': data_test},\n", 204 | " split=['train', 'test'])\n", 205 | "ds" 206 | ] 207 | }, 208 | { 209 | "cell_type": "markdown", 210 | "id": "d63ae622", 211 | "metadata": {}, 212 | "source": [ 213 | "## 4 数据处理" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 4, 219 | "id": "1d71b691", 220 | "metadata": { 221 | "scrolled": true 222 | }, 223 | "outputs": [], 224 | "source": [ 225 | "model_path=r'H:\\models\\bloom-2b5-zh'\n", 226 | "tokenizer = AutoTokenizer.from_pretrained(model_path)" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": 11, 232 | "id": "2f48676e", 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [ 236 | "def process_func(examples):\n", 237 | " MAX_LENGTH = 150\n", 238 | " contents='机器翻译:\\n' + examples['src']\n", 239 | " # 对输入与label进行编码\n", 240 | " inputs=tokenizer(contents)\n", 241 | " labels = tokenizer(text_target=examples['tgt'] + tokenizer.eos_token)\n", 242 | " input_ids=inputs[\"input_ids\"]+labels[\"input_ids\"]\n", 243 | " attention_mask=inputs[\"attention_mask\"] + labels[\"attention_mask\"]\n", 244 | " labels = [-100] * len(inputs[\"input_ids\"]) + labels[\"input_ids\"]\n", 245 | " # 数据截断\n", 246 | " if len(input_ids) > MAX_LENGTH:\n", 247 | " input_ids = input_ids[:MAX_LENGTH]\n", 248 | " attention_mask = attention_mask[:MAX_LENGTH]\n", 249 | " labels = labels[:MAX_LENGTH]\n", 250 | " return {\n", 251 | " \"input_ids\": input_ids,\n", 252 | " \"attention_mask\": attention_mask,\n", 253 | " \"labels\": labels\n", 254 | " }" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": 12, 260 | "id": "eb8f1a88", 261 | "metadata": { 262 | "scrolled": true 263 | }, 264 | "outputs": [ 265 | { 266 | "data": { 267 | "application/vnd.jupyter.widget-view+json": { 268 | "model_id": "7ab01b873b7841eeac5440ce5443f642", 269 | "version_major": 2, 270 | "version_minor": 0 271 | }, 272 | "text/plain": [ 273 | "Map: 0%| | 0/92644 [00:00, auto_mapping=None, base_model_name_or_path=None, revision=None, task_type=, inference_mode=False, r=8, target_modules=None, lora_alpha=8, lora_dropout=0.0, fan_in_fan_out=False, bias='none', modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None, rank_pattern={}, alpha_pattern={})" 372 | ] 373 | }, 374 | "execution_count": 16, 375 | "metadata": {}, 376 | "output_type": "execute_result" 377 | } 378 | ], 379 | "source": [ 380 | "# 6.1 创建配置文件\n", 381 | "from peft import LoraConfig,get_peft_model,TaskType\n", 382 | "comfig = LoraConfig(task_type=TaskType.CAUSAL_LM)\n", 383 | "comfig" 384 | ] 385 | }, 386 | { 387 | "cell_type": "code", 388 | "execution_count": 17, 389 | "id": "5cc4ec1c", 390 | "metadata": {}, 391 | "outputs": [], 392 | "source": [ 393 | "# 6.2 创建模型\n", 394 | "model_lora = get_peft_model(model,comfig)" 395 | ] 396 | }, 397 | { 398 | "cell_type": "code", 399 | "execution_count": 18, 400 | "id": "4495dfdb", 401 | "metadata": {}, 402 | "outputs": [], 403 | "source": [ 404 | "model_lora=model_lora.half()" 405 | ] 406 | }, 407 | { 408 | "cell_type": "code", 409 | "execution_count": 13, 410 | "id": "e1938044", 411 | "metadata": {}, 412 | "outputs": [ 413 | { 414 | "name": "stdout", 415 | "output_type": "stream", 416 | "text": [ 417 | " 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译: 翻译\n" 418 | ] 419 | } 420 | ], 421 | "source": [ 422 | "x=\"机器翻译:\\n{}\".format(\"what is this。\").strip()\n", 423 | "ipt = tokenizer(x,return_tensors='pt').to('cuda')\n", 424 | "print(tokenizer.decode(model.generate(**ipt,max_length=256, do_sample=False)[0],skip_special_tokens=True)[len(x):])" 425 | ] 426 | }, 427 | { 428 | "cell_type": "code", 429 | "execution_count": 19, 430 | "id": "33be677b", 431 | "metadata": {}, 432 | "outputs": [ 433 | { 434 | "name": "stdout", 435 | "output_type": "stream", 436 | "text": [ 437 | "trainable params: 2,457,600 || all params: 2,480,893,440 || trainable%: 0.09906108663820724\n" 438 | ] 439 | } 440 | ], 441 | "source": [ 442 | "model_lora.print_trainable_parameters()" 443 | ] 444 | }, 445 | { 446 | "cell_type": "markdown", 447 | "id": "f7d689a5", 448 | "metadata": {}, 449 | "source": [ 450 | "## 7 配置训练参数" 451 | ] 452 | }, 453 | { 454 | "cell_type": "code", 455 | "execution_count": 23, 456 | "id": "77a2e300", 457 | "metadata": {}, 458 | "outputs": [ 459 | { 460 | "name": "stderr", 461 | "output_type": "stream", 462 | "text": [ 463 | "Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).\n" 464 | ] 465 | } 466 | ], 467 | "source": [ 468 | "import os\n", 469 | "os.environ[\"WANDB_DISABLED\"] = \"true\" # 防止日志输出到wandb.ai\n", 470 | "args= TrainingArguments(\n", 471 | " output_dir='./modelcheak/m3',\n", 472 | " logging_dir=r'./modelcheak/m3',\n", 473 | " per_device_train_batch_size=8, # batch_size\n", 474 | " gradient_accumulation_steps=4,\n", 475 | " logging_steps=20,\n", 476 | " optim=\"adafactor\", # 使用特定的优化器优化显存\n", 477 | " save_strategy='epoch', # 每一轮保存一个模型\n", 478 | " num_train_epochs=1,\n", 479 | " adam_epsilon=1e-4\n", 480 | ")" 481 | ] 482 | }, 483 | { 484 | "cell_type": "markdown", 485 | "id": "122adaa1", 486 | "metadata": {}, 487 | "source": [ 488 | "## 8 创建训练器" 489 | ] 490 | }, 491 | { 492 | "cell_type": "code", 493 | "execution_count": 24, 494 | "id": "43b7e698", 495 | "metadata": { 496 | "scrolled": true 497 | }, 498 | "outputs": [], 499 | "source": [ 500 | "trainr=Trainer(\n", 501 | " args=args,\n", 502 | " model=model_lora,\n", 503 | " train_dataset=tokenized_train,\n", 504 | " tokenizer=tokenizer,\n", 505 | " data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer)\n", 506 | ")" 507 | ] 508 | }, 509 | { 510 | "cell_type": "code", 511 | "execution_count": 25, 512 | "id": "0943fb9e", 513 | "metadata": { 514 | "scrolled": true 515 | }, 516 | "outputs": [ 517 | { 518 | "data": { 519 | "text/html": [ 520 | "\n", 521 | "

\n", 522 | " \n", 523 | " \n", 524 | " [2895/2895 26:55, Epoch 0/1]\n", 525 | "
\n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | " \n", 682 | " \n", 683 | " \n", 684 | " \n", 685 | " \n", 686 | " \n", 687 | " \n", 688 | " \n", 689 | " \n", 690 | " \n", 691 | " \n", 692 | " \n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | " \n", 745 | " \n", 746 | " \n", 747 | " \n", 748 | " \n", 749 | " \n", 750 | " \n", 751 | " \n", 752 | " \n", 753 | " \n", 754 | " \n", 755 | " \n", 756 | " \n", 757 | " \n", 758 | " \n", 759 | " \n", 760 | " \n", 761 | " \n", 762 | " \n", 763 | " \n", 764 | " \n", 765 | " \n", 766 | " \n", 767 | " \n", 768 | " \n", 769 | " \n", 770 | " \n", 771 | " \n", 772 | " \n", 773 | " \n", 774 | " \n", 775 | " \n", 776 | " \n", 777 | " \n", 778 | " \n", 779 | " \n", 780 | " \n", 781 | " \n", 782 | " \n", 783 | " \n", 784 | " \n", 785 | " \n", 786 | " \n", 787 | " \n", 788 | " \n", 789 | " \n", 790 | " \n", 791 | " \n", 792 | " \n", 793 | " \n", 794 | " \n", 795 | " \n", 796 | " \n", 797 | " \n", 798 | " \n", 799 | " \n", 800 | " \n", 801 | " \n", 802 | " \n", 803 | " \n", 804 | " \n", 805 | " \n", 806 | " \n", 807 | " \n", 808 | " \n", 809 | " \n", 810 | " \n", 811 | " \n", 812 | " \n", 813 | " \n", 814 | " \n", 815 | " \n", 816 | " \n", 817 | " \n", 818 | " \n", 819 | " \n", 820 | " \n", 821 | " \n", 822 | " \n", 823 | " \n", 824 | " \n", 825 | " \n", 826 | " \n", 827 | " \n", 828 | " \n", 829 | " \n", 830 | " \n", 831 | " \n", 832 | " \n", 833 | " \n", 834 | " \n", 835 | " \n", 836 | " \n", 837 | " \n", 838 | " \n", 839 | " \n", 840 | " \n", 841 | " \n", 842 | " \n", 843 | " \n", 844 | " \n", 845 | " \n", 846 | " \n", 847 | " \n", 848 | " \n", 849 | " \n", 850 | " \n", 851 | " \n", 852 | " \n", 853 | " \n", 854 | " \n", 855 | " \n", 856 | " \n", 857 | " \n", 858 | " \n", 859 | " \n", 860 | " \n", 861 | " \n", 862 | " \n", 863 | " \n", 864 | " \n", 865 | " \n", 866 | " \n", 867 | " \n", 868 | " \n", 869 | " \n", 870 | " \n", 871 | " \n", 872 | " \n", 873 | " \n", 874 | " \n", 875 | " \n", 876 | " \n", 877 | " \n", 878 | " \n", 879 | " \n", 880 | " \n", 881 | " \n", 882 | " \n", 883 | " \n", 884 | " \n", 885 | " \n", 886 | " \n", 887 | " \n", 888 | " \n", 889 | " \n", 890 | " \n", 891 | " \n", 892 | " \n", 893 | " \n", 894 | " \n", 895 | " \n", 896 | " \n", 897 | " \n", 898 | " \n", 899 | " \n", 900 | " \n", 901 | " \n", 902 | " \n", 903 | " \n", 904 | " \n", 905 | " \n", 906 | " \n", 907 | " \n", 908 | " \n", 909 | " \n", 910 | " \n", 911 | " \n", 912 | " \n", 913 | " \n", 914 | " \n", 915 | " \n", 916 | " \n", 917 | " \n", 918 | " \n", 919 | " \n", 920 | " \n", 921 | " \n", 922 | " \n", 923 | " \n", 924 | " \n", 925 | " \n", 926 | " \n", 927 | " \n", 928 | " \n", 929 | " \n", 930 | " \n", 931 | " \n", 932 | " \n", 933 | " \n", 934 | " \n", 935 | " \n", 936 | " \n", 937 | " \n", 938 | " \n", 939 | " \n", 940 | " \n", 941 | " \n", 942 | " \n", 943 | " \n", 944 | " \n", 945 | " \n", 946 | " \n", 947 | " \n", 948 | " \n", 949 | " \n", 950 | " \n", 951 | " \n", 952 | " \n", 953 | " \n", 954 | " \n", 955 | " \n", 956 | " \n", 957 | " \n", 958 | " \n", 959 | " \n", 960 | " \n", 961 | " \n", 962 | " \n", 963 | " \n", 964 | " \n", 965 | " \n", 966 | " \n", 967 | " \n", 968 | " \n", 969 | " \n", 970 | " \n", 971 | " \n", 972 | " \n", 973 | " \n", 974 | " \n", 975 | " \n", 976 | " \n", 977 | " \n", 978 | " \n", 979 | " \n", 980 | " \n", 981 | " \n", 982 | " \n", 983 | " \n", 984 | " \n", 985 | " \n", 986 | " \n", 987 | " \n", 988 | " \n", 989 | " \n", 990 | " \n", 991 | " \n", 992 | " \n", 993 | " \n", 994 | " \n", 995 | " \n", 996 | " \n", 997 | " \n", 998 | " \n", 999 | " \n", 1000 | " \n", 1001 | " \n", 1002 | " \n", 1003 | " \n", 1004 | " \n", 1005 | " \n", 1006 | " \n", 1007 | " \n", 1008 | " \n", 1009 | " \n", 1010 | " \n", 1011 | " \n", 1012 | " \n", 1013 | " \n", 1014 | " \n", 1015 | " \n", 1016 | " \n", 1017 | " \n", 1018 | " \n", 1019 | " \n", 1020 | " \n", 1021 | " \n", 1022 | " \n", 1023 | " \n", 1024 | " \n", 1025 | " \n", 1026 | " \n", 1027 | " \n", 1028 | " \n", 1029 | " \n", 1030 | " \n", 1031 | " \n", 1032 | " \n", 1033 | " \n", 1034 | " \n", 1035 | " \n", 1036 | " \n", 1037 | " \n", 1038 | " \n", 1039 | " \n", 1040 | " \n", 1041 | " \n", 1042 | " \n", 1043 | " \n", 1044 | " \n", 1045 | " \n", 1046 | " \n", 1047 | " \n", 1048 | " \n", 1049 | " \n", 1050 | " \n", 1051 | " \n", 1052 | " \n", 1053 | " \n", 1054 | " \n", 1055 | " \n", 1056 | " \n", 1057 | " \n", 1058 | " \n", 1059 | " \n", 1060 | " \n", 1061 | " \n", 1062 | " \n", 1063 | " \n", 1064 | " \n", 1065 | " \n", 1066 | " \n", 1067 | " \n", 1068 | " \n", 1069 | " \n", 1070 | " \n", 1071 | " \n", 1072 | " \n", 1073 | " \n", 1074 | " \n", 1075 | " \n", 1076 | " \n", 1077 | " \n", 1078 | " \n", 1079 | " \n", 1080 | " \n", 1081 | " \n", 1082 | " \n", 1083 | " \n", 1084 | " \n", 1085 | " \n", 1086 | " \n", 1087 | " \n", 1088 | " \n", 1089 | " \n", 1090 | " \n", 1091 | " \n", 1092 | " \n", 1093 | " \n", 1094 | " \n", 1095 | " \n", 1096 | " \n", 1097 | " \n", 1098 | " \n", 1099 | " \n", 1100 | " \n", 1101 | " \n", 1102 | " \n", 1103 | " \n", 1104 | " \n", 1105 | " \n", 1106 | " \n", 1107 | " \n", 1108 | " \n", 1109 | " \n", 1110 | " \n", 1111 | "
StepTraining Loss
203.051300
402.798100
602.639400
802.523500
1002.463700
1202.394600
1402.382400
1602.352300
1802.386000
2002.254800
2202.277300
2402.282400
2602.333800
2802.297500
3002.238400
3202.272300
3402.236500
3602.234700
3802.200000
4002.168200
4202.210200
4402.161000
4602.249600
4802.242700
5002.168200
5202.131500
5402.117000
5602.153500
5802.199800
6002.192000
6202.178200
6402.116700
6602.113900
6802.149100
7002.175600
7202.148900
7402.182600
7602.117200
7802.152000
8002.163600
8202.117500
8402.143200
8602.079800
8802.101300
9002.093500
9202.111800
9402.096700
9602.136000
9802.136700
10002.162500
10202.095500
10402.116100
10602.063400
10802.125400
11002.095900
11202.088300
11402.076200
11602.100000
11802.105600
12002.086100
12202.060700
12402.068800
12602.031200
12802.091700
13002.137900
13202.107100
13402.075100
13602.118500
13802.069200
14002.089800
14202.062000
14402.085700
14602.076500
14802.075900
15002.106000
15202.060200
15402.085200
15602.098200
15802.076200
16002.116000
16202.058100
16402.114300
16602.079700
16802.032300
17002.082100
17202.073100
17402.069700
17602.087300
17802.036300
18002.067500
18202.122600
18402.090300
18602.044800
18802.181300
19002.036300
19202.008100
19402.083100
19602.022100
19802.089300
20002.084200
20202.058900
20402.021600
20602.132200
20802.114400
21002.082700
21202.054100
21402.053700
21602.083600
21802.170300
22002.032100
22202.043000
22402.077600
22602.061400
22802.016500
23002.074100
23202.101000
23402.072500
23601.994900
23801.989000
24002.077200
24202.039100
24402.113900
24602.129500
24802.087300
25002.096200
25202.025700
25402.103600
25602.031100
25801.992500
26002.047700
26202.028600
26402.057700
26602.072400
26802.011900
27002.067200
27202.052000
27402.070400
27602.095900
27802.038700
28002.063300
28202.083900
28402.022500
28601.947500
28801.995800

" 1112 | ], 1113 | "text/plain": [ 1114 | "" 1115 | ] 1116 | }, 1117 | "metadata": {}, 1118 | "output_type": "display_data" 1119 | }, 1120 | { 1121 | "data": { 1122 | "text/plain": [ 1123 | "TrainOutput(global_step=2895, training_loss=2.131528674506153, metrics={'train_runtime': 1615.9318, 'train_samples_per_second': 57.332, 'train_steps_per_second': 1.792, 'total_flos': 6.0358179078144e+16, 'train_loss': 2.131528674506153, 'epoch': 1.0})" 1124 | ] 1125 | }, 1126 | "execution_count": 25, 1127 | "metadata": {}, 1128 | "output_type": "execute_result" 1129 | } 1130 | ], 1131 | "source": [ 1132 | "trainr.train()" 1133 | ] 1134 | }, 1135 | { 1136 | "cell_type": "markdown", 1137 | "id": "7ae220ef", 1138 | "metadata": {}, 1139 | "source": [ 1140 | "## 9 权重合并与" 1141 | ] 1142 | }, 1143 | { 1144 | "cell_type": "code", 1145 | "execution_count": 7, 1146 | "id": "5aca0d3b", 1147 | "metadata": {}, 1148 | "outputs": [], 1149 | "source": [ 1150 | "from peft import PeftModel\n", 1151 | "# model_id 是checkpoint那个路径\n", 1152 | "prft_model=PeftModel.from_pretrained(model=model,model_id=r\"C:\\Users\\30535\\Desktop\\CodeProgram\\Python\\deepstudy\\code2\\使用Transformer进行中英文翻译\\modelcheak\\m3\\checkpoint-2895\")\n", 1153 | "# 权重合并\n", 1154 | "merge_model=prft_model.merge_and_unload()" 1155 | ] 1156 | }, 1157 | { 1158 | "cell_type": "code", 1159 | "execution_count": 7, 1160 | "id": "59fc687a", 1161 | "metadata": {}, 1162 | "outputs": [], 1163 | "source": [ 1164 | "# 模型保存\n", 1165 | "merge_model.save_pretrained('./modelcheak/trans11')" 1166 | ] 1167 | }, 1168 | { 1169 | "cell_type": "code", 1170 | "execution_count": 21, 1171 | "id": "876ab7dc", 1172 | "metadata": {}, 1173 | "outputs": [ 1174 | { 1175 | "name": "stdout", 1176 | "output_type": "stream", 1177 | "text": [ 1178 | "这是什么?\n" 1179 | ] 1180 | } 1181 | ], 1182 | "source": [ 1183 | "x=\"机器翻译:\\n{}\".format(\"what is this。\").strip()\n", 1184 | "ipt = tokenizer(x,return_tensors='pt').to('cuda')\n", 1185 | "print(tokenizer.decode(merge_model.generate(**ipt,max_length=256, do_sample=False)[0],skip_special_tokens=True)[len(x):])" 1186 | ] 1187 | }, 1188 | { 1189 | "cell_type": "code", 1190 | "execution_count": 19, 1191 | "id": "ae2b6ec1", 1192 | "metadata": {}, 1193 | "outputs": [ 1194 | { 1195 | "name": "stdout", 1196 | "output_type": "stream", 1197 | "text": [ 1198 | "What is this?\n" 1199 | ] 1200 | } 1201 | ], 1202 | "source": [ 1203 | "x=\"机器翻译:\\n{}\".format(\"这又是什么呢?\").strip()\n", 1204 | "ipt = tokenizer(x,return_tensors='pt').to('cuda')\n", 1205 | "print(tokenizer.decode(merge_model.generate(**ipt,max_length=256, do_sample=False)[0],skip_special_tokens=True)[len(x):])" 1206 | ] 1207 | }, 1208 | { 1209 | "cell_type": "code", 1210 | "execution_count": 8, 1211 | "id": "ce272f6e", 1212 | "metadata": {}, 1213 | "outputs": [ 1214 | { 1215 | "name": "stdout", 1216 | "output_type": "stream", 1217 | "text": [ 1218 | "0.0\n", 1219 | "0.04\n", 1220 | "0.08\n", 1221 | "时间 25.055421829223633\n", 1222 | "17.161924767287793\n" 1223 | ] 1224 | } 1225 | ], 1226 | "source": [ 1227 | "import re\n", 1228 | "import sacrebleu\n", 1229 | "def is_english_sentence(sentence):\n", 1230 | " # 使用正则表达式检查句子中是否包含英文字母\n", 1231 | " english_pattern = re.compile(r'[a-zA-Z]')\n", 1232 | " match = english_pattern.search(sentence)\n", 1233 | " \n", 1234 | " if match:\n", 1235 | " return True\n", 1236 | " else:\n", 1237 | " return False\n", 1238 | "from nltk.translate.bleu_score import sentence_bleu\n", 1239 | "from nltk.translate.bleu_score import SmoothingFunction\n", 1240 | "\n", 1241 | "smooth = SmoothingFunction().method1\n", 1242 | "bleu_scores=[]\n", 1243 | "m1,m2=[],[]\n", 1244 | "m3,m4=[],[]\n", 1245 | "import time\n", 1246 | "t=time.time()\n", 1247 | "for i in range(100):\n", 1248 | " if i%40==0:\n", 1249 | " print(i/len(ds[1]['src']))\n", 1250 | " x=\"机器翻译:\\n{}\".format(ds[1]['src'][i]).strip()\n", 1251 | " ipt = tokenizer(x,return_tensors='pt').to('cuda')\n", 1252 | " y=tokenizer.decode(merge_model.generate(**ipt,max_length=150, do_sample=False)[0],skip_special_tokens=True)[len(x):]\n", 1253 | " if is_english_sentence(ds[1]['tgt'][i]):\n", 1254 | " m1.append(ds[1]['tgt'][i])\n", 1255 | " m2.append([y])\n", 1256 | " else:\n", 1257 | " m3.append(list(ds[1]['tgt'][i][:-1]))\n", 1258 | " m4.append([list(y)[:-1]])\n", 1259 | "print('时间',time.time()-t)\n", 1260 | "smooth = SmoothingFunction().method1\n", 1261 | "b1=[sacrebleu.sentence_bleu(candidate, refs).score for candidate, refs in zip(m1, m2)]\n", 1262 | "for i in range(len(m4)):\n", 1263 | " b2 = sentence_bleu(m4[i], m3[i], weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smooth)*100\n", 1264 | " b1.append(b2)\n", 1265 | "print(sum(b1)/100)" 1266 | ] 1267 | }, 1268 | { 1269 | "cell_type": "code", 1270 | "execution_count": null, 1271 | "id": "8162c152", 1272 | "metadata": {}, 1273 | "outputs": [], 1274 | "source": [] 1275 | }, 1276 | { 1277 | "cell_type": "markdown", 1278 | "id": "9a092afd", 1279 | "metadata": {}, 1280 | "source": [ 1281 | "## 9 模型推理" 1282 | ] 1283 | }, 1284 | { 1285 | "cell_type": "code", 1286 | "execution_count": 32, 1287 | "id": "df9ad78c", 1288 | "metadata": {}, 1289 | "outputs": [], 1290 | "source": [ 1291 | "from transformers import pipeline" 1292 | ] 1293 | }, 1294 | { 1295 | "cell_type": "code", 1296 | "execution_count": 33, 1297 | "id": "7145468b", 1298 | "metadata": {}, 1299 | "outputs": [ 1300 | { 1301 | "name": "stderr", 1302 | "output_type": "stream", 1303 | "text": [ 1304 | "The model 'BloomForCausalLM' is not supported for text2text-generation. Supported models are ['BartForConditionalGeneration', 'BigBirdPegasusForConditionalGeneration', 'BlenderbotForConditionalGeneration', 'BlenderbotSmallForConditionalGeneration', 'EncoderDecoderModel', 'FSMTForConditionalGeneration', 'GPTSanJapaneseForConditionalGeneration', 'LEDForConditionalGeneration', 'LongT5ForConditionalGeneration', 'M2M100ForConditionalGeneration', 'MarianMTModel', 'MBartForConditionalGeneration', 'MT5ForConditionalGeneration', 'MvpForConditionalGeneration', 'NllbMoeForConditionalGeneration', 'PegasusForConditionalGeneration', 'PegasusXForConditionalGeneration', 'PLBartForConditionalGeneration', 'ProphetNetForConditionalGeneration', 'SwitchTransformersForConditionalGeneration', 'T5ForConditionalGeneration', 'UMT5ForConditionalGeneration', 'XLMProphetNetForConditionalGeneration'].\n" 1305 | ] 1306 | } 1307 | ], 1308 | "source": [ 1309 | "pipe=pipeline('text2text-generation',model=merge_model,tokenizer=tokenizer,device=0)" 1310 | ] 1311 | }, 1312 | { 1313 | "cell_type": "code", 1314 | "execution_count": 35, 1315 | "id": "89d02ec1", 1316 | "metadata": { 1317 | "scrolled": true 1318 | }, 1319 | "outputs": [ 1320 | { 1321 | "data": { 1322 | "text/plain": [ 1323 | "[{'generated_text': '机器翻译:\\n我有一个苹果I have a Apple'}]" 1324 | ] 1325 | }, 1326 | "execution_count": 35, 1327 | "metadata": {}, 1328 | "output_type": "execute_result" 1329 | } 1330 | ], 1331 | "source": [ 1332 | "pipe('机器翻译:\\n'+'我有一个苹果',max_length=30,do_sample=False)" 1333 | ] 1334 | } 1335 | ], 1336 | "metadata": { 1337 | "kernelspec": { 1338 | "display_name": "Python 3 (ipykernel)", 1339 | "language": "python", 1340 | "name": "python3" 1341 | }, 1342 | "language_info": { 1343 | "codemirror_mode": { 1344 | "name": "ipython", 1345 | "version": 3 1346 | }, 1347 | "file_extension": ".py", 1348 | "mimetype": "text/x-python", 1349 | "name": "python", 1350 | "nbconvert_exporter": "python", 1351 | "pygments_lexer": "ipython3", 1352 | "version": "3.11.4" 1353 | } 1354 | }, 1355 | "nbformat": 4, 1356 | "nbformat_minor": 5 1357 | } 1358 | -------------------------------------------------------------------------------- /国科大-深度学习作业/机器翻译/bloom5-6.4b-4bit Qlora微调-机器翻译.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "5575fad4", 6 | "metadata": {}, 7 | "source": [ 8 | "# 源数据处理" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "286db6b0", 14 | "metadata": {}, 15 | "source": [ 16 | "## 1 导入相关包" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "id": "71e4ad43", 23 | "metadata": {}, 24 | "outputs": [ 25 | { 26 | "name": "stdout", 27 | "output_type": "stream", 28 | "text": [ 29 | "bin D:\\anac\\Lib\\site-packages\\bitsandbytes\\libbitsandbytes_cuda118.dll\n" 30 | ] 31 | } 32 | ], 33 | "source": [ 34 | "from datasets import load_dataset\n", 35 | "from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "id": "77a540d7", 41 | "metadata": {}, 42 | "source": [ 43 | "## 2 加载数据集" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 2, 49 | "id": "64d6fc0b", 50 | "metadata": { 51 | "scrolled": false 52 | }, 53 | "outputs": [ 54 | { 55 | "data": { 56 | "text/plain": [ 57 | "[Dataset({\n", 58 | " features: ['src', 'tgt'],\n", 59 | " num_rows: 92644\n", 60 | " }),\n", 61 | " Dataset({\n", 62 | " features: ['src', 'tgt'],\n", 63 | " num_rows: 1000\n", 64 | " })]" 65 | ] 66 | }, 67 | "execution_count": 2, 68 | "metadata": {}, 69 | "output_type": "execute_result" 70 | } 71 | ], 72 | "source": [ 73 | "data_train=r'C:\\Users\\30535\\Desktop\\train.csv'\n", 74 | "data_test=r'C:\\Users\\30535\\Desktop\\test.csv'\n", 75 | "ds=load_dataset('csv',data_files={'train':data_train, 'test': data_test},\n", 76 | " split=['train', 'test'])\n", 77 | "ds" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "id": "d63ae622", 83 | "metadata": {}, 84 | "source": [ 85 | "## 4 数据处理" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 3, 91 | "id": "1d71b691", 92 | "metadata": { 93 | "scrolled": true 94 | }, 95 | "outputs": [], 96 | "source": [ 97 | "model_path=r'H:\\models\\bloom-6b4-zh'\n", 98 | "tokenizer = AutoTokenizer.from_pretrained(model_path)" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 4, 104 | "id": "2f48676e", 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "def process_func(examples):\n", 109 | " MAX_LENGTH = 150\n", 110 | " contents='机器翻译:\\n' + examples['src']\n", 111 | " # 对输入与label进行编码\n", 112 | " inputs=tokenizer(contents)\n", 113 | " labels = tokenizer(text_target=examples['tgt'] + tokenizer.eos_token)\n", 114 | " input_ids=inputs[\"input_ids\"]+labels[\"input_ids\"]\n", 115 | " attention_mask=inputs[\"attention_mask\"] + labels[\"attention_mask\"]\n", 116 | " labels = [-100] * len(inputs[\"input_ids\"]) + labels[\"input_ids\"]\n", 117 | " # 数据截断\n", 118 | " if len(input_ids) > MAX_LENGTH:\n", 119 | " input_ids = input_ids[:MAX_LENGTH]\n", 120 | " attention_mask = attention_mask[:MAX_LENGTH]\n", 121 | " labels = labels[:MAX_LENGTH]\n", 122 | " return {\n", 123 | " \"input_ids\": input_ids,\n", 124 | " \"attention_mask\": attention_mask,\n", 125 | " \"labels\": labels\n", 126 | " }" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 5, 132 | "id": "eb8f1a88", 133 | "metadata": { 134 | "scrolled": true 135 | }, 136 | "outputs": [], 137 | "source": [ 138 | "tokenized_train=ds[0].map(process_func, remove_columns=ds[0].column_names)" 139 | ] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "id": "ad20e4d9", 144 | "metadata": {}, 145 | "source": [ 146 | "## 5 创建模型" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 23, 152 | "id": "8f5fa333", 153 | "metadata": {}, 154 | "outputs": [ 155 | { 156 | "data": { 157 | "application/vnd.jupyter.widget-view+json": { 158 | "model_id": "03b5230f75f74f3586684146b472e8f5", 159 | "version_major": 2, 160 | "version_minor": 0 161 | }, 162 | "text/plain": [ 163 | "Loading checkpoint shards: 0%| | 0/3 [00:00, auto_mapping=None, base_model_name_or_path=None, revision=None, task_type=, inference_mode=False, r=8, target_modules=None, lora_alpha=8, lora_dropout=0.0, fan_in_fan_out=False, bias='none', modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None, rank_pattern={}, alpha_pattern={})" 200 | ] 201 | }, 202 | "execution_count": 8, 203 | "metadata": {}, 204 | "output_type": "execute_result" 205 | } 206 | ], 207 | "source": [ 208 | "# 6.1 创建配置文件\n", 209 | "from peft import LoraConfig,get_peft_model,TaskType\n", 210 | "comfig = LoraConfig(task_type=TaskType.CAUSAL_LM)\n", 211 | "comfig" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 9, 217 | "id": "5cc4ec1c", 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "# 6.2 创建模型\n", 222 | "model_lora = get_peft_model(model,comfig)" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 10, 228 | "id": "4495dfdb", 229 | "metadata": {}, 230 | "outputs": [], 231 | "source": [ 232 | "model_lora=model_lora.half()" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": 11, 238 | "id": "33be677b", 239 | "metadata": {}, 240 | "outputs": [ 241 | { 242 | "name": "stdout", 243 | "output_type": "stream", 244 | "text": [ 245 | "trainable params: 3,932,160 || all params: 6,234,353,664 || trainable%: 0.06307245645536737\n" 246 | ] 247 | } 248 | ], 249 | "source": [ 250 | "model_lora.print_trainable_parameters()" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 12, 256 | "id": "43d387bd", 257 | "metadata": {}, 258 | "outputs": [ 259 | { 260 | "data": { 261 | "text/plain": [ 262 | "device(type='cuda', index=0)" 263 | ] 264 | }, 265 | "execution_count": 12, 266 | "metadata": {}, 267 | "output_type": "execute_result" 268 | } 269 | ], 270 | "source": [ 271 | "model.device" 272 | ] 273 | }, 274 | { 275 | "cell_type": "markdown", 276 | "id": "f7d689a5", 277 | "metadata": {}, 278 | "source": [ 279 | "## 7 配置训练参数" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": 13, 285 | "id": "77a2e300", 286 | "metadata": {}, 287 | "outputs": [ 288 | { 289 | "name": "stderr", 290 | "output_type": "stream", 291 | "text": [ 292 | "Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).\n" 293 | ] 294 | } 295 | ], 296 | "source": [ 297 | "import os\n", 298 | "os.environ[\"WANDB_DISABLED\"] = \"true\" # 防止日志输出到wandb.ai\n", 299 | "args= TrainingArguments(\n", 300 | " output_dir='./modelcheak/m5',\n", 301 | " logging_dir=r'./modelcheak/m5',\n", 302 | " per_device_train_batch_size=16, # batch_size\n", 303 | " gradient_accumulation_steps=2,\n", 304 | " logging_steps=20,\n", 305 | " optim=\"paged_adamw_32bit\", # 分页优化器,QLora要使用\n", 306 | " num_train_epochs=1,\n", 307 | " gradient_checkpointing=True\n", 308 | ")" 309 | ] 310 | }, 311 | { 312 | "cell_type": "markdown", 313 | "id": "122adaa1", 314 | "metadata": {}, 315 | "source": [ 316 | "## 8 创建训练器" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": 14, 322 | "id": "43b7e698", 323 | "metadata": { 324 | "scrolled": true 325 | }, 326 | "outputs": [], 327 | "source": [ 328 | "trainr=Trainer(\n", 329 | " args=args,\n", 330 | " model=model_lora,\n", 331 | " train_dataset=tokenized_train,\n", 332 | " tokenizer=tokenizer,\n", 333 | " data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer)\n", 334 | ")" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": 15, 340 | "id": "0943fb9e", 341 | "metadata": { 342 | "scrolled": true 343 | }, 344 | "outputs": [ 345 | { 346 | "name": "stderr", 347 | "output_type": "stream", 348 | "text": [ 349 | "You're using a BloomTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n", 350 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\n" 351 | ] 352 | }, 353 | { 354 | "data": { 355 | "text/html": [ 356 | "\n", 357 | "

\n", 358 | " \n", 359 | " \n", 360 | " [2895/2895 1:28:01, Epoch 0/1]\n", 361 | "
\n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | " \n", 682 | " \n", 683 | " \n", 684 | " \n", 685 | " \n", 686 | " \n", 687 | " \n", 688 | " \n", 689 | " \n", 690 | " \n", 691 | " \n", 692 | " \n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | " \n", 745 | " \n", 746 | " \n", 747 | " \n", 748 | " \n", 749 | " \n", 750 | " \n", 751 | " \n", 752 | " \n", 753 | " \n", 754 | " \n", 755 | " \n", 756 | " \n", 757 | " \n", 758 | " \n", 759 | " \n", 760 | " \n", 761 | " \n", 762 | " \n", 763 | " \n", 764 | " \n", 765 | " \n", 766 | " \n", 767 | " \n", 768 | " \n", 769 | " \n", 770 | " \n", 771 | " \n", 772 | " \n", 773 | " \n", 774 | " \n", 775 | " \n", 776 | " \n", 777 | " \n", 778 | " \n", 779 | " \n", 780 | " \n", 781 | " \n", 782 | " \n", 783 | " \n", 784 | " \n", 785 | " \n", 786 | " \n", 787 | " \n", 788 | " \n", 789 | " \n", 790 | " \n", 791 | " \n", 792 | " \n", 793 | " \n", 794 | " \n", 795 | " \n", 796 | " \n", 797 | " \n", 798 | " \n", 799 | " \n", 800 | " \n", 801 | " \n", 802 | " \n", 803 | " \n", 804 | " \n", 805 | " \n", 806 | " \n", 807 | " \n", 808 | " \n", 809 | " \n", 810 | " \n", 811 | " \n", 812 | " \n", 813 | " \n", 814 | " \n", 815 | " \n", 816 | " \n", 817 | " \n", 818 | " \n", 819 | " \n", 820 | " \n", 821 | " \n", 822 | " \n", 823 | " \n", 824 | " \n", 825 | " \n", 826 | " \n", 827 | " \n", 828 | " \n", 829 | " \n", 830 | " \n", 831 | " \n", 832 | " \n", 833 | " \n", 834 | " \n", 835 | " \n", 836 | " \n", 837 | " \n", 838 | " \n", 839 | " \n", 840 | " \n", 841 | " \n", 842 | " \n", 843 | " \n", 844 | " \n", 845 | " \n", 846 | " \n", 847 | " \n", 848 | " \n", 849 | " \n", 850 | " \n", 851 | " \n", 852 | " \n", 853 | " \n", 854 | " \n", 855 | " \n", 856 | " \n", 857 | " \n", 858 | " \n", 859 | " \n", 860 | " \n", 861 | " \n", 862 | " \n", 863 | " \n", 864 | " \n", 865 | " \n", 866 | " \n", 867 | " \n", 868 | " \n", 869 | " \n", 870 | " \n", 871 | " \n", 872 | " \n", 873 | " \n", 874 | " \n", 875 | " \n", 876 | " \n", 877 | " \n", 878 | " \n", 879 | " \n", 880 | " \n", 881 | " \n", 882 | " \n", 883 | " \n", 884 | " \n", 885 | " \n", 886 | " \n", 887 | " \n", 888 | " \n", 889 | " \n", 890 | " \n", 891 | " \n", 892 | " \n", 893 | " \n", 894 | " \n", 895 | " \n", 896 | " \n", 897 | " \n", 898 | " \n", 899 | " \n", 900 | " \n", 901 | " \n", 902 | " \n", 903 | " \n", 904 | " \n", 905 | " \n", 906 | " \n", 907 | " \n", 908 | " \n", 909 | " \n", 910 | " \n", 911 | " \n", 912 | " \n", 913 | " \n", 914 | " \n", 915 | " \n", 916 | " \n", 917 | " \n", 918 | " \n", 919 | " \n", 920 | " \n", 921 | " \n", 922 | " \n", 923 | " \n", 924 | " \n", 925 | " \n", 926 | " \n", 927 | " \n", 928 | " \n", 929 | " \n", 930 | " \n", 931 | " \n", 932 | " \n", 933 | " \n", 934 | " \n", 935 | " \n", 936 | " \n", 937 | " \n", 938 | " \n", 939 | " \n", 940 | " \n", 941 | " \n", 942 | " \n", 943 | " \n", 944 | " \n", 945 | " \n", 946 | " \n", 947 | "
StepTraining Loss
203.916700
403.506800
602.920000
802.535500
1002.373500
1202.276600
1402.214600
1602.185300
1802.239700
2002.076600
2202.107100
2402.104600
2602.128600
2802.110000
3002.053500
3202.071100
3402.038100
3602.044400
3802.023100
4001.984200
4202.018300
4401.977400
4602.079600
4802.044400
5001.980800
5201.964300
5401.936100
5601.966300
5802.018200
6002.009400
6201.987000
6401.953200
6601.938400
6801.976200
7001.980300
7201.972900
7401.992600
7601.955700
7801.959400
8002.001400
8201.944000
8401.977500
8601.914500
8801.930400
9001.932700
9201.925500
9401.922200
9601.942200
9801.955500
10001.991600
10201.927100
10401.922700
10601.911200
10801.951500
11001.915900
11201.914500
11401.914400
11601.926300
11801.920900
12001.924900
12201.885200
12401.892300
12601.865000
12801.906500
13001.973800
13201.946900
13401.899200
13601.955200
13801.887600
14001.904500
14201.883300
14401.920400
14601.921200
14801.911600
15001.945900
15201.894300
15401.927600
15601.933200
15801.903100
16001.955200
16201.904200
16401.949300
16601.912000
16801.856300
17001.922300
17201.903400
17401.897600
17601.917000
17801.884800
18001.907800
18201.950800
18401.943300
18601.878800
18801.992300
19001.884800
19201.848900
19401.913400
19601.856400
19801.898000
20001.918800
20201.899200
20401.850000
20601.958700
20801.936000
21001.918100
21201.890400
21401.883500
21601.917800
21801.987700
22001.872300
22201.879200
22401.919000
22601.899200
22801.850200
23001.914800
23201.946000
23401.899600
23601.826000
23801.820800
24001.910500
24201.864200
24401.955800
24601.964200
24801.927400
25001.918800
25201.860900
25401.932000
25601.875100
25801.808300
26001.886500
26201.866700
26401.902100
26601.916600
26801.857100
27001.928600
27201.908400
27401.908300
27601.934100
27801.874800
28001.885300
28201.904400
28401.857000
28601.793300
28801.848700

" 948 | ], 949 | "text/plain": [ 950 | "" 951 | ] 952 | }, 953 | "metadata": {}, 954 | "output_type": "display_data" 955 | }, 956 | { 957 | "data": { 958 | "text/plain": [ 959 | "TrainOutput(global_step=2895, training_loss=1.9793793515220208, metrics={'train_runtime': 5295.8774, 'train_samples_per_second': 17.494, 'train_steps_per_second': 0.547, 'total_flos': 8.78107053612073e+16, 'train_loss': 1.9793793515220208, 'epoch': 1.0})" 960 | ] 961 | }, 962 | "execution_count": 15, 963 | "metadata": {}, 964 | "output_type": "execute_result" 965 | } 966 | ], 967 | "source": [ 968 | "trainr.train()" 969 | ] 970 | }, 971 | { 972 | "cell_type": "markdown", 973 | "id": "7ae220ef", 974 | "metadata": {}, 975 | "source": [ 976 | "## 9 权重合并与" 977 | ] 978 | }, 979 | { 980 | "cell_type": "code", 981 | "execution_count": 19, 982 | "id": "9e8cbedf", 983 | "metadata": {}, 984 | "outputs": [ 985 | { 986 | "data": { 987 | "text/plain": [ 988 | "device(type='cuda', index=0)" 989 | ] 990 | }, 991 | "execution_count": 19, 992 | "metadata": {}, 993 | "output_type": "execute_result" 994 | } 995 | ], 996 | "source": [ 997 | "model.device" 998 | ] 999 | }, 1000 | { 1001 | "cell_type": "code", 1002 | "execution_count": 35, 1003 | "id": "5aca0d3b", 1004 | "metadata": {}, 1005 | "outputs": [], 1006 | "source": [ 1007 | "from peft import PeftModel\n", 1008 | "# model_id 是checkpoint那个路径\n", 1009 | "prft_model=PeftModel.from_pretrained(model=model,model_id=r\"C:\\Users\\30535\\Desktop\\CodeProgram\\Python\\deepstudy\\code2\\使用Transformer进行中英文翻译\\modelcheak\\m5\\checkpoint-2500\")\n", 1010 | "# 权重合并\n", 1011 | "prft_model=prft_model.to('cuda')" 1012 | ] 1013 | }, 1014 | { 1015 | "cell_type": "code", 1016 | "execution_count": null, 1017 | "id": "59fc687a", 1018 | "metadata": {}, 1019 | "outputs": [], 1020 | "source": [ 1021 | "# 模型保存\n", 1022 | "# merge_model.save_pretrained('./modelcheak/trans11')" 1023 | ] 1024 | }, 1025 | { 1026 | "cell_type": "code", 1027 | "execution_count": 39, 1028 | "id": "ce272f6e", 1029 | "metadata": {}, 1030 | "outputs": [ 1031 | { 1032 | "name": "stdout", 1033 | "output_type": "stream", 1034 | "text": [ 1035 | "0.0\n", 1036 | "被翻译句子: 我只是在帮她。\n", 1037 | "翻译结果 I'm just doing her a favor.\n", 1038 | "\n", 1039 | "被翻译句子: I imagined myself in a courtroom at his trial, facing down the bearded man who has haunted my dreams over the last nine years.\n", 1040 | "翻译结果 我幻想自己在他审判时,面对着那张胡须浓密的脸,他一直在我梦里纠缠不休。\n", 1041 | "\n", 1042 | "被翻译句子: There's a good place nearby. \n", 1043 | "翻译结果 附近有个好地方。\n", 1044 | "\n", 1045 | "被翻译句子: 他是从外地来的货郎 \n", 1046 | "翻译结果 He's a local landlord\n", 1047 | "\n", 1048 | "被翻译句子: 但是不要忘记,埃利森是一个创业者和梦想家,他缔造了一个公司、一种文化,事实上缔造了整个行业。\n", 1049 | "翻译结果 But don't forget, Elison is a creator and a dreamer, he created a corporation, a culture, in fact he created the whole industry.\n", 1050 | "\n", 1051 | "时间 4.347065210342407\n", 1052 | "0.3682862200723244\n" 1053 | ] 1054 | } 1055 | ], 1056 | "source": [ 1057 | "import re\n", 1058 | "import sacrebleu\n", 1059 | "def is_english_sentence(sentence):\n", 1060 | " # 使用正则表达式检查句子中是否包含英文字母\n", 1061 | " english_pattern = re.compile(r'[a-zA-Z]')\n", 1062 | " match = english_pattern.search(sentence)\n", 1063 | " \n", 1064 | " if match:\n", 1065 | " return True\n", 1066 | " else:\n", 1067 | " return False\n", 1068 | "from nltk.translate.bleu_score import sentence_bleu\n", 1069 | "from nltk.translate.bleu_score import SmoothingFunction\n", 1070 | "\n", 1071 | "smooth = SmoothingFunction().method1\n", 1072 | "bleu_scores=[]\n", 1073 | "m1,m2=[],[]\n", 1074 | "m3,m4=[],[]\n", 1075 | "import time\n", 1076 | "t=time.time()\n", 1077 | "for i in range(len(ds[1]['src'])):\n", 1078 | " if i%40==0:\n", 1079 | " print(i/len(ds[1]['src']))\n", 1080 | " x=\"机器翻译:\\n{}\".format(ds[1]['src'][i]).strip()\n", 1081 | " ipt = tokenizer(x,return_tensors='pt').to('cuda')\n", 1082 | "# print('被翻译句子: ',ds[1]['src'][i])\n", 1083 | " y=tokenizer.decode(prft_model.generate(**ipt,max_length=150, do_sample=False)[0],skip_special_tokens=True)[len(x):]\n", 1084 | "# print('翻译结果: ',y)\n", 1085 | "# print()\n", 1086 | " if is_english_sentence(ds[1]['tgt'][i]):\n", 1087 | " m1.append(ds[1]['tgt'][i])\n", 1088 | " m2.append([y])\n", 1089 | " else:\n", 1090 | " m3.append(list(ds[1]['tgt'][i][:-1]))\n", 1091 | " m4.append([list(y)[:-1]])\n", 1092 | "# print('时间',time.time()-t)\n", 1093 | "smooth = SmoothingFunction().method1\n", 1094 | "b1=[sacrebleu.sentence_bleu(candidate, refs).score for candidate, refs in zip(m1, m2)]\n", 1095 | "for i in range(len(m4)):\n", 1096 | " b2 = sentence_bleu(m4[i], m3[i], weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smooth)*100\n", 1097 | " b1.append(b2)\n", 1098 | "print(sum(b1)/len(ds[1]['src']))" 1099 | ] 1100 | }, 1101 | { 1102 | "cell_type": "code", 1103 | "execution_count": null, 1104 | "id": "8162c152", 1105 | "metadata": {}, 1106 | "outputs": [], 1107 | "source": [] 1108 | }, 1109 | { 1110 | "cell_type": "markdown", 1111 | "id": "9a092afd", 1112 | "metadata": {}, 1113 | "source": [ 1114 | "## 9 模型推理" 1115 | ] 1116 | }, 1117 | { 1118 | "cell_type": "code", 1119 | "execution_count": null, 1120 | "id": "df9ad78c", 1121 | "metadata": {}, 1122 | "outputs": [], 1123 | "source": [ 1124 | "from transformers import pipeline" 1125 | ] 1126 | }, 1127 | { 1128 | "cell_type": "code", 1129 | "execution_count": null, 1130 | "id": "7145468b", 1131 | "metadata": {}, 1132 | "outputs": [], 1133 | "source": [ 1134 | "pipe=pipeline('text2text-generation',model=merge_model,tokenizer=tokenizer,device=0)" 1135 | ] 1136 | }, 1137 | { 1138 | "cell_type": "code", 1139 | "execution_count": null, 1140 | "id": "89d02ec1", 1141 | "metadata": { 1142 | "scrolled": true 1143 | }, 1144 | "outputs": [], 1145 | "source": [ 1146 | "pipe('机器翻译:\\n'+'我有一个苹果',max_length=30,do_sample=False)" 1147 | ] 1148 | } 1149 | ], 1150 | "metadata": { 1151 | "kernelspec": { 1152 | "display_name": "Python 3 (ipykernel)", 1153 | "language": "python", 1154 | "name": "python3" 1155 | }, 1156 | "language_info": { 1157 | "codemirror_mode": { 1158 | "name": "ipython", 1159 | "version": 3 1160 | }, 1161 | "file_extension": ".py", 1162 | "mimetype": "text/x-python", 1163 | "name": "python", 1164 | "nbconvert_exporter": "python", 1165 | "pygments_lexer": "ipython3", 1166 | "version": "3.11.4" 1167 | } 1168 | }, 1169 | "nbformat": 4, 1170 | "nbformat_minor": 5 1171 | } 1172 | --------------------------------------------------------------------------------