├── README.md ├── build_dataset.py ├── config.py ├── dataset.py ├── decoder-only.png ├── emb.py ├── gpt.py ├── inference.py ├── tokenizer.py ├── train_gpt.py ├── train_tokenizer.py └── 纳兰性德诗集.json /README.md: -------------------------------------------------------------------------------- 1 | # chatgpt 2 | 3 | simple decoder-only GTP model in pytorch 4 | 5 | ![](decoder-only.png) 6 | 7 | ## 依赖 8 | 9 | ``` 10 | pip install tqdm torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu118 -i https://mirrors.aliyun.com/pypi/simple/ 11 | ``` 12 | 13 | ## 配置(config.py) 14 | 15 | GPT_MODE可选值为generate或者chat,切换模式后需要重新构建dataset. 16 | 17 | generate:使用诗体直接训练,即单词接龙风格 18 | 19 | ``` 20 | 谢却荼蘼,一片月明如水 21 | 篆香消,犹未睡,早鸦啼 22 | 嫩寒无赖罗衣薄,休傍阑干角 23 | 最愁人,灯欲落,雁还飞 24 | ``` 25 | 26 | chat:使用chatml格式训练,其中user是诗标题,assitant是诗体,即对话风格(指定诗标题,自动写诗体) 27 | 28 | ``` 29 | <|im_start|>user 30 | 酒泉子·谢却荼蘼 31 | <|im_end|> 32 | <|im_start|>assistant 33 | 谢却荼蘼,一片月明如水 34 | 篆香消,犹未睡,早鸦啼 35 | 嫩寒无赖罗衣薄,休傍阑干角 36 | 最愁人,灯欲落,雁还飞 37 | <|im_end|> 38 | ``` 39 | 40 | ## 训练 41 | 42 | 训练tokenizer 43 | 44 | ``` 45 | python train_tokenizer.py 46 | ``` 47 | 48 | 构建dataset 49 | 50 | ``` 51 | python build_dataset.py 52 | ``` 53 | 54 | 训练gpt 55 | 56 | ``` 57 | python train_gpt.py 58 | 训练集大小: 258 59 | 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [19:04<00:00, 8.74it/s, loss=0.0698] 60 | ``` 61 | 62 | # 推理 63 | 64 | GPT_MODE='generate'模式,效果为单词接龙. 65 | 66 | ``` 67 | python inference.py 68 | >山色江声共寂寥 69 | < 山色江声共寂寥,十三陵树晚萧萧 70 | 中原事业如江左,芳草何须怨六朝 71 | >三眠 72 | < 三眠未歇,乍到秋时节 73 | 一树料阳蝉更咽,曾绾灞陵离别 74 | 絮己为萍风卷叶,空凄切 75 | 长条莫轻折,苏小恨,倩他说 76 | 尽飘零、游冶章台客 77 | 红板桥空,湔裙人去,依旧晓风残月 78 | ``` 79 | 80 | GPT_MODE='chat'模式,效果为对话(本质仍旧是单词接龙,只是由于数据构造为chatml格式). 81 | 82 | ``` 83 | python inference.py 84 | >虞美人·曲阑深处重相见 85 | < <|im_start|>user 86 | 虞美人·曲阑深处重相见 87 | <|im_end|> 88 | <|im_start|>assistant 89 | 曲阑深处重相见,匀泪偎人颤 90 | 凄凉别后两应同,最是不胜清怨月明中 91 | 半生已分孤眠过,山枕檀痕涴 92 | 忆来何事最销魂,第一折枝花样画罗裙 93 | >美人 94 | < <|im_start|>user 95 | 美人 96 | <|im_end|> 97 | <|im_start|>assistant 98 | 落�垒鸟入闵,望舞,�端缕风毸山行山 99 | 西风丝,�茼�上� 100 | 炬 101 | 残阳何时节�已教�下雨霜倦 102 | 残阳郎此但,寄银入�成�作清�曾游 103 | ``` 104 | 105 | 说明: 106 | * 完整的标题基本能返回正确的诗体,并且有一定的top_k多样性变化。 107 | * 同时,由于训练语料很小,模型没有办法学到足够多的token来表达诗体,所以返回的token序列可能无法构成正确的UTF-8序列,所以出现了�,属于可预期现象。 -------------------------------------------------------------------------------- /build_dataset.py: -------------------------------------------------------------------------------- 1 | from dataset import NalanDataset 2 | import pickle 3 | import os 4 | import sys 5 | 6 | filename='dataset.bin' 7 | 8 | def load_dataset(): 9 | with open(filename,'rb') as fp: 10 | ds=pickle.load(fp) 11 | return ds 12 | 13 | if __name__=='__main__': 14 | if os.path.exists(filename): 15 | ds=load_dataset() 16 | print(f'{filename}已存在,训练集大小:{len(ds)},样例数据如下:') 17 | ids,text=ds[5] 18 | print(ids,text) 19 | sys.exit(0) 20 | 21 | ds=NalanDataset() 22 | with open(filename,'wb') as fp: 23 | ds.build_train_data() 24 | pickle.dump(ds,fp) 25 | print('dataset.bin已生成') -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | VOCAB_SIZE=500 # 词表大小 2 | MAX_SEQ_LEN=2000 # GPT模型输入限制 3 | 4 | # transformer 5 | GPT_DIM=384 6 | GPT_HEAD=6 7 | GPT_FF=1024 8 | GPT_BLOCKS=6 9 | 10 | # training 11 | TRAIN_ITER=10000 12 | BATCH_SIZE=50 13 | 14 | # inference 15 | TEMPERATURE=1.2 16 | TOP_K=20 17 | 18 | # special tokens 19 | BOS='<|beginoftext|>' 20 | EOS='<|endoftext|>' 21 | PAD='<|padding|>' 22 | IM_START='<|im_start|>' 23 | IM_END='<|im_end|>' 24 | 25 | # chat or generate 26 | GPT_MODE='generate' -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from tokenizer import BPETokenizer 3 | from config import * 4 | import json 5 | from tqdm import tqdm 6 | 7 | # 数据集 8 | class NalanDataset(Dataset): 9 | def __init__(self): 10 | super().__init__() 11 | with open('纳兰性德诗集.json','r',encoding='utf-8') as fp: 12 | self.raw_ds=json.loads(fp.read()) 13 | 14 | def build_train_data(self): 15 | tokenizer=BPETokenizer() 16 | tokenizer.load('tokenizer.bin') 17 | 18 | self.data=[] 19 | for sample in tqdm(self.raw_ds,desc='building dataset'): 20 | try: 21 | text='\n'.join(sample['para']) 22 | inputs=f'{IM_START}user\n{sample["title"]}\n{IM_END}\n{IM_START}assistant\n{text}\n{IM_END}' if GPT_MODE=='chat' else f'{text}' 23 | ids,_=tokenizer.encode(inputs) 24 | if len(ids)>MAX_SEQ_LEN-2: # 留出BOS和EOS的token 25 | continue 26 | self.data.append((ids,inputs)) 27 | except: 28 | continue 29 | 30 | def __len__(self): 31 | return len(self.data) 32 | 33 | def __getitem__(self,index): 34 | return self.data[index] -------------------------------------------------------------------------------- /decoder-only.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/chatgpt/4933a4a4237dee430cbd0104656233686ad39e06/decoder-only.png -------------------------------------------------------------------------------- /emb.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import math 4 | 5 | class EmbeddingWithPosition(nn.Module): 6 | def __init__(self,vocab_size,dim,seq_max_len): 7 | super().__init__() 8 | 9 | self.seq_emb=nn.Embedding(vocab_size,dim) 10 | 11 | position_idx=torch.arange(0,seq_max_len,dtype=torch.float).unsqueeze(-1) 12 | position_emb_fill=position_idx*torch.exp(-torch.arange(0,dim,2)*math.log(10000.0)/dim) 13 | pos_encoding=torch.zeros(seq_max_len,dim) 14 | pos_encoding[:,0::2]=torch.sin(position_emb_fill) 15 | pos_encoding[:,1::2]=torch.cos(position_emb_fill) 16 | self.register_buffer('pos_encoding',pos_encoding) 17 | 18 | def forward(self,x): # x: (batch_size,seq_len) 19 | x=self.seq_emb(x) # x: (batch_size,seq_len,dim) 20 | x=x+self.pos_encoding.unsqueeze(0)[:,:x.size()[1],:] # x: (batch_size,seq_len,dim) 21 | return x -------------------------------------------------------------------------------- /gpt.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from emb import EmbeddingWithPosition 4 | from config import GPT_BLOCKS 5 | 6 | class GPT(nn.Module): 7 | def __init__(self,d_model,nhead,feedforward,vocab_size,seq_max_len): 8 | super().__init__() 9 | 10 | # positional encoding... 11 | self.emb=EmbeddingWithPosition(vocab_size=vocab_size,dim=d_model,seq_max_len=seq_max_len) 12 | 13 | # decoder-only transformer (self-attention) 14 | self.dec_blocks=nn.ModuleList([ 15 | nn.TransformerEncoderLayer(d_model=d_model,nhead=nhead,dim_feedforward=feedforward,batch_first=True) for _ in range(GPT_BLOCKS) 16 | ]) 17 | # next token probability 18 | self.prob_linear=nn.Linear(d_model,vocab_size) 19 | 20 | def forward(self,x,padding_mask): # x:(batch,seq) 21 | # 注意力遮挡 22 | src_mask=torch.triu(torch.ones(x.size()[1],x.size()[1]),diagonal=1).type(torch.bool).to(x.device) 23 | # embedding 24 | x=self.emb(x) 25 | # decoder 26 | for block in self.dec_blocks: 27 | x=block(x,src_mask=src_mask,src_key_padding_mask=padding_mask) 28 | # logits 29 | logits=self.prob_linear(x) 30 | return logits 31 | 32 | if __name__=='__main__': 33 | # 分词器 34 | from tokenizer import BPETokenizer 35 | tokenizer=BPETokenizer() 36 | tokenizer.load('tokenizer.bin') 37 | 38 | # 模拟输入 39 | x=torch.randint(0,tokenizer.vocab_size(),(5,30)) 40 | padding=torch.zeros(5,30) 41 | 42 | # GPT模型 43 | from config import MAX_SEQ_LEN 44 | gpt=GPT(d_model=64,nhead=2,feedforward=128,vocab_size=tokenizer.vocab_size(),seq_max_len=MAX_SEQ_LEN) 45 | y=gpt(x,padding) 46 | print(y.shape) -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from gpt import GPT 2 | from config import * 3 | import torch 4 | from tokenizer import BPETokenizer 5 | import torch.nn.functional as F 6 | import random 7 | 8 | # 设备 9 | DEVICE='cuda' if torch.cuda.is_available() else 'cpu' 10 | 11 | # 分词器 12 | tokenizer=BPETokenizer() 13 | tokenizer.load('tokenizer.bin') 14 | 15 | # 加载模型 16 | model=GPT(d_model=GPT_DIM,nhead=GPT_HEAD,feedforward=GPT_FF,vocab_size=tokenizer.vocab_size(),seq_max_len=MAX_SEQ_LEN).to(DEVICE) # 模型 17 | try: 18 | checkpoint=torch.load('checkpoint.bin') 19 | model.load_state_dict(checkpoint['model']) 20 | except: 21 | pass 22 | 23 | model.eval() 24 | 25 | # 可能的结束符 26 | eos_ids,_=tokenizer.encode(EOS) 27 | pad_ids,_=tokenizer.encode(PAD) 28 | im_end_ids,_=tokenizer.encode(IM_END) 29 | 30 | def chat(query): 31 | global tokenizer,model 32 | 33 | inputs=f'{BOS}{IM_START}user\n{query}\n{IM_END}\n{IM_START}assistant\n' if GPT_MODE=='chat' else f'{BOS}{query}' 34 | ids,_=tokenizer.encode(inputs) 35 | 36 | while len(ids)') 65 | if query=='exit': 66 | break 67 | 68 | resp=chat(query) 69 | print('<',resp) 70 | -------------------------------------------------------------------------------- /tokenizer.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import pickle 3 | import re 4 | from tqdm import tqdm 5 | 6 | # Byte-Pair Encoding tokenization 7 | class BPETokenizer: 8 | def __init__(self): 9 | self.b2i=OrderedDict() # bytes to id 10 | self.i2b=OrderedDict() # id to bytes (b2i的反向映射) 11 | self.next_id=0 12 | 13 | # special token 14 | self.sp_s2i={} # str to id 15 | self.sp_i2s={} # id to str 16 | 17 | # 相邻token统计 18 | def _pair_stats(self,tokens,stats): 19 | for i in range(len(tokens)-1): 20 | new_token=tokens[i]+tokens[i+1] 21 | if new_token not in stats: 22 | stats[new_token]=0 23 | stats[new_token]+=1 24 | 25 | # 合并相邻token 26 | def _merge_pair(self,tokens,new_token): 27 | merged_tokens=[] 28 | 29 | i=0 30 | while i=vocab_size: 57 | break 58 | 59 | # 统计相邻token频率 60 | stats={} 61 | for tokens in tokens_list: 62 | self._pair_stats(tokens,stats) 63 | 64 | # 没有更多相邻token, 无法生成更多token,退出训练 65 | if not stats: 66 | break 67 | 68 | # 合并最高频的相邻token,作为新的token加入词表 69 | new_token=max(stats,key=stats.get) 70 | 71 | new_tokens_list=[] 72 | for tokens in tokens_list: 73 | new_tokens_list.append(self._merge_pair(tokens,new_token)) 74 | tokens_list=new_tokens_list 75 | 76 | # new token加入词表 77 | self.b2i[new_token]=self.next_id 78 | self.next_id+=1 79 | 80 | # 刷新进度条 81 | progress.update(1) 82 | 83 | self.i2b={v:k for k,v in self.b2i.items()} 84 | 85 | # 词表大小 86 | def vocab_size(self): 87 | return self.next_id 88 | 89 | # 词表 90 | def vocab(self): 91 | v={} 92 | v.update(self.i2b) 93 | v.update({id:token.encode('utf-8') for id,token in self.sp_i2s.items()}) 94 | return v 95 | 96 | # 特殊token 97 | def add_special_tokens(self,special_tokens): 98 | for token in special_tokens: 99 | if token not in self.sp_s2i: 100 | self.sp_s2i[token]=self.next_id 101 | self.sp_i2s[self.next_id]=token 102 | self.next_id+=1 103 | 104 | def encode(self,text): 105 | # 特殊token分离 106 | pattern='('+'|'.join([re.escape(tok) for tok in self.sp_s2i])+')' 107 | splits=re.split(pattern,text) # [ '<|im_start|>', 'user', '<||>' ] 108 | 109 | # 编码结果 110 | enc_ids=[] 111 | enc_tokens=[] 112 | for sub_text in splits: 113 | if sub_text in self.sp_s2i: # 特殊token,直接对应id 114 | enc_ids.append(self.sp_s2i[sub_text]) 115 | enc_tokens.append(sub_text.encode('utf-8')) 116 | else: 117 | tokens=[bytes([b]) for b in sub_text.encode('utf-8')] 118 | while True: 119 | # 统计相邻token频率 120 | stats={} 121 | self._pair_stats(tokens,stats) 122 | 123 | # 选择合并后id最小的pair合并(也就是优先合并短的) 124 | new_token=None 125 | for merge_token in stats: 126 | if merge_token in self.b2i and (new_token is None or self.b2i[merge_token]','<|im_end|>','<|endoftext|>','<|padding|>'])) 169 | 170 | # 保存 171 | tokenizer.save('tokenizer.bin') 172 | 173 | # 还原 174 | tokenizer=BPETokenizer() 175 | tokenizer.load('tokenizer.bin') 176 | print('vocab size:',tokenizer.vocab_size()) 177 | 178 | # 编码 179 | ids,tokens=tokenizer.encode('<|im_start|>system\nyou are a helper assistant\n<|im_end|>\n<|im_start|>user\n今天的天气\n<|im_end|><|im_start|>assistant\n') 180 | print('encode:',ids,tokens) 181 | 182 | # 解码 183 | s=tokenizer.decode(ids) 184 | print('decode:',s) 185 | 186 | # 打印词典 187 | print('vocab:',tokenizer.vocab()) -------------------------------------------------------------------------------- /train_gpt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from build_dataset import load_dataset 3 | from gpt import GPT 4 | import torch.nn.functional as F 5 | from torch.utils.data import DataLoader 6 | from config import * 7 | from tokenizer import BPETokenizer 8 | from tqdm import tqdm 9 | import os 10 | 11 | DEVICE='cuda' if torch.cuda.is_available() else 'cpu' 12 | 13 | dataset=load_dataset() 14 | 15 | tokenizer=BPETokenizer() 16 | tokenizer.load('tokenizer.bin') 17 | pad_ids,_=tokenizer.encode(PAD) 18 | 19 | def batch_proc(batch): 20 | bos_ids,_=tokenizer.encode(BOS) 21 | eos_ids,_=tokenizer.encode(EOS) 22 | pad_ids,_=tokenizer.encode(PAD) 23 | 24 | batch_x=[] 25 | batch_chatml=[] 26 | # bpe encode 27 | for sample in batch: 28 | ids,chatml=sample 29 | ids=bos_ids+ids+eos_ids 30 | batch_x.append(ids) 31 | batch_chatml.append(chatml) 32 | 33 | # padding 34 | max_len=max([len(ids) for ids in batch_x]) 35 | for ids in batch_x: 36 | if len(ids)