├── .DS_Store ├── .gitignore ├── README.md ├── WIKI_CN └── .gitignore ├── images ├── dataset.png ├── llm.png ├── log.png └── terminal.png ├── pretrain.py ├── requirements.txt └── swanlab.yaml /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShaohonChen/transformers_from_scratch/5e9448497f729765f909596e3e2b3c409949bbd3/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | tmp/ 2 | output 3 | checkpoints/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # transformers_from_scratch 2 | 3 | 大语言模型(Large Language Model,简称LLM),指使用大量文本数据训练的深度学习模型,可以生成自然语言文本或理解语言文本的含义。 4 | 5 | ![llm](./images/llm.png) 6 | 7 | 虽然网上有大量关于transformer理论、大语言模型微调的教程。但是少有关于预训练的解释。本文则从如何自己实战预训练一个大语言模型的角度,使用wiki数据集进行一个简单的从零预训练工作,并附上使用swanlab launch白嫖显卡的方法 8 | 9 | * 实验记录:[SwanLab](https://swanlab.cn/@ShaohonChen/WikiLLM/overview) 10 | 11 | * 数据集下载:[百度网盘(j8ee)](https://pan.baidu.com/s/1p5F52bRlnpSY7F78q0hz7A?pwd=j8ee),[huggingface](https://huggingface.co/datasets/fjcanyue/wikipedia-zh-cn) 12 | 13 | --- 14 | 15 | ## 安装环境 16 | 17 | 首先,项目推荐使用python3.10。需要安装的python包如下: 18 | 19 | ```txt 20 | swanlab 21 | transformers 22 | datasets 23 | accelerate 24 | ``` 25 | 26 | 使用如下命令一键安装: 27 | 28 | ```bash 29 | pip install swanlab transformers datasets accelerate modelscope 30 | ``` 31 | 32 | --- 33 | 34 | ## 下载数据集 35 | 36 | 本教程使用的是中文wiki数据,理论上预训练数据集种类越丰富、数据量越大越好,后续会增加别的数据集。 37 | 38 | ![dataset](./images/dataset.png) 39 | 40 | huggingface链接:[wikipedia-zh-cn](https://huggingface.co/datasets/fjcanyue/wikipedia-zh-cn) 41 | 42 | 百度网盘下载地址:[百度网盘(j8ee)](https://pan.baidu.com/s/1p5F52bRlnpSY7F78q0hz7A?pwd=j8ee) 43 | 44 | 下载`wikipedia-zh-cn-20240820.json`文件后放到项目目录下`./WIKI_CN/`文件夹中 45 | 46 | 该数据集文件约1.99G大,共有1.44M条数据。虽然数据集中包含文章标题,但是实际上在预训练阶段用不上。正文片段参考: 47 | 48 | ```txt 49 | 数学是研究数量、结构以及空间等概念及其变化的一门学科,属于形式科学的一种。数学利用抽象化和逻辑推理,从计数、计算、量度、对物体形状及运动的观察发展而成。数学家们拓展这些概念... 50 | ``` 51 | 52 | 使用[🤗Huggingface Datasets](https://huggingface.co/docs/datasets/index)加载数据集的代码如下: 53 | 54 | ```python 55 | from datasets import load_dataset 56 | 57 | ds = load_dataset("fjcanyue/wikipedia-zh-cn") 58 | ``` 59 | 60 | 如果使用百度网盘下载的json文件,可以通过如下代码加载 61 | 62 | ```python 63 | raw_datasets = datasets.load_dataset( 64 | "json", data_files="data/wikipedia-zh-cn-20240820.json" 65 | ) 66 | 67 | raw_datasets = raw_datasets["train"].train_test_split(test_size=0.1, seed=2333) 68 | print("dataset info") 69 | print(raw_datasets) 70 | ``` 71 | 72 | --- 73 | 74 | ## 运行训练 75 | 76 | 运行如下命令 77 | 78 | ``` 79 | python pretrain.py 80 | ``` 81 | 82 | 可以看到如下训练日志。由于训练时间较长,推荐使用tmux将训练任务hold住 83 | 84 | ![terminal](./images/terminal.png) 85 | 86 | 可以在[SwanLab](https://swanlab.cn)中查看最终的训练结果: 87 | 88 | ![log](./images/log.png) 89 | -------------------------------------------------------------------------------- /WIKI_CN/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /images/dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShaohonChen/transformers_from_scratch/5e9448497f729765f909596e3e2b3c409949bbd3/images/dataset.png -------------------------------------------------------------------------------- /images/llm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShaohonChen/transformers_from_scratch/5e9448497f729765f909596e3e2b3c409949bbd3/images/llm.png -------------------------------------------------------------------------------- /images/log.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShaohonChen/transformers_from_scratch/5e9448497f729765f909596e3e2b3c409949bbd3/images/log.png -------------------------------------------------------------------------------- /images/terminal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShaohonChen/transformers_from_scratch/5e9448497f729765f909596e3e2b3c409949bbd3/images/terminal.png -------------------------------------------------------------------------------- /pretrain.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import transformers 3 | import swanlab 4 | from swanlab.integration.huggingface import SwanLabCallback 5 | import modelscope 6 | 7 | def main(): 8 | # using swanlab to save log 9 | swanlab.init("WikiLLM") 10 | 11 | # load dataset 12 | raw_datasets = datasets.load_dataset( 13 | "json", data_files="/data/WIKI_CN/wikipedia-zh-cn-20240820.json" 14 | ) 15 | 16 | raw_datasets = raw_datasets["train"].train_test_split(test_size=0.1, seed=2333) 17 | print("dataset info") 18 | print(raw_datasets) 19 | 20 | # load tokenizers 21 | # 因为国内无法直接访问HuggingFace,因此使用魔搭将模型的配置文件和Tokenizer下载下来 22 | modelscope.AutoConfig.from_pretrained("Qwen/Qwen2-0.5B").save_pretrained( 23 | "Qwen2-0.5B" 24 | ) 25 | modelscope.AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B").save_pretrained( 26 | "Qwen2-0.5B" 27 | ) 28 | context_length = 512 # use a small context length 29 | # tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B") 30 | tokenizer = transformers.AutoTokenizer.from_pretrained( 31 | "./Qwen2-0.5B" 32 | ) # download from local 33 | 34 | # preprocess dataset 35 | def tokenize(element): 36 | outputs = tokenizer( 37 | element["text"], 38 | truncation=True, 39 | max_length=context_length, 40 | return_overflowing_tokens=True, 41 | return_length=True, 42 | ) 43 | input_batch = [] 44 | for length, input_ids in zip(outputs["length"], outputs["input_ids"]): 45 | if length == context_length: 46 | input_batch.append(input_ids) 47 | return {"input_ids": input_batch} 48 | 49 | tokenized_datasets = raw_datasets.map( 50 | tokenize, batched=True, remove_columns=raw_datasets["train"].column_names 51 | ) 52 | print("tokenize dataset info") 53 | print(tokenized_datasets) 54 | tokenizer.pad_token = tokenizer.eos_token 55 | data_collator = transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False) 56 | 57 | # prepare a model from scratch 58 | config = transformers.AutoConfig.from_pretrained( 59 | "./Qwen2-0.5B", 60 | vocab_size=len(tokenizer), 61 | hidden_size=512, 62 | intermediate_size=2048, 63 | num_attention_heads=8, 64 | num_hidden_layers=12, 65 | n_ctx=context_length, 66 | bos_token_id=tokenizer.bos_token_id, 67 | eos_token_id=tokenizer.eos_token_id, 68 | ) 69 | model = transformers.Qwen2ForCausalLM(config) 70 | model_size = sum(t.numel() for t in model.parameters()) 71 | print("Model Config:") 72 | print(config) 73 | print(f"Model Size: {model_size/1000**2:.1f}M parameters") 74 | 75 | # train 76 | args = transformers.TrainingArguments( 77 | output_dir="WikiLLM", 78 | per_device_train_batch_size=32, # 每个GPU的训练batch数 79 | per_device_eval_batch_size=32, # 每个GPU的测试batch数 80 | eval_strategy="steps", 81 | eval_steps=5_00, 82 | logging_steps=50, 83 | gradient_accumulation_steps=8, # 梯度累计总数 84 | num_train_epochs=2, # 训练epoch数 85 | weight_decay=0.1, 86 | warmup_steps=2_00, 87 | optim="adamw_torch", # 优化器使用adamw 88 | lr_scheduler_type="cosine", # 学习率衰减策略 89 | learning_rate=5e-4, # 基础学习率, 90 | save_steps=5_00, 91 | save_total_limit=10, 92 | bf16=True, # 开启bf16训练, 对于Amper架构以下的显卡建议替换为fp16=True 93 | ) 94 | print("Train Args:") 95 | print(args) 96 | # enjoy training 97 | trainer = transformers.Trainer( 98 | model=model, 99 | tokenizer=tokenizer, 100 | args=args, 101 | data_collator=data_collator, 102 | train_dataset=tokenized_datasets["train"], 103 | eval_dataset=tokenized_datasets["test"], 104 | callbacks=[SwanLabCallback()], 105 | ) 106 | trainer.train() 107 | 108 | # save model 109 | model.save_pretrained("./WikiLLM/Weight") # 保存模型的路径 110 | 111 | # generate 112 | pipe = transformers.pipeline("text-generation", model=model, tokenizer=tokenizer) 113 | print("GENERATE:", pipe("人工智能", num_return_sequences=1)[0]["generated_text"]) 114 | prompts = ["牛顿", "北京市", "亚洲历史"] 115 | examples = [] 116 | for i in range(3): 117 | # 根据提示词生成数据 118 | text = pipe(prompts[i], num_return_sequences=1)[0]["generated_text"] 119 | text = swanlab.Text(text) 120 | examples.append(text) 121 | swanlab.log({"Generate": examples}) 122 | 123 | 124 | if __name__ == "__main__": 125 | main() 126 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | swanlab==0.3.19a3 2 | transformers 3 | datasets 4 | accelerate 5 | modelscope -------------------------------------------------------------------------------- /swanlab.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: swanlab/v1 2 | kind: Folder 3 | metadata: 4 | name: WikiLLM 5 | desc: Pretrain LLM using wiki data 6 | spec: 7 | python: "3.10" 8 | entry: "pretrain.py" 9 | volumes: 10 | - name: "WIKI_CN" 11 | id: "9sqjot" 12 | exclude: 13 | - "WIKI_CN" --------------------------------------------------------------------------------