├── .gitignore ├── README.md ├── example ├── en.json └── en.txt ├── inference.sh ├── model ├── inference.py ├── languages_abbreviation2fullname.txt ├── llama │ ├── __init__.py │ ├── configuration_llama.py │ ├── convert_llama_weights_to_hf.py │ ├── modeling_llama.py │ └── tokenization_llama.py └── translate.py ├── pics ├── 104langs_bleu.png ├── 70langs_gpt4.png ├── The_outline_of_Increment_pre-training.png └── corpus_distribution.png ├── requirements.txt └── translate.sh /.gitignore: -------------------------------------------------------------------------------- 1 | /log/13b/ 2 | __pycache__/ 3 | *.py[cod] 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🦙 **BigTranslate** 🚀 2 | ## 📢 News 3 | **💥[July 12, 2023] We have implemented parallel inference on multiple GPUs by utilizing [tensor_parallel](https://github.com/BlackSamorez/tensor_parallel/).** 4 | 5 | [July 7, 2023] We have changed our model name to **BigTranslate**. 6 | 7 | [June 6, 2023] We have updated [translate.sh](./translate.sh) to directly utilize BigTranslate for multilingual machine translation! Just only need to provide source and target language abbreviation, file path in script. 8 | 9 | [May 31, 2023] We have released BigTranslate model on [Hugging Face](https://huggingface.co/James-WYang/BigTranslate) for download and experience! 10 | ## 📝 Case Study 11 | *French -> English* 12 | > Instruction: 请将以下法语句子翻译成英语:Cette découverte permet également de mieux comprendre l'évolution des plumes chez les oiseaux. 13 | > 14 | > BigTranslate: This finding also provides a better understanding of the evolution of feathers in birds. 15 | > 16 | > Instruction: 请将以下句子翻译成英语:Il a été volé par des pirates, attaqué au Tibet par un chien enragé, a échappé à un mariage au Népal et a été arrêté en Inde. 17 | > 18 | > BigTranslate: Along the way he was hijacked by pirates, attacked by an angry dog in Tibet, narrowly escaped an arranged marriage in Nepal, and was arrested in India. 19 | 20 | *Danish -> English* 21 | > Instruction: 请将以下丹麦语句子翻译成英语:Rapporten starter med at opfordre til en åben debat, og til at der skabes enighed i USA om landets politik over for Mellemøsten. 22 | > 23 | > BigTranslate: The report starts by calling for an open debate and for unity in the United States on its policy towards the Middle East. 24 | > 25 | > Instruction: 请将以下丹麦语句子翻译成英语:Mere end fire millioner mennesker tog til Rom for at deltage i begravelsen 26 | > 27 | > BigTranslate: More than 4 million people went to Rome to attend the funeral. 28 | 29 | *English -> Chinese* 30 | > Instruction: 请将以下英语句子翻译成汉语:You may also wish to consult the advice of governments other than your own, but their advice is designed for their citizens. 31 | > 32 | > BigTranslate: 您也可以咨询其他国家政府的意见,但这些意见是针对本国公民而设计的。 33 | > 34 | > Instruction: 请将以下句子翻译成汉语:Typically there will be a tuition fee to enroll in these educational programs. 35 | > 36 | > BigTranslate: 参加这些教育培训项目的学生通常需要缴纳学费。 37 | 38 | *Tibetan -> Chinese* 39 | > Instruction: 请将以下句子翻译成汉语:ཚད་ལྡན་གྱི་འཁྱགས་ཤུད་སྤོ་ལོའི་ཐང་གི་ཚད་གཞི་ཆེ་ཤོས་ནི་རིང་ཚད་ལ་སྨི་61དང་ཞེང་ཚད་ལ་སྨི་30ཡོད། 40 | > 41 | > BigTranslate: 标准冰橇长度最大的是61米,最小的是30米 42 | > 43 | > Instruction: 请将以下藏语句子翻译成汉语:ངས་ཤེས་གསལ་ལྟར་ན། ང་ཚོའི་རྐང་རྩེད་སྤོ་ལོ་རུ་ཁག་གི་ནུས་ཤུགས་ཁོ་ཚོ་ལས་བཟང་། 44 | > 45 | > BigTranslate: 就我所知,我们的足球队比他们强。 46 | 47 | *English -> Portuguese* 48 | > Instruction: 请将以下英语句子翻译成葡萄牙语:Several large television screens were installed in various places in Rome to let the people watch the ceremony. 49 | > 50 | > BigTranslate: Diversos grandes ecrãs televisivos foram instalados em diversos lugares em Roma para que as pessoas pudessem assistir à cerimónia. 51 | > 52 | > Instruction: 请将以下英语句子翻译成葡萄牙语:Scientists say the explosion caused by the collision was massive. 53 | > 54 | > BigTranslate: Os cientistas dizem que a explosão causada pela colisão foi massiva. 55 | 56 | *English -> Swedish* 57 | > Instruction: 请将以下句子翻译成瑞典语:Negotiators tried to rectify the situation, but the prisoners' demands are not clear. 58 | > 59 | > BigTranslate: Förhandlarna försöker korrigera situationen, men fångarnas krav är inte klara. 60 | > 61 | > Instruction: 请将以下英语句子翻译成瑞典语:Although the water level will only rise a few feet after the flood, officials are hoping it will be enough to restore eroded sandbars downstream. 62 | > 63 | > BigTranslate: Även om vattennivån endast ökar några fot efter översvämningen, hoppas myndigheterna att det räcker för att återställa eroderade sandbankar nedströms. 64 | 65 | ## ⭐ BigTranslate Construction 66 | ### 🌓 Large-scale Parallel Dataset Construction 67 | In order to enhance the language capabilities of the Chinese LLaMA model to support 102 languages, we constructed a comprehensive parallel corpus dataset consisting of 102 languages. This dataset was employed to continue training the foundational model. The compilation of this dataset drew upon multiple sources, including widely available public parallel corpus datasets and household datasets. The public datasets utilized in our study contain IWSLT, WMT, CCMT, and OPUS-100, forming the initial corpus of our dataset. 68 | 69 | To effectively illustrate the distribution of the corpus, we present a visual representation of the language-pair distribution within the multilingual datasets. The matter pertaining to the imbalance between high-resource and low-resource language pairs continues to be a prominent concern within the current corpus. 70 | 71 | ![image](./pics/corpus_distribution.png) 72 | 73 | ### 🌔 Incremental Multilingual Pre-training 74 | In this incremental pre-training method, we gradually expose the model to language pairs in a curriculum-like manner. Initially, the model is exposed to high-resource language pairs, allowing it to establish a solid foundation in those languages. Subsequently, we progressively introduce low-resource language pairs, enabling the model to gradually expand its knowledge and proficiency in these languages. 75 | 76 | Specifically, we follow a three-step approach in our incremental pre-training method. Firstly, we set the sample interval size and divide language pairs into distinct intervals based on the number of instances for each language pair. Secondly, we calculate the sample mean for all language pairs in each interval. Thirdly, we dynamically measure the moment of adding the language-pair samples next interval according to the sample mean in the previous sample interval. In the following part, we detail the three steps. 77 | 78 | ![image](./pics/The_outline_of_Increment_pre-training.png) 79 | 80 | ### 🌕 Multilingual Translation Instruction Tuning 81 | 82 | We have designed a set of 28 multilingual translation prompts that encompass various application scenarios for multilingual translation. We randomly select a prompt from the set for instruction tuning for each parallel sentence. Accordingly, the instruction tuning dataset is scrambled to ensure randomness and diversity. 83 | 84 | During training phase, We randomly select a prompt from the following 28 multilingual translation prompts for each sentence. 85 | ``` 86 | 请将以下{SRC_LANG}句子翻译成{TGT_LANG}:{SRC_Text} 87 | 请将以下{SRC_LANG}文本翻译成{TGT_LANG}:{SRC_Text} 88 | 请将以下句子翻译成{TGT_LANG}:{SRC_Text} 89 | 请将以下文本翻译成{TGT_LANG}:{SRC_Text} 90 | 请提供{SRC_LANG}句子“{SRC_Text}”的{TGT_LANG}翻译 91 | 请提供{SRC_LANG}文本“{SRC_Text}”的{TGT_LANG}翻译 92 | 请提供句子“{SRC_Text}”的{TGT_LANG}翻译 93 | 请提供文本“{SRC_Text}”的{TGT_LANG}翻译 94 | 以下{SRC_LANG}句子“{SRC_Text}”用{TGT_LANG}如何表达 95 | 以下{SRC_LANG}文本“{SRC_Text}”用{TGT_LANG}如何表达 96 | 以下句子“{SRC_Text}”用{TGT_LANG}如何表达 97 | 以下文本“{SRC_Text}”用{TGT_LANG}如何表达 98 | 以下{SRC_LANG}句子“{SRC_Text}”的{TGT_LANG}翻译是什么? 99 | 以下{SRC_LANG}文本“{SRC_Text}”的{TGT_LANG}翻译是什么? 100 | 以下句子“{SRC_Text}”的{TGT_LANG}翻译是什么? 101 | 以下文本“{SRC_Text}”的{TGT_LANG}翻译是什么? 102 | 请生成以下{SRC_LANG}句子“{SRC_Text}”的{TGT_LANG}翻译 103 | 请生成以下{SRC_LANG}文本“{SRC_Text}”的{TGT_LANG}翻译 104 | 请生成以下句子“{SRC_Text}”的{TGT_LANG}翻译 105 | 请生成以下文本“{SRC_Text}”的{TGT_LANG}翻译 106 | 如何用{TGT_LANG}表达{SRC_LANG}句子“{SRC_Text}” 107 | 如何用{TGT_LANG}表达{SRC_LANG}文本“{SRC_Text}” 108 | 如何用{TGT_LANG}表达句子“{SRC_Text}” 109 | 如何用{TGT_LANG}表达文本“{SRC_Text}” 110 | 这个{SRC_LANG}句子“{SRC_Text}”用{TGT_LANG}怎么说? 111 | 这个{SRC_LANG}文本“{SRC_Text}”用{TGT_LANG}怎么说? 112 | 这个句子“{SRC_Text}”用{TGT_LANG}怎么说? 113 | 这个文本“{SRC_Text}”用{TGT_LANG}怎么说? 114 | ``` 115 | During inference phase, We randomly select a prompt from the following two multilingual translation prompts for each sentence. 116 | ``` 117 | 请将以下{SRC_LANG}句子翻译成{TGT_LANG}:{SRC_Text} 118 | 请将以下句子翻译成{TGT_LANG}:{SRC_Text} 119 | ``` 120 | 121 | 122 | ## 🌟 Experiments 123 | ### 🌖 Automatic Evaluation with BLEU 124 | An illustrated comparison of 102 languages from X to English or Chinese between BigTranslate, ChatGPT and Google Translate. We sort the language scores in BLEU for BigTranslate in descending order. 125 | 126 | ![image](./pics/104langs_bleu.png) 127 | 128 | ### 🌗 Automatic Evaluation with GPT-4 129 | An illustrated comparison of 70 languages from X to English or Chinese between BigTranslate, ChatGPT and Google Translate. We sort the language scores in GPT-4 score for BigTranslate in descending order. 130 | 131 | ![image](./pics/70langs_gpt4.png) 132 | 133 | ## 🤖 BigTranslate Model 134 | 135 | ### ⚠️ User Notice (Must Read) 136 | 137 | 138 | 139 | The BigTranslate Model weights are based on [GNU General Public License v3.0](https://www.gnu.org/licenses/gpl-3.0.html) protocols, which is only for research use and cannot be used for commercial purposes. 140 | 141 | ***Please confirm that you are using the model in this warehouse with [permission](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform?usp=send_form).*** 142 | 143 | ### 📎 Model Download 144 | 145 | **BigTranslate**:[Hugging Face](https://huggingface.co/James-WYang/BigTranslate) 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | ### 📌 Model Inference 155 | Install dependencies: 156 | 157 | ```bash 158 | pip install -r requirements.txt 159 | ``` 160 | 161 | Example usage: 162 | 163 | ```bash 164 | python -u model/inference.py \ 165 | --model ${CHECKPOINT_PATH} \ 166 | --tokenizer-path ${TOKENIZER_PATH} \ 167 | --prompt-file ${PROMPT_FILE} \ 168 | --with-instruct \ 169 | --out-file ${LOW_OUT_FILE} \ 170 | --seed ${SEED} \ 171 | --beam-search \ 172 | --num-beams ${NUM_BEAMS} \ 173 | --times ${OUT_TIME} \ 174 | --max-tokens ${MAX_TOKENS} \ 175 | --no-repeat-ngram-size ${NO_REPEAT_NGRAM_SIZE} \ 176 | --top-k ${TOP_K} \ 177 | --top-p ${TOP_P} \ 178 | --temperature ${TEMPERATURE} 2>&1 >>${LOG_FILE} 179 | ``` 180 | We can customize the hyperparameters: 181 | 182 | ```bash 183 | python -u model/inference.py \ 184 | --model ${CHECKPOINT_PATH} \ 185 | --tokenizer-path ${TOKENIZER_PATH} \ 186 | --prompt-file ${PROMPT_FILE} \ 187 | --with-instruct \ 188 | --out-file ${BEAM_OUT_FILE} \ 189 | --seed ${SEED} \ 190 | --beam-search \ 191 | --num-beams 5 \ 192 | --times 1 \ 193 | --max-tokens 256 \ 194 | --no-repeat-ngram-size 6 2>&1 >>${LOG_FILE} 195 | ``` 196 | We made a script in [inference.sh](./inference.sh) to run model inference. 197 | 198 | ### 💡 Translate with BigTranslate 199 | 200 | Example usage: 201 | ``` 202 | python -u model/translate.py \ 203 | --model ${CHECKPOINT_PATH} \ 204 | --tokenizer-path ${TOKENIZER_PATH} \ 205 | --prompt-file ${PROMPT_FILE} \ 206 | ${ADD_PARAMETERS} \ 207 | --out-file ${SAVE_PATH} \ 208 | --source-language ${SRC_LANG} \ 209 | --target-language ${TGT_LANG} \ 210 | --seed ${SEED} \ 211 | --beam-search \ 212 | --parameter-type ${MODEL_TYPE} \ 213 | --num-beams ${NUM_BEAMS} \ 214 | --times ${OUT_TIME} \ 215 | --max-tokens ${MAX_TOKENS} \ 216 | --no-repeat-ngram-size ${NO_REPEAT_NGRAM_SIZE} \ 217 | --temperature ${LOW_TEMPERATURE} 2>&1 >>${LOG_FILE} 218 | ``` 219 | We made a script in [translate.sh](./translate.sh) to translate with BigTranslate. 220 | 221 | ## License 222 | 223 | Our code and documents are released under Apache Licence 2.0 224 | 225 | Following LLaMA, our pre-trained weights are released under GNU General Public License v3.0 226 | 227 | ## Acknowledgement 228 | 229 | We thank all contributors for BigTranslate project. 230 | 231 | This repo benefits from [LLaMA](https://github.com/facebookresearch/llama), [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca). Thanks for their wonderful works. 232 | 233 | ## Contact 234 | 235 | If you have any questions, please feel free to contact us by sending an email to {yangwen2023, lichong2021}@ia.ac.cn, {jjzhang, cqzong}@nlpr.ia.ac.cn . 236 | 237 | ## Citation 238 | 239 | ``` 240 | @article{yang-etal-2023-BigTranslate, 241 | author = {Wen Yang and 242 | Chong Li and 243 | Jiajun Zhang and 244 | Chengqing Zong}, 245 | title={BigTranslate: Augmenting Large Language Models with Multilingual Translation Capability over 100 Languages}, 246 | journal={arXiv preprint arXiv:2305.18098}, 247 | url={https://arxiv.org/abs/2305.18098}, 248 | year={2023} 249 | } 250 | ``` 251 | 252 | 253 | 254 | 255 | -------------------------------------------------------------------------------- /example/en.json: -------------------------------------------------------------------------------- 1 | {"input": "Previously, Ring's CEO, Jamie Siminoff, remarked the company started when his doorbell wasn't audible from his shop in his garage."} 2 | {"input": "Siminoff said sales boosted after his 2013 appearance in a Shark Tank episode where the show panel declined funding the startup."} 3 | {"input": "While one experimental vaccine appears able to reduce Ebola mortality, up until now, no drugs have been clearly demonstrated suitable for treating existing infection."} 4 | {"input": "USA Gymnastics supports the United States Olympic Committee's letter and accepts the absolute need of the Olympic family to promote a safe environment for all of our athletes."} 5 | {"input": "Throughout 1960s, Brzezinski worked for John F. Kennedy as his advisor and then the Lyndon B. Johnson administration."} 6 | {"input": "During the 1976 selections he advised Carter on foreign policy, then served as National Security Advisor (NSA) from 1977 to 1981, succeeding Henry Kissinger."} 7 | {"input": "The other nominations include Best Picture, Director, Cinematography, Costume Design, Film-editing, Original Score, Production Design, Sound Editing, Sound Mixing and Original Screenplay."} 8 | {"input": "The announcement was made after Trump had a phone conversation with Turkish President Recep Tayyip Erdoğan."} 9 | {"input": "A car bomb detonated at police headquarters in Gaziantep, Turkey yesterday morning killed two police officers and injured more than twenty other people."} 10 | {"input": "Police said they suspect an alleged Daesh (ISIL) militant of responsibility for the attack."} 11 | {"input": "A doctor who worked at Children's Hospital of Pittsburgh, Pennsylvania will be charged with aggravated murder after her mother was found dead in the trunk of her car Wednesday, authorities in Ohio say."} 12 | {"input": "The outbreak has prompted the Indian government to undertake such measures as deployment of pig catchers in seriously affected areas, distributing thousands of mosquito curtains and spraying pesticides."} 13 | {"input": "Plans for vaccines to be delivered to the historically most affected areas this year were delayed due to lack of funds and low prioritisation relative to other diseases."} 14 | {"input": "Lakkha Singh presented the chhappan bhog bhajan as well. Singer, Raju Khandelwal was accompanying him."} 15 | {"input": "Famous singers across the country presented bhajans, or devotional songs, to Shri Shyam's feet."} 16 | {"input": "She came to this conclusion due to the multitude of positive comments and encouragement sent to her by both female and male individuals urging that contraception medication be considered a medical necessity."} 17 | {"input": "As a result, two fish species have become extinct, and two others have become endangered, including the humpback chub."} 18 | {"input": "Although the water level will only rise a few feet after the flood, officials are hoping it will be enough to restore eroded sandbars downstream."} 19 | {"input": "Final results from Namibian presidential and parliamentary elections have indicated that the incumbent president, Hifikepunye Pohamba, has been reelected by a large margin."} 20 | {"input": "The medical charity Mangola, Medecines Sans Frontieres and the World Health Organisation say it is the worst outbreak recorded in the country."} -------------------------------------------------------------------------------- /example/en.txt: -------------------------------------------------------------------------------- 1 | The find also grants insight into the evolution of feathers in birds. 2 | The governor's office said nineteen of the injured were police officers. 3 | During his trip, Iwasaki ran into trouble on many occasions. 4 | He was robbed by pirates, attacked in Tibet by a rabid dog, escaped marriage in Nepal and was arrested in India. 5 | He did not set a figure for the cuts, saying they will be made based on China's economic output. 6 | The Report opens with plea for open debate and the formation of a consensus in the United States about the policy towards the Middle East. 7 | Over four million people went to Rome to attend the funeral. 8 | Several large television screens were installed in various places in Rome to let the people watch the ceremony. 9 | Television reports show white smoke coming from the plant. 10 | Scientists say the explosion caused by the collision was massive. 11 | Police said that the body appeared to have been there for about a day. 12 | They all ran back from where the accident had happened. 13 | Negotiators tried to rectify the situation, but the prisoners' demands are not clear. 14 | Although the water level will only rise a few feet after the flood, officials are hoping it will be enough to restore eroded sandbars downstream. 15 | "This is not going to be goodbye. This is the closing of one chapter and the opening of a new one." 16 | "They are cooler than the surrounding surface in the day and warmer at night. 17 | It was the final match for the All Blacks, who had already won the trophy two weeks ago. 18 | Chambers had sued God for "widespread death, destruction and terrorization of millions upon millions of the Earth's inhabitants." 19 | With only eighteen medals available a day, a number of countries have failed to make the medal podium. 20 | Before The Simpsons Simon had worked on several shows in various positions. 21 | This will allow players to control actions and movements in video games by moving the device through the air. 22 | "I was moved every time we did a rehearsal on this, from the bottom of my heart." 23 | That didn't seem to make sense to me; it certainly wasn't fair. 24 | Although three people were inside the house when the car impacted it, none of them were hurt. 25 | Several hostages have been rescued and least six have been confirmed dead so far. 26 | He has been unable to take the drugs needed to overcome his pain as they are banned from the Games. 27 | He is speculated to make a run for president in 2016. 28 | When you call someone who is thousands of miles away, you are using a satellite. 29 | The wheel has changed the world in incredible ways. The biggest thing that the wheel has done for us is given us much easier and faster transportation. 30 | It has brought us the train, the car, and many other transportation devices. 31 | The females are usually closely related to each other, being a large family of sisters and daughters. 32 | The original population hasn't changed at all, they still need the same adaptations as before. 33 | In the warm climate of the Middle East, the house was not so important. 34 | Its all-pervading power affected everyone from king to commoner. 35 | For example, one might say that the motor car necessarily leads to the development of roads. 36 | However, due to the slow communication channels, styles in the west could lag behind by 25 to 30 year. 37 | The Internet combines elements of both mass and interpersonal communication. 38 | In particular, it is claimed that one can detect whether a person is lying by interpreting micro-expressions correctly. 39 | Many people don't think about them as dinosaurs because they have feathers and can fly. 40 | After all, the leader is ultimately responsible for the success and failure of the team. 41 | In 1990, it was added to the list of world heritage sites in danger, due to the threat of desert sands. 42 | In some areas boiling water for a minute is enough, in others several minutes are needed. 43 | Think of the skiing route as of a similar hiking route. 44 | You may also wish to consult the advice of governments other than your own, but their advice is designed for their citizens. 45 | Willingness of foreign governments to honour these documents is just as widely variable. 46 | People may not anticipate that patience and understanding are also necessary for travellers returning home. 47 | Typically there will be a tuition fee to enroll in these educational programs. 48 | It's worth half an hour to stroll about the intriguing village. 49 | The area is also home to an extremely wide variety of animal and bird species. 50 | The concept came from China where plum blossoms were the flower of choice. -------------------------------------------------------------------------------- /inference.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | export SEED=0 3 | 4 | export PROMPT_FILE= #PROMPT_FILE_PATH 5 | # export PROMPT_PATH= #PROMPT_FILES_DIR (change --prompt-file ${PROMPT_FILE} with --prompt-path ${PROMPT_PATH} for multiple files in one directory) 6 | export CHECKPOINT_PATH= #CHECKPOINT_PATH (e.g., /PATH2BigTrans or decapoda-research/llama-7b-hf) 7 | export TOKENIZER_PATH= #TOKENIZER_PATH (e.g., /PATH2BigTrans or decapoda-research/llama-7b-hf) 8 | 9 | export INSTRUCT=True 10 | 11 | # export HIGH_OUT_FILE= #OUT_FILE_PATH 12 | # export LOW_OUT_FILE= #OUT_FILE_PATH 13 | export BEAM_OUT_FILE= #OUT_FILE_PATH 14 | 15 | export MAX_TOKENS=256 16 | export TOP_K=50 17 | export TOP_P=0.95 18 | export NO_REPEAT_NGRAM_SIZE=6 19 | 20 | export HIGH_TEMPERATURE=0.7 21 | export LOW_TEMPERATURE=0.01 22 | export NUM_BEAMS=5 23 | 24 | export ADD_PARAMETERS="" 25 | if [ "${INSTRUCT}" != "False" ]; 26 | then 27 | ADD_PARAMETERS="--with-instruct " 28 | fi 29 | 30 | LOG_FILE="bigtrans_inference_local.log" 31 | 32 | 33 | # HIGH TEPERATURE, MORE CREATIVE 34 | # export OUT_TIME=3 35 | # python -u model/inference.py \ 36 | # --model ${CHECKPOINT_PATH} \ 37 | # --tokenizer-path ${TOKENIZER_PATH} \ 38 | # --prompt-file ${PROMPT_FILE} \ 39 | # ${ADD_PARAMETERS} \ 40 | # --out-file ${HIGH_OUT_FILE} \ 41 | # --seed ${SEED} \ 42 | # --times ${OUT_TIME} \ 43 | # --max-tokens ${MAX_TOKENS} \ 44 | # --no-repeat-ngram-size ${NO_REPEAT_NGRAM_SIZE} \ 45 | # --top-k ${TOP_K} \ 46 | # --top-p ${TOP_P} \ 47 | # --temperature ${HIGH_TEMPERATURE} 2>&1 >>${LOG_FILE} 48 | 49 | 50 | # LOW TEPERATURE, MORE REALIABLE 51 | # export OUT_TIME=3 52 | # python -u model/inference.py \ 53 | # --model ${CHECKPOINT_PATH} \ 54 | # --tokenizer-path ${TOKENIZER_PATH} \ 55 | # --prompt-file ${PROMPT_FILE} \ 56 | # ${ADD_PARAMETERS} \ 57 | # --out-file ${LOW_OUT_FILE} \ 58 | # --seed ${SEED} \ 59 | # --times ${OUT_TIME} \ 60 | # --max-tokens ${MAX_TOKENS} \ 61 | # --no-repeat-ngram-size ${NO_REPEAT_NGRAM_SIZE} \ 62 | # --top-k ${TOP_K} \ 63 | # --top-p ${TOP_P} \ 64 | # --temperature ${LOW_TEMPERATURE} 2>&1 >>${LOG_FILE} 65 | 66 | 67 | # BEAM SEARCH, DETERMINISTIC 68 | export OUT_TIME=1 69 | python -u model/inference.py \ 70 | --model ${CHECKPOINT_PATH} \ 71 | --tokenizer-path ${TOKENIZER_PATH} \ 72 | --prompt-file ${PROMPT_FILE} \ 73 | ${ADD_PARAMETERS} \ 74 | --out-file ${BEAM_OUT_FILE} \ 75 | --seed ${SEED} \ 76 | --beam-search \ 77 | --num-beams ${NUM_BEAMS} \ 78 | --times ${OUT_TIME} \ 79 | --max-tokens ${MAX_TOKENS} \ 80 | --no-repeat-ngram-size ${NO_REPEAT_NGRAM_SIZE} 2>&1 >>${LOG_FILE} 81 | -------------------------------------------------------------------------------- /model/inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | from nltk.translate.bleu_score import sentence_bleu 4 | import jieba 5 | import llama 6 | import argparse 7 | from accelerate.utils import set_seed 8 | import tensor_parallel as tp 9 | 10 | import os 11 | 12 | 13 | PROMPT_DICT = { 14 | # "prompt_instruct": ( 15 | # "以下是一个描述任务的指令,并配有一个提供详细上下文信息的输入。" 16 | # "请写一个完成该指令的适当回复。\n\n" 17 | # "### 指令:\n{instruction}\n\n### 输入:\n{input}\n\n### 回复:" 18 | # ), 19 | "prompt_input": ( 20 | "以下是一个描述任务的指令,请写一个完成该指令的适当回复。\n\n" 21 | "### 指令:\n{0}\n\n### 回复:" 22 | ), 23 | } 24 | 25 | 26 | def read_prompt(file_path:str): 27 | file_handle = open(file_path) 28 | prompts = [] 29 | 30 | while True: 31 | line = file_handle.readline() 32 | if not line: 33 | break 34 | line = line.strip() 35 | prompts.append(line) 36 | 37 | return prompts 38 | 39 | 40 | def cut2list(line): 41 | line_cut = jieba.cut(line, cut_all=True) 42 | line_list = [c for c in line_cut] 43 | out_list = [] 44 | for c in line_list: 45 | if len(c) == 0: 46 | continue 47 | if c == ' ': 48 | continue 49 | out_list.append(c) 50 | return out_list 51 | 52 | def single_prompt(model, tokenizer, prompt="Hello, I'm am conscious and", max_new_tokens:int=128, do_sample:bool=True, num_beams:int=1, top_k:int=50, top_p:float=0.95, no_repeat_ngram_size=6, temperature:float=0.7, cuda=True, verbose=False): 53 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids 54 | if cuda: 55 | # model = model.cuda() 56 | 57 | model = tp.tensor_parallel(model) 58 | input_ids = input_ids.cuda() 59 | 60 | with torch.inference_mode(): 61 | generated_ids = model.generate(input_ids, max_new_tokens=max_new_tokens, do_sample=do_sample, num_beams=num_beams, top_k=top_k, top_p=top_p, temperature=temperature, no_repeat_ngram_size=no_repeat_ngram_size) 62 | 63 | results = tokenizer.batch_decode(generated_ids, skip_special_tokens=True, spaces_between_special_tokens=False) 64 | 65 | if verbose: 66 | print(results) 67 | return results 68 | 69 | def batch_prompt(model, tokenizer, prompts:list=["Hello, I'm am conscious and"], max_new_tokens:int=128, do_sample:bool=True, num_beams:int=1, top_k:int=50, top_p:float=0.95, temperature:float=0.7, no_repeat_ngram_size=6, cuda=True, verbose=False): 70 | tokenizer.padding_side="left" 71 | tokenizer.pad_token_id = tokenizer.eos_token_id 72 | model.config.pad_token_id = model.config.eos_token_id 73 | tokens = tokenizer(prompts, return_tensors="pt", padding=True) 74 | input_ids = tokens["input_ids"] 75 | attention_mask = tokens["attention_mask"] 76 | if cuda: 77 | # model = model.cuda() 78 | 79 | model = tp.tensor_parallel(model) 80 | input_ids = input_ids.cuda() 81 | attention_mask = attention_mask.cuda() 82 | with torch.inference_mode(): 83 | generated_ids = model.generate(input_ids,attention_mask=attention_mask, max_new_tokens=max_new_tokens, do_sample=do_sample, num_beams=num_beams, top_k=top_k, top_p=top_p, temperature=temperature, no_repeat_ngram_size=no_repeat_ngram_size) 84 | results = tokenizer.batch_decode(generated_ids, skip_special_tokens=True, spaces_between_special_tokens=False) 85 | if verbose: 86 | print(results) 87 | return results 88 | 89 | def back_process(str_out:str): 90 | if "\n" in str_out: 91 | str_out = str_out[:str_out.find("\n")] 92 | 93 | return str_out 94 | 95 | def eval_translate(model, tokenizer, prompts, gold_path, context="Chinese: 我想回家。\nEnglish: I want to go home.\n\nChinese: 我不知道。\nEnglish: I don't know.\n\nChinese: {}\nEnglish: ", cuda=True, split_translate=False, generation_args:dict=None): 96 | file_handler = open(gold_path,"r",encoding="utf-8") 97 | 98 | total_b = 0 99 | for prompt in prompts: 100 | 101 | s_t = time.time() 102 | if not split_translate: 103 | str_in = context.format(prompt) 104 | print(f"Input:{str_in}\n") 105 | decode_res = single_prompt(model=model, tokenizer=tokenizer, prompt=str_in, cuda=cuda, verbose=False, **generation_args) 106 | predict_res = back_process(decode_res[0][len(str_in):]) 107 | else: 108 | inputs = prompt.split(",") 109 | preds = [] 110 | for str_id, str_in in enumerate(inputs): 111 | if len(str_in) == 0: 112 | continue 113 | str_in = context.format(str_in) 114 | print(f"Input-{str_id+1}:{str_in}\n") 115 | 116 | decode_res = single_prompt(model=model, tokenizer=tokenizer, prompt=str_in, cuda=cuda, verbose=False, **generation_args) 117 | decode_res = back_process(decode_res[0][len(str_in):]) 118 | 119 | print(f"Output-{str_id+1}:{decode_res}\n") 120 | preds.append(decode_res) 121 | 122 | predict_res = ", ".join(preds) 123 | 124 | gold_line = file_handler.readline() 125 | 126 | print("Output:",predict_res) 127 | print("Gold:",gold_line) 128 | 129 | gold_list = cut2list(gold_line) 130 | 131 | predict_list = cut2list(predict_res) 132 | 133 | curr_b = sentence_bleu([gold_list], predict_list) 134 | total_b += curr_b 135 | 136 | e_t = time.time() 137 | print(f"Time cost:{e_t-s_t}s, bleu:{curr_b} \n\n") 138 | 139 | file_handler.close() 140 | 141 | print(f"Average bleu: {total_b/len(prompts)}") 142 | 143 | 144 | def to_matrix(l, n): 145 | return [l[i:i+n] for i in range(0, len(l), n)] 146 | 147 | if __name__ == '__main__': 148 | parser = argparse.ArgumentParser() 149 | 150 | parser.add_argument( 151 | "--model", 152 | type=str, 153 | default=None, 154 | help="The name of model to use.", 155 | ) 156 | 157 | parser.add_argument( 158 | "--tokenizer-path", 159 | type=str, 160 | default=None, 161 | help="The path of tokenizer to use.", 162 | ) 163 | 164 | parser.add_argument( 165 | "--beam-search", 166 | action='store_true', 167 | help="Whether to run beam search." 168 | ) 169 | 170 | parser.add_argument( 171 | "--with-instruct", 172 | action='store_true', 173 | help="Whether to run beam search." 174 | ) 175 | 176 | parser.add_argument( 177 | "--num-beams", 178 | type=int, 179 | default=5 180 | ) 181 | 182 | parser.add_argument( 183 | "--no-repeat-ngram-size", 184 | type=int, 185 | default=0 186 | ) 187 | 188 | parser.add_argument( 189 | "--checkpoint", 190 | type=str, 191 | default=None 192 | ) 193 | 194 | parser.add_argument( 195 | "--prompt-file", 196 | type=str, 197 | default=None 198 | ) 199 | parser.add_argument( 200 | "--prompt-path", 201 | type=str, 202 | default=None 203 | ) 204 | 205 | parser.add_argument( 206 | "--out-file", 207 | type=str, 208 | default=None 209 | ) 210 | 211 | parser.add_argument( 212 | "--seed", 213 | type=int, 214 | default=0 215 | ) 216 | 217 | parser.add_argument( 218 | "--translate", 219 | action='store_true', 220 | help="Whether to run translate." 221 | ) 222 | 223 | parser.add_argument( 224 | "--batch-inference", 225 | action='store_true', 226 | help="Whether to run inference in batch." 227 | ) 228 | 229 | parser.add_argument( 230 | "--batch-size", 231 | type=int, 232 | default=3, 233 | help="Batch size of run inference in batch." 234 | ) 235 | 236 | parser.add_argument( 237 | "--split-translate", 238 | action='store_true', 239 | help="Whether to run translate by split on dot." 240 | ) 241 | 242 | 243 | parser.add_argument( 244 | "--gold-file", 245 | type=str, 246 | default=None 247 | ) 248 | 249 | parser.add_argument( 250 | "--times", 251 | type=int, 252 | default=3, 253 | help="Number of generation for each prompt.", 254 | ) 255 | 256 | parser.add_argument( 257 | "--max-tokens", 258 | type=int, 259 | default=256 260 | ) 261 | 262 | parser.add_argument( 263 | "--top-k", 264 | type=int, 265 | default=80, 266 | help="The configuration top k tokens in the generation of model.", 267 | ) 268 | 269 | parser.add_argument( 270 | "--top-p", 271 | type=float, 272 | default=0.95, 273 | help="The configuration top p tokens in the generation of model.", 274 | ) 275 | 276 | parser.add_argument( 277 | "--temperature", 278 | type=float, 279 | default=0.7, 280 | help="The configuration temperature in the generation of model.", 281 | ) 282 | 283 | args = parser.parse_args() 284 | 285 | set_seed(args.seed) 286 | 287 | 288 | print(args.out_file) 289 | 290 | config = llama.LLaMAConfig.from_pretrained(args.model) 291 | tokenizer = llama.LLaMATokenizer.from_pretrained(args.tokenizer_path) 292 | model = llama.LLaMAForCausalLM.from_pretrained( 293 | args.model, 294 | torch_dtype=torch.float16, 295 | config=config, 296 | state_dict=torch.load(args.checkpoint) if args.checkpoint is not None else None 297 | ) 298 | 299 | 300 | 301 | generation_config = { 302 | "do_sample": not args.beam_search, 303 | "num_beams": args.num_beams if args.beam_search else 1, 304 | "max_new_tokens": args.max_tokens, 305 | "no_repeat_ngram_size": args.no_repeat_ngram_size, 306 | "top_k": args.top_k, 307 | "top_p": args.top_p, 308 | "temperature": args.temperature 309 | } 310 | 311 | 312 | if args.prompt_path is not None: 313 | for file in os.listdir(args.prompt_path): 314 | print(f"Inference {file}...") 315 | prompts = read_prompt(os.path.join(args.prompt_path,file)) 316 | print(len(prompts)) 317 | if args.translate: 318 | # do translation 319 | eval_translate(model=model, tokenizer=tokenizer, prompts=prompts, gold_path=args.gold_file, cuda=True, split_translate=args.split_translate, generation_args=generation_config) 320 | elif args.batch_inference: 321 | # batch inference 322 | out_handle = open(args.out_file,"w",encoding="utf-8") 323 | for attr, value in sorted(args.__dict__.items()): 324 | print(f"\t{attr}={value}") 325 | # out_handle.write(f"\t{attr}={value}") 326 | t_id = 1 327 | prompt_list = to_matrix(prompts, args.batch_size) 328 | for prompts_in in prompt_list: 329 | outputs = batch_prompt(model, tokenizer, prompts=prompts_in, cuda=True, **generation_config) 330 | for out_str in outputs: 331 | out_handle.write(f"T-{t_id} {out_str}\n") 332 | t_id += 1 333 | out_handle.close() 334 | else: 335 | # inference 336 | out_handle = open(os.path.join(args.out_file, file),"w",encoding="utf-8") 337 | for attr, value in sorted(args.__dict__.items()): 338 | print(f"\t{attr}={value}") 339 | # out_handle.write(f"\t{attr}={value}") 340 | 341 | for prompt in prompts: 342 | if args.with_instruct: 343 | prompt = PROMPT_DICT["prompt_input"].format(prompt) 344 | 345 | # out_handle.write(f"\n*****Input: {prompt}\n") 346 | 347 | for i in range(args.times): 348 | s_t = time.time() 349 | results = single_prompt(model=model, tokenizer=tokenizer, prompt=prompt, cuda=True, **generation_config) 350 | step_time = time.time() - s_t 351 | 352 | out_handle.write(f"\n*****Output(Time-{i+1},cost {step_time:.2f}s): ") 353 | out_handle.write(results[0]+"\n") 354 | 355 | out_handle.close() 356 | else: 357 | prompts = read_prompt(args.prompt_file) 358 | 359 | if args.translate: 360 | # do translation 361 | eval_translate(model=model, tokenizer=tokenizer, prompts=prompts, gold_path=args.gold_file, cuda=True, split_translate=args.split_translate, generation_args=generation_config) 362 | elif args.batch_inference: 363 | # batch inference 364 | out_handle = open(args.out_file,"w",encoding="utf-8") 365 | for attr, value in sorted(args.__dict__.items()): 366 | print(f"\t{attr}={value}") 367 | # out_handle.write(f"\t{attr}={value}") 368 | t_id = 1 369 | prompt_list = to_matrix(prompts, args.batch_size) 370 | for prompts_in in prompt_list: 371 | outputs = batch_prompt(model, tokenizer, prompts=prompts_in, cuda=True, **generation_config) 372 | for out_str in outputs: 373 | out_handle.write(f"T-{t_id} {out_str}\n") 374 | t_id += 1 375 | 376 | out_handle.close() 377 | else: 378 | # inference 379 | out_handle = open(args.out_file,"w",encoding="utf-8") 380 | for attr, value in sorted(args.__dict__.items()): 381 | print(f"\t{attr}={value}") 382 | out_handle.write(f"\t{attr}={value}") 383 | 384 | for prompt in prompts: 385 | if args.with_instruct: 386 | prompt = PROMPT_DICT["prompt_input"].format(prompt) 387 | 388 | # out_handle.write(f"\n*****Input: {prompt}\n") 389 | 390 | for i in range(args.times): 391 | s_t = time.time() 392 | results = single_prompt(model=model, tokenizer=tokenizer, prompt=prompt, cuda=True, **generation_config) 393 | step_time = time.time() - s_t 394 | 395 | out_handle.write(f"\n*****Output(Time-{i+1},cost {step_time:.2f}s): ") 396 | out_handle.write(results[0]+"\n") 397 | 398 | out_handle.close() 399 | -------------------------------------------------------------------------------- /model/languages_abbreviation2fullname.txt: -------------------------------------------------------------------------------- 1 | af Afrikaans 南非荷兰语 2 | am Amharic 阿姆哈拉语 3 | an Aragonese 阿拉贡语 4 | ar Arabic 阿拉伯语 5 | as Assamese 阿萨姆语 6 | ast Asturian 阿斯图里亚斯语 7 | az Azerbaijani 阿塞拜疆语 8 | be Belarusian 白俄罗斯语 9 | bg Bulgarian 保加利亚语 10 | bn Bengali 孟加拉语 11 | bo Tibetan 藏语 12 | br Breton 布列塔尼语 13 | bs Bosnian 波斯尼亚语 14 | ca Catalan 加泰罗尼亚语 15 | cs Czech 捷克语 16 | cy Welsh 威尔士语 17 | da Danish 丹麦语 18 | de German 德语 19 | dz Dzongkha 宗喀语 20 | el Greek 希腊语 21 | en Engilish 英语 22 | eo Esperanto 世界语 23 | es Spanish 西班牙语 24 | et Estonian 爱沙尼亚语 25 | eu Basque 巴斯克语 26 | fa Persian 波斯语 27 | fi Finnish 芬兰语 28 | fr French 法语 29 | fy Western Frisian 西弗里斯兰语 30 | ga Irish 爱尔兰语 31 | gd Gaelic 盖尔语 32 | gl Galician 加利西亚语 33 | gu Gujarati 古吉拉特语 34 | ha Hausa 豪萨语 35 | he Hebrew 希伯来语 36 | hi Hindi 印地语 37 | hr Croatian 克罗地亚语 38 | hu Hungarian 匈牙利语 39 | hy Armenian 亚美尼亚语 40 | id Indonesian 印度尼西亚语 41 | ig Igbo 伊博语 42 | is Icelandic 冰岛语 43 | it Italian 意大利语 44 | ja Japanese 日语 45 | ka Georgian 格鲁吉亚语 46 | kk Kazakh 哈萨克语 47 | km Central Khmer 高棉语 48 | kn Kannada 卡纳达语 49 | ko Korean 韩语 50 | ku Kurdish 库尔德语 51 | ky Kyrgyz 吉尔吉斯语 52 | li Limburgan 林堡语 53 | lt Lithuanian 立陶宛语 54 | lv Latvian 拉脱维亚语 55 | mg Malagasy 马尔加什语 56 | mk Macedonian 马其顿语 57 | ml Malayalam 马拉雅拉姆语 58 | mo Mongolian 蒙古语 59 | mr Marathi 马拉地语 60 | ms Malay 马来语 61 | mt Maltese 马耳他语 62 | my Burmese 缅甸语 63 | nb Norwegian Bokmal 挪威博克马尔语 巴克摩挪威文 64 | ne Nepali 尼泊尔语 65 | nl Dutch 荷兰语 66 | nn Norwegian Nynorsk 尼诺斯克挪威语 新挪威语 67 | no Norwegian 挪威语 68 | oc Occitan 奥克西唐语 69 | or Oriya 奥里亚语 70 | pa Panjabi 旁遮普语 71 | pl Polish 波兰语 72 | ps Pashto 普什图语 73 | pt Portuguese 葡萄牙语 74 | ro Romanian 罗马尼亚语 75 | ru Russian 俄语 76 | rw Kinyarwanda 卢旺达语 77 | se Northern Sami 北萨米语 78 | sh Serbo-Croatian 塞尔维亚-克罗地亚语 塞克语 79 | si Sinhala 僧伽罗语 80 | sk Slovak 斯洛伐克语 81 | sl Slovenian 斯洛文尼亚语 82 | sq Albanian 阿尔巴尼亚语 83 | sr Serbian 塞尔维亚语 84 | sv Swedish 瑞典语 85 | ta Tamil 泰米尔语 86 | te Telugu 泰卢固语 87 | tg Tajik 塔吉克语 88 | th Thai 泰语 89 | tk Turkmen 土库曼语 90 | tr Turkish 土耳其语 91 | tt Tatar 鞑靼语 92 | uk Ukrainian 乌克兰语 93 | ur Urdu 乌尔都语 94 | uy Uighur 维吾尔语 95 | uz Uzbek 乌兹别克语 96 | vi Vietnamese 越南语 97 | wa Walloon 瓦隆语 98 | xh Xhosa 科萨语 99 | yi Yiddish 意第绪语 100 | yo Yoruba 约鲁巴语 101 | zh Chinese 汉语 102 | zu Zulu 祖鲁语 -------------------------------------------------------------------------------- /model/llama/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import TYPE_CHECKING 15 | 16 | from transformers.utils import ( 17 | OptionalDependencyNotAvailable, 18 | _LazyModule, 19 | is_torch_available, 20 | is_sentencepiece_available, 21 | ) 22 | 23 | 24 | _import_structure = { 25 | "configuration_llama": ["LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP", "LLaMAConfig"], 26 | } 27 | 28 | try: 29 | if not is_sentencepiece_available(): 30 | raise OptionalDependencyNotAvailable() 31 | except OptionalDependencyNotAvailable: 32 | pass 33 | else: 34 | _import_structure["tokenization_llama"] = ["LLaMATokenizer"] 35 | 36 | try: 37 | if not is_torch_available(): 38 | raise OptionalDependencyNotAvailable() 39 | except OptionalDependencyNotAvailable: 40 | pass 41 | else: 42 | _import_structure["modeling_llama"] = [ 43 | "LLaMAForCausalLM", 44 | "LLaMAModel", 45 | "LLaMAPreTrainedModel", 46 | ] 47 | 48 | 49 | if TYPE_CHECKING: 50 | from .configuration_llama import LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP, LLaMAConfig 51 | 52 | try: 53 | if not is_sentencepiece_available(): 54 | raise OptionalDependencyNotAvailable() 55 | except OptionalDependencyNotAvailable: 56 | pass 57 | else: 58 | from .tokenization_llama import LLaMATokenizer 59 | 60 | try: 61 | if not is_torch_available(): 62 | raise OptionalDependencyNotAvailable() 63 | except OptionalDependencyNotAvailable: 64 | pass 65 | else: 66 | from .modeling_llama import ( 67 | LLaMAForCausalLM, 68 | LLaMAModel, 69 | LLaMAPreTrainedModel, 70 | ) 71 | 72 | 73 | else: 74 | import sys 75 | 76 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) -------------------------------------------------------------------------------- /model/llama/configuration_llama.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """ LLaMA model configuration""" 21 | 22 | from transformers.configuration_utils import PretrainedConfig 23 | from transformers.utils import logging 24 | 25 | 26 | logger = logging.get_logger(__name__) 27 | 28 | LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {} 29 | 30 | 31 | class LLaMAConfig(PretrainedConfig): 32 | r""" 33 | This is the configuration class to store the configuration of a [`~LLaMAModel`]. It is used to instantiate an LLaMA 34 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the 35 | defaults will yield a similar configuration to that of the LLaMA-7B. 36 | 37 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 38 | documentation from [`PretrainedConfig`] for more information. 39 | 40 | 41 | Args: 42 | vocab_size (`int`, *optional*, defaults to 32000): 43 | Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the 44 | `inputs_ids` passed when calling [`~LLaMAModel`] or [`~TFLLaMAModel`]. 45 | hidden_size (`int`, *optional*, defaults to 4096): 46 | Dimension of the hidden representations. 47 | intermediate_size (`int`, *optional*, defaults to 11008): 48 | Dimension of the MLP representations. 49 | num_hidden_layers (`int`, *optional*, defaults to 32): 50 | Number of hidden layers in the Transformer encoder. 51 | num_attention_heads (`int`, *optional*, defaults to 32): 52 | Number of attention heads for each attention layer in the Transformer encoder. 53 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 54 | The non-linear activation function (function or string) in the decoder. 55 | initializer_range (`float`, *optional*, defaults to 0.02): 56 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 57 | rms_norm_eps (`float`, *optional*, defaults to 1e-12): 58 | The epsilon used by the rms normalization layers. 59 | use_cache (`bool`, *optional*, defaults to `True`): 60 | Whether or not the model should return the last key/values attentions (not used by all models). Only 61 | relevant if `config.is_decoder=True`. 62 | tie_word_embeddings(`bool`, *optional*, defaults to `False`): 63 | Whether to tie weight embeddings 64 | Example: 65 | 66 | ```python 67 | >>> from transformers import LLaMAModel, LLaMAConfig 68 | 69 | >>> # Initializing a LLaMA llama-7b style configuration 70 | >>> configuration = LLaMAConfig() 71 | 72 | >>> # Initializing a model from the llama-7b style configuration 73 | >>> model = LLaMAModel(configuration) 74 | 75 | >>> # Accessing the model configuration 76 | >>> configuration = model.config 77 | ```""" 78 | model_type = "llama" 79 | 80 | def __init__( 81 | self, 82 | vocab_size=32000, 83 | hidden_size=4096, 84 | intermediate_size=11008, 85 | num_hidden_layers=32, 86 | num_attention_heads=32, 87 | hidden_act="silu", 88 | initializer_range=0.02, 89 | rms_norm_eps=1e-6, 90 | use_cache=True, 91 | pad_token_id=-1, 92 | bos_token_id=0, 93 | eos_token_id=1, 94 | tie_word_embeddings=False, 95 | **kwargs, 96 | ): 97 | self.vocab_size = vocab_size 98 | self.hidden_size = hidden_size 99 | self.intermediate_size = intermediate_size 100 | self.num_hidden_layers = num_hidden_layers 101 | self.num_attention_heads = num_attention_heads 102 | self.hidden_act = hidden_act 103 | self.initializer_range = initializer_range 104 | self.rms_norm_eps = rms_norm_eps 105 | self.use_cache = use_cache 106 | super().__init__( 107 | pad_token_id=pad_token_id, 108 | bos_token_id=bos_token_id, 109 | eos_token_id=eos_token_id, 110 | tie_word_embeddings=tie_word_embeddings, 111 | **kwargs, 112 | ) -------------------------------------------------------------------------------- /model/llama/convert_llama_weights_to_hf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import argparse 15 | import json 16 | import os 17 | import shutil 18 | 19 | import torch 20 | 21 | 22 | """ 23 | Sample usage: 24 | 25 | ``` 26 | python src/transformers/models/llama/convert_llama_weights_to_hf.py \ 27 | --input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path 28 | ``` 29 | 30 | Thereafter, models can be loaded via: 31 | 32 | ``` 33 | tokenizer = transformers.LLaMATokenizer.from_pretrained("/output/path/tokenizer/") 34 | 35 | model = transformers.LLaMAForCausalLM.from_pretrained("/output/path/llama-7b/") 36 | ``` 37 | """ 38 | 39 | INTERMEDIATE_SIZE_MAP = { 40 | "7B": 11008, 41 | "13B": 13824, 42 | "30B": 17920, 43 | "65B": 22016, 44 | } 45 | NUM_SHARDS = { 46 | "7B": 1, 47 | "13B": 2, 48 | "30B": 4, 49 | "65B": 8, 50 | } 51 | 52 | 53 | def read_json(path): 54 | with open(path, "r") as f: 55 | return json.load(f) 56 | 57 | 58 | def write_json(text, path): 59 | with open(path, "w") as f: 60 | json.dump(text, f) 61 | 62 | 63 | def write_model(model_path, input_base_path, model_size): 64 | assert model_size in INTERMEDIATE_SIZE_MAP 65 | os.makedirs(model_path, exist_ok=True) 66 | 67 | params = read_json(os.path.join(input_base_path, "params.json")) 68 | num_shards = NUM_SHARDS[model_size] 69 | n_layers = params["n_layers"] 70 | n_heads = params["n_heads"] 71 | n_heads_per_shard = n_heads // num_shards 72 | dim = params["dim"] 73 | dims_per_head = dim // n_heads 74 | base = 10000.0 75 | inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) 76 | 77 | # permute for sliced rotary 78 | def permute(w): 79 | return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim) 80 | 81 | # Load weights 82 | if model_size == "7B": 83 | # Not shared 84 | # (The sharded implementation would also work, but this is simpler.) 85 | loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu") 86 | else: 87 | # Sharded 88 | loaded = [ 89 | torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu") 90 | for i in range(num_shards) 91 | ] 92 | param_count = 0 93 | index_dict = {"weight_map": {}} 94 | for layer_i in range(n_layers): 95 | filename = "pytorch_model-{:05d}-of-{:05d}.bin".format( 96 | layer_i + 1, 97 | n_layers + 1, 98 | ) 99 | if model_size == "7B": 100 | # Unsharded 101 | state_dict = { 102 | f"model.layers.{layer_i}.self_attn.q_proj.weight": permute( 103 | loaded[f"layers.{layer_i}.attention.wq.weight"] 104 | ), 105 | f"model.layers.{layer_i}.self_attn.k_proj.weight": permute( 106 | loaded[f"layers.{layer_i}.attention.wk.weight"] 107 | ), 108 | f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"], 109 | f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"], 110 | f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"], 111 | f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"], 112 | f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"], 113 | f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"layers.{layer_i}.attention_norm.weight"], 114 | f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"layers.{layer_i}.ffn_norm.weight"], 115 | } 116 | else: 117 | # Sharded 118 | state_dict = { 119 | f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][f"layers.{layer_i}.attention_norm.weight"], 120 | f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][ 121 | f"layers.{layer_i}.ffn_norm.weight" 122 | ], 123 | } 124 | state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute( 125 | torch.cat( 126 | [ 127 | loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim) 128 | for i in range(num_shards) 129 | ], 130 | dim=0, 131 | ).reshape(dim, dim) 132 | ) 133 | state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute( 134 | torch.cat( 135 | [ 136 | loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(n_heads_per_shard, dims_per_head, dim) 137 | for i in range(num_shards) 138 | ], 139 | dim=0, 140 | ).reshape(dim, dim) 141 | ) 142 | state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( 143 | [ 144 | loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(n_heads_per_shard, dims_per_head, dim) 145 | for i in range(num_shards) 146 | ], 147 | dim=0, 148 | ).reshape(dim, dim) 149 | 150 | state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( 151 | [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1 152 | ) 153 | state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( 154 | [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0 155 | ) 156 | state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( 157 | [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1 158 | ) 159 | state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( 160 | [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0 161 | ) 162 | 163 | state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq 164 | for k, v in state_dict.items(): 165 | index_dict["weight_map"][k] = filename 166 | param_count += v.numel() 167 | torch.save(state_dict, os.path.join(model_path, filename)) 168 | 169 | filename = "pytorch_model-{:05d}-of-{:05d}.bin".format( 170 | n_layers + 1, 171 | n_layers + 1, 172 | ) 173 | if model_size == "7B": 174 | # Unsharded 175 | state_dict = { 176 | "model.embed_tokens.weight": loaded["tok_embeddings.weight"], 177 | "model.norm.weight": loaded["norm.weight"], 178 | "lm_head.weight": loaded["output.weight"], 179 | } 180 | else: 181 | state_dict = { 182 | "model.norm.weight": loaded[0]["norm.weight"], 183 | "model.embed_tokens.weight": torch.cat( 184 | [loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1 185 | ), 186 | "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0), 187 | } 188 | 189 | for k, v in state_dict.items(): 190 | index_dict["weight_map"][k] = filename 191 | param_count += v.numel() 192 | torch.save(state_dict, os.path.join(model_path, filename)) 193 | 194 | # Write configs 195 | index_dict["metadata"] = {"total_size": param_count * 2} 196 | write_json(index_dict, os.path.join(model_path, "pytorch_model.bin.index.json")) 197 | config_out = { 198 | "architectures": ["LLaMAForCausalLM"], 199 | "bos_token_id": 0, 200 | "eos_token_id": 1, 201 | "hidden_act": "silu", 202 | "hidden_size": params["dim"], 203 | "intermediate_size": INTERMEDIATE_SIZE_MAP[model_size], 204 | "initializer_range": 0.02, 205 | "max_sequence_length": 2048, 206 | "model_type": "llama", 207 | "num_attention_heads": params["n_heads"], 208 | "num_hidden_layers": params["n_layers"], 209 | "pad_token_id": -1, 210 | "rms_norm_eps": params["norm_eps"], 211 | "torch_dtype": "float16", 212 | "transformers_version": "4.27.0.dev0", 213 | "use_cache": True, 214 | "vocab_size": 32000, 215 | } 216 | write_json( 217 | config_out, 218 | os.path.join(model_path, "config.json"), 219 | ) 220 | generation_config = { 221 | "_from_model_config": True, 222 | "bos_token_id": 0, 223 | "eos_token_id": 1, 224 | "pad_token_id": 0, 225 | "transformers_version": "4.27.0.dev0", 226 | } 227 | write_json( 228 | generation_config, 229 | os.path.join(model_path, "generation_config.json"), 230 | ) 231 | 232 | 233 | def write_tokenizer(tokenizer_path, input_tokenizer_path): 234 | os.makedirs(tokenizer_path, exist_ok=True) 235 | write_json({}, os.path.join(tokenizer_path, "special_tokens_map.json")) 236 | write_json( 237 | { 238 | "bos_token": "", 239 | "eos_token": "", 240 | "model_max_length": int(1e30), 241 | "tokenizer_class": "LLaMATokenizer", 242 | "unk_token": "", 243 | }, 244 | os.path.join(tokenizer_path, "tokenizer_config.json"), 245 | ) 246 | shutil.copyfile(input_tokenizer_path, os.path.join(tokenizer_path, "tokenizer.model")) 247 | 248 | 249 | def main(): 250 | parser = argparse.ArgumentParser() 251 | parser.add_argument( 252 | "--input_dir", 253 | help="Location of LLaMA weights, which contains tokenizer.model and model folders", 254 | ) 255 | parser.add_argument( 256 | "--model_size", 257 | choices=["7B", "13B", "30B", "65B"], 258 | ) 259 | parser.add_argument( 260 | "--output_dir", 261 | help="Location to write HF model and tokenizer", 262 | ) 263 | args = parser.parse_args() 264 | write_model( 265 | model_path=os.path.join(args.output_dir, "llama-{}".format(args.model_size).lower()), 266 | input_base_path=os.path.join(args.input_dir, args.model_size), 267 | model_size=args.model_size, 268 | ) 269 | write_tokenizer( 270 | tokenizer_path=os.path.join(args.output_dir, "tokenizer"), 271 | input_tokenizer_path=os.path.join(args.input_dir, "tokenizer.model"), 272 | ) 273 | 274 | 275 | if __name__ == "__main__": 276 | main() -------------------------------------------------------------------------------- /model/llama/modeling_llama.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """ PyTorch LLaMA model.""" 21 | import math 22 | from typing import List, Optional, Tuple, Union 23 | 24 | import torch 25 | import torch.utils.checkpoint 26 | from torch import nn 27 | from torch.nn import CrossEntropyLoss 28 | 29 | from transformers.activations import ACT2FN 30 | from transformers.modeling_outputs import ( 31 | BaseModelOutputWithPast, 32 | CausalLMOutputWithPast, 33 | ) 34 | from transformers.modeling_utils import PreTrainedModel 35 | from transformers.utils import ( 36 | add_code_sample_docstrings, 37 | add_start_docstrings, 38 | add_start_docstrings_to_model_forward, 39 | logging, 40 | replace_return_docstrings, 41 | ) 42 | from .configuration_llama import LLaMAConfig 43 | 44 | 45 | logger = logging.get_logger(__name__) 46 | 47 | _CHECKPOINT_FOR_DOC = "llama-7b" 48 | _CONFIG_FOR_DOC = "LLaMAConfig" 49 | 50 | 51 | def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): 52 | """ 53 | Make causal mask used for bi-directional self-attention. 54 | """ 55 | bsz, tgt_len = input_ids_shape 56 | mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) 57 | mask_cond = torch.arange(mask.size(-1)) 58 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) 59 | mask = mask.to(dtype) 60 | 61 | if past_key_values_length > 0: 62 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) 63 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) 64 | 65 | 66 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): 67 | """ 68 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. 69 | """ 70 | bsz, src_len = mask.size() 71 | tgt_len = tgt_len if tgt_len is not None else src_len 72 | 73 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) 74 | 75 | inverted_mask = 1.0 - expanded_mask 76 | 77 | return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) 78 | 79 | 80 | class RMSNorm(nn.Module): 81 | def __init__(self, hidden_size, eps=1e-6): 82 | """ 83 | RMSNorm is equivalent to T5LayerNorm 84 | """ 85 | super().__init__() 86 | self.weight = nn.Parameter(torch.ones(hidden_size)) 87 | self.variance_epsilon = eps 88 | 89 | def forward(self, hidden_states): 90 | variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) 91 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 92 | 93 | # convert into half-precision if necessary 94 | if self.weight.dtype in [torch.float16, torch.bfloat16]: 95 | hidden_states = hidden_states.to(self.weight.dtype) 96 | 97 | return self.weight * hidden_states 98 | 99 | 100 | class RotaryEmbedding(torch.nn.Module): 101 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): 102 | super().__init__() 103 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) 104 | self.register_buffer("inv_freq", inv_freq) 105 | 106 | # Build here to make `torch.jit.trace` work. 107 | self.max_seq_len_cached = max_position_embeddings 108 | t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) 109 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 110 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 111 | emb = torch.cat((freqs, freqs), dim=-1) 112 | self.cos_cached = emb.cos()[None, None, :, :] 113 | self.sin_cached = emb.sin()[None, None, :, :] 114 | 115 | def forward(self, x, seq_len=None): 116 | # x: [bs, num_attention_heads, seq_len, head_size] 117 | # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. 118 | if seq_len > self.max_seq_len_cached: 119 | self.max_seq_len_cached = seq_len 120 | t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) 121 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 122 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 123 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 124 | self.cos_cached = emb.cos()[None, None, :, :].to(dtype=x.dtype) 125 | self.sin_cached = emb.sin()[None, None, :, :].to(dtype=x.dtype) 126 | return ( 127 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype, device=x.device), 128 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype, device=x.device), 129 | ) 130 | 131 | 132 | def rotate_half(x): 133 | """Rotates half the hidden dims of the input.""" 134 | x1 = x[..., : x.shape[-1] // 2] 135 | x2 = x[..., x.shape[-1] // 2 :] 136 | return torch.cat((-x2, x1), dim=-1) 137 | 138 | 139 | def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): 140 | cos = cos[..., offset : q.shape[-2] + offset, :] 141 | sin = sin[..., offset : q.shape[-2] + offset, :] 142 | q_embed = (q * cos) + (rotate_half(q) * sin) 143 | k_embed = (k * cos) + (rotate_half(k) * sin) 144 | return q_embed, k_embed 145 | 146 | 147 | class LLaMAMLP(nn.Module): 148 | def __init__( 149 | self, 150 | hidden_size: int, 151 | intermediate_size: int, 152 | hidden_act: str, 153 | ): 154 | super().__init__() 155 | self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) 156 | self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) 157 | self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) 158 | self.act_fn = ACT2FN[hidden_act] 159 | 160 | def forward(self, x): 161 | return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) 162 | 163 | 164 | class LLaMAAttention(nn.Module): 165 | """Multi-headed attention from 'Attention Is All You Need' paper""" 166 | 167 | def __init__( 168 | self, 169 | hidden_size: int, 170 | num_heads: int, 171 | ): 172 | super().__init__() 173 | self.hidden_size = hidden_size 174 | self.num_heads = num_heads 175 | self.head_dim = hidden_size // num_heads 176 | 177 | if (self.head_dim * num_heads) != self.hidden_size: 178 | raise ValueError( 179 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" 180 | f" and `num_heads`: {num_heads})." 181 | ) 182 | self.q_proj = nn.Linear( 183 | hidden_size, 184 | num_heads * self.head_dim, 185 | bias=False, 186 | ) 187 | self.k_proj = nn.Linear( 188 | hidden_size, 189 | num_heads * self.head_dim, 190 | bias=False, 191 | ) 192 | self.v_proj = nn.Linear( 193 | hidden_size, 194 | num_heads * self.head_dim, 195 | bias=False, 196 | ) 197 | self.o_proj = nn.Linear( 198 | num_heads * self.head_dim, 199 | hidden_size, 200 | bias=False, 201 | ) 202 | self.rotary_emb = RotaryEmbedding(self.head_dim) 203 | 204 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 205 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 206 | 207 | def forward( 208 | self, 209 | hidden_states: torch.Tensor, 210 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 211 | attention_mask: Optional[torch.Tensor] = None, 212 | output_attentions: bool = False, 213 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 214 | """Input shape: Batch x Time x Channel""" 215 | 216 | bsz, q_len, _ = hidden_states.size() 217 | 218 | query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 219 | key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 220 | value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 221 | 222 | kv_seq_len = key_states.shape[-2] 223 | offset = 0 224 | if past_key_value is not None: 225 | offset = past_key_value[0].shape[-2] 226 | kv_seq_len += offset 227 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 228 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, offset=offset) 229 | # [bsz, nh, t, hd] 230 | 231 | if past_key_value is not None: 232 | # reuse k, v, self_attention 233 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 234 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 235 | 236 | past_key_value = (key_states, value_states) 237 | 238 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) 239 | 240 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 241 | raise ValueError( 242 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" 243 | f" {attn_weights.size()}" 244 | ) 245 | 246 | if attention_mask is not None: 247 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 248 | raise ValueError( 249 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 250 | ) 251 | attn_weights = attn_weights + attention_mask 252 | attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) 253 | 254 | # upcast attention to fp32 255 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 256 | attn_output = torch.matmul(attn_weights, value_states) 257 | 258 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 259 | raise ValueError( 260 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 261 | f" {attn_output.size()}" 262 | ) 263 | 264 | attn_output = attn_output.transpose(1, 2) 265 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 266 | 267 | attn_output = self.o_proj(attn_output) 268 | 269 | if not output_attentions: 270 | attn_weights = None 271 | 272 | return attn_output, attn_weights, past_key_value 273 | 274 | 275 | class LLaMADecoderLayer(nn.Module): 276 | def __init__(self, config: LLaMAConfig): 277 | super().__init__() 278 | self.hidden_size = config.hidden_size 279 | self.self_attn = LLaMAAttention( 280 | hidden_size=self.hidden_size, 281 | num_heads=config.num_attention_heads, 282 | ) 283 | self.mlp = LLaMAMLP( 284 | hidden_size=self.hidden_size, 285 | intermediate_size=config.intermediate_size, 286 | hidden_act=config.hidden_act, 287 | ) 288 | self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 289 | self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 290 | 291 | def forward( 292 | self, 293 | hidden_states: torch.Tensor, 294 | attention_mask: Optional[torch.Tensor] = None, 295 | output_attentions: Optional[bool] = False, 296 | use_cache: Optional[bool] = False, 297 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 298 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 299 | """ 300 | Args: 301 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` 302 | attention_mask (`torch.FloatTensor`, *optional*): attention mask of size 303 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 304 | output_attentions (`bool`, *optional*): 305 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 306 | returned tensors for more detail. 307 | use_cache (`bool`, *optional*): 308 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding 309 | (see `past_key_values`). 310 | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states 311 | """ 312 | 313 | residual = hidden_states 314 | 315 | hidden_states = self.input_layernorm(hidden_states) 316 | 317 | # Self Attention 318 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 319 | hidden_states=hidden_states, 320 | past_key_value=past_key_value, 321 | attention_mask=attention_mask, 322 | output_attentions=output_attentions, 323 | ) 324 | hidden_states = residual + hidden_states 325 | 326 | # Fully Connected 327 | residual = hidden_states 328 | hidden_states = self.post_attention_layernorm(hidden_states) 329 | hidden_states = self.mlp(hidden_states) 330 | hidden_states = residual + hidden_states 331 | 332 | outputs = (hidden_states,) 333 | 334 | if output_attentions: 335 | outputs += (self_attn_weights,) 336 | 337 | if use_cache: 338 | outputs += (present_key_value,) 339 | 340 | return outputs 341 | 342 | 343 | LLAMA_START_DOCSTRING = r""" 344 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 345 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 346 | etc.) 347 | 348 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 349 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 350 | and behavior. 351 | 352 | Parameters: 353 | config ([`LLaMAConfig`]): 354 | Model configuration class with all the parameters of the model. Initializing with a config file does not 355 | load the weights associated with the model, only the configuration. Check out the 356 | [`~PreTrainedModel.from_pretrained`] method to load the model weights. 357 | """ 358 | 359 | 360 | @add_start_docstrings( 361 | "The bare OPT Model outputting raw hidden-states without any specific head on top.", 362 | LLAMA_START_DOCSTRING, 363 | ) 364 | class LLaMAPreTrainedModel(PreTrainedModel): 365 | config_class = LLaMAConfig 366 | base_model_prefix = "model" 367 | supports_gradient_checkpointing = True 368 | _no_split_modules = ["LLaMADecoderLayer"] 369 | _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] 370 | 371 | def _init_weights(self, module): 372 | std = self.config.initializer_range 373 | if isinstance(module, nn.Linear): 374 | module.weight.data.normal_(mean=0.0, std=std) 375 | if module.bias is not None: 376 | module.bias.data.zero_() 377 | elif isinstance(module, nn.Embedding): 378 | module.weight.data.normal_(mean=0.0, std=std) 379 | if module.padding_idx is not None: 380 | module.weight.data[module.padding_idx].zero_() 381 | 382 | def _set_gradient_checkpointing(self, module, value=False): 383 | if isinstance(module, (LLaMADecoderLayer)): 384 | module.gradient_checkpointing = value 385 | 386 | 387 | LLAMA_INPUTS_DOCSTRING = r""" 388 | Args: 389 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 390 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide 391 | it. 392 | 393 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 394 | [`PreTrainedTokenizer.__call__`] for details. 395 | 396 | [What are input IDs?](../glossary#input-ids) 397 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 398 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 399 | 400 | - 1 for tokens that are **not masked**, 401 | - 0 for tokens that are **masked**. 402 | 403 | [What are attention masks?](../glossary#attention-mask) 404 | 405 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 406 | [`PreTrainedTokenizer.__call__`] for details. 407 | 408 | If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see 409 | `past_key_values`). 410 | 411 | If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] 412 | and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more 413 | information on the default strategy. 414 | 415 | - 1 indicates the head is **not masked**, 416 | - 0 indicates the head is **masked**. 417 | 418 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 419 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape 420 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape 421 | `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. 422 | 423 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention 424 | blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. 425 | 426 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that 427 | don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all 428 | `decoder_input_ids` of shape `(batch_size, sequence_length)`. 429 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 430 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 431 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 432 | model's internal embedding lookup matrix. 433 | use_cache (`bool`, *optional*): 434 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 435 | `past_key_values`). 436 | output_attentions (`bool`, *optional*): 437 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 438 | tensors for more detail. 439 | output_hidden_states (`bool`, *optional*): 440 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 441 | more detail. 442 | return_dict (`bool`, *optional*): 443 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 444 | """ 445 | 446 | 447 | @add_start_docstrings( 448 | "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", 449 | LLAMA_START_DOCSTRING, 450 | ) 451 | class LLaMAModel(LLaMAPreTrainedModel): 452 | """ 453 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LLaMADecoderLayer`] 454 | 455 | Args: 456 | config: LLaMAConfig 457 | """ 458 | 459 | def __init__(self, config: LLaMAConfig): 460 | super().__init__(config) 461 | self.padding_idx = config.pad_token_id 462 | self.vocab_size = config.vocab_size 463 | 464 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) 465 | self.layers = nn.ModuleList([LLaMADecoderLayer(config) for _ in range(config.num_hidden_layers)]) 466 | self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 467 | 468 | self.gradient_checkpointing = False 469 | # Initialize weights and apply final processing 470 | self.post_init() 471 | 472 | def get_input_embeddings(self): 473 | return self.embed_tokens 474 | 475 | def set_input_embeddings(self, value): 476 | self.embed_tokens = value 477 | 478 | # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask 479 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): 480 | # create causal mask 481 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 482 | combined_attention_mask = None 483 | if input_shape[-1] > 1: 484 | 485 | 486 | combined_attention_mask = _make_causal_mask( 487 | input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length 488 | ).to(inputs_embeds.device) 489 | 490 | if attention_mask is not None: 491 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 492 | expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( 493 | inputs_embeds.device 494 | ) 495 | combined_attention_mask = ( 496 | expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask 497 | ) 498 | 499 | return combined_attention_mask 500 | 501 | def forward( 502 | self, 503 | input_ids: torch.LongTensor = None, 504 | attention_mask: Optional[torch.Tensor] = None, 505 | past_key_values: Optional[List[torch.FloatTensor]] = None, 506 | inputs_embeds: Optional[torch.FloatTensor] = None, 507 | use_cache: Optional[bool] = None, 508 | output_attentions: Optional[bool] = None, 509 | output_hidden_states: Optional[bool] = None, 510 | return_dict: Optional[bool] = None, 511 | ) -> Union[Tuple, BaseModelOutputWithPast]: 512 | r""" 513 | Args: 514 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 515 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you 516 | provide it. 517 | 518 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 519 | [`PreTrainedTokenizer.__call__`] for details. 520 | 521 | [What are input IDs?](../glossary#input-ids) 522 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 523 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 524 | 525 | - 1 for tokens that are **not masked**, 526 | - 0 for tokens that are **masked**. 527 | 528 | [What are attention masks?](../glossary#attention-mask) 529 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 530 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of 531 | shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of 532 | 533 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the 534 | cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. 535 | 536 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those 537 | that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of 538 | all `decoder_input_ids` of shape `(batch_size, sequence_length)`. 539 | use_cache (`bool`, *optional*): 540 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 541 | `past_key_values`). 542 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 543 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. 544 | This is useful if you want more control over how to convert `input_ids` indices into associated vectors 545 | than the model's internal embedding lookup matrix. 546 | output_attentions (`bool`, *optional*): 547 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 548 | returned tensors for more detail. 549 | output_hidden_states (`bool`, *optional*): 550 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 551 | for more detail. 552 | return_dict (`bool`, *optional*): 553 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 554 | """ 555 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 556 | output_hidden_states = ( 557 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 558 | ) 559 | use_cache = use_cache if use_cache is not None else self.config.use_cache 560 | 561 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 562 | 563 | # retrieve input_ids and inputs_embeds 564 | if input_ids is not None and inputs_embeds is not None: 565 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") 566 | elif input_ids is not None: 567 | input_shape = input_ids.size() 568 | input_ids = input_ids.view(-1, input_shape[-1]) 569 | elif inputs_embeds is not None: 570 | input_shape = inputs_embeds.size()[:-1] 571 | else: 572 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") 573 | 574 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 575 | 576 | if inputs_embeds is None: 577 | inputs_embeds = self.embed_tokens(input_ids) 578 | 579 | # embed positions 580 | if attention_mask is None: 581 | attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device) 582 | 583 | 584 | 585 | attention_mask = self._prepare_decoder_attention_mask( 586 | attention_mask, input_shape, inputs_embeds, past_key_values_length 587 | ) 588 | 589 | hidden_states = inputs_embeds 590 | 591 | if self.gradient_checkpointing and self.training: 592 | if use_cache: 593 | logger.warning_once( 594 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 595 | ) 596 | use_cache = False 597 | 598 | # decoder layers 599 | all_hidden_states = () if output_hidden_states else None 600 | all_self_attns = () if output_attentions else None 601 | next_decoder_cache = () if use_cache else None 602 | 603 | for idx, decoder_layer in enumerate(self.layers): 604 | if output_hidden_states: 605 | all_hidden_states += (hidden_states,) 606 | 607 | past_key_value = past_key_values[idx] if past_key_values is not None else None 608 | 609 | if self.gradient_checkpointing and self.training: 610 | 611 | def create_custom_forward(module): 612 | def custom_forward(*inputs): 613 | # None for past_key_value 614 | return module(*inputs, output_attentions, None) 615 | 616 | return custom_forward 617 | 618 | layer_outputs = torch.utils.checkpoint.checkpoint( 619 | create_custom_forward(decoder_layer), 620 | hidden_states, 621 | attention_mask, 622 | None, 623 | ) 624 | else: 625 | layer_outputs = decoder_layer( 626 | hidden_states, 627 | attention_mask=attention_mask, 628 | past_key_value=past_key_value, 629 | output_attentions=output_attentions, 630 | use_cache=use_cache, 631 | ) 632 | 633 | hidden_states = layer_outputs[0] 634 | 635 | if use_cache: 636 | next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) 637 | 638 | if output_attentions: 639 | all_self_attns += (layer_outputs[1],) 640 | 641 | hidden_states = self.norm(hidden_states) 642 | 643 | # add hidden states from the last decoder layer 644 | if output_hidden_states: 645 | all_hidden_states += (hidden_states,) 646 | 647 | next_cache = next_decoder_cache if use_cache else None 648 | if not return_dict: 649 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 650 | return BaseModelOutputWithPast( 651 | last_hidden_state=hidden_states, 652 | past_key_values=next_cache, 653 | hidden_states=all_hidden_states, 654 | attentions=all_self_attns, 655 | ) 656 | 657 | 658 | class LLaMAForCausalLM(LLaMAPreTrainedModel): 659 | _keys_to_ignore_on_load_missing = [r"lm_head.weight"] 660 | 661 | def __init__(self, config): 662 | super().__init__(config) 663 | self.model = LLaMAModel(config) 664 | 665 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 666 | 667 | # Initialize weights and apply final processing 668 | self.post_init() 669 | 670 | def get_input_embeddings(self): 671 | return self.model.embed_tokens 672 | 673 | def set_input_embeddings(self, value): 674 | self.model.embed_tokens = value 675 | 676 | def get_output_embeddings(self): 677 | return self.lm_head 678 | 679 | def set_output_embeddings(self, new_embeddings): 680 | self.lm_head = new_embeddings 681 | 682 | def set_decoder(self, decoder): 683 | self.model = decoder 684 | 685 | def get_decoder(self): 686 | return self.model 687 | 688 | @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) 689 | def forward( 690 | self, 691 | input_ids: torch.LongTensor = None, 692 | attention_mask: Optional[torch.Tensor] = None, 693 | past_key_values: Optional[List[torch.FloatTensor]] = None, 694 | inputs_embeds: Optional[torch.FloatTensor] = None, 695 | labels: Optional[torch.LongTensor] = None, 696 | use_cache: Optional[bool] = None, 697 | output_attentions: Optional[bool] = None, 698 | output_hidden_states: Optional[bool] = None, 699 | return_dict: Optional[bool] = None, 700 | ) -> Union[Tuple, CausalLMOutputWithPast]: 701 | r""" 702 | Args: 703 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 704 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you 705 | provide it. 706 | 707 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 708 | [`PreTrainedTokenizer.__call__`] for details. 709 | 710 | [What are input IDs?](../glossary#input-ids) 711 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 712 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 713 | 714 | - 1 for tokens that are **not masked**, 715 | - 0 for tokens that are **masked**. 716 | 717 | [What are attention masks?](../glossary#attention-mask) 718 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 719 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of 720 | shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of 721 | shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional 722 | tensors are only required when the model is used as a decoder in a Sequence to Sequence model. 723 | 724 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the 725 | cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. 726 | 727 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those 728 | that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of 729 | all `decoder_input_ids` of shape `(batch_size, sequence_length)`. 730 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 731 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. 732 | This is useful if you want more control over how to convert `input_ids` indices into associated vectors 733 | than the model's internal embedding lookup matrix. 734 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 735 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 736 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 737 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 738 | use_cache (`bool`, *optional*): 739 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding 740 | (see `past_key_values`). 741 | output_attentions (`bool`, *optional*): 742 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 743 | returned tensors for more detail. 744 | output_hidden_states (`bool`, *optional*): 745 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 746 | for more detail. 747 | return_dict (`bool`, *optional*): 748 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 749 | 750 | Returns: 751 | 752 | Example: 753 | 754 | ```python 755 | >>> from transformers import AutoTokenizer, LLaMAForCausalLM 756 | 757 | >>> model = LLaMAForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) 758 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) 759 | 760 | >>> prompt = "Hey, are you consciours? Can you talk to me?" 761 | >>> inputs = tokenizer(prompt, return_tensors="pt") 762 | 763 | >>> # Generate 764 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 765 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 766 | "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." 767 | ```""" 768 | 769 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 770 | output_hidden_states = ( 771 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 772 | ) 773 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 774 | 775 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 776 | outputs = self.model( 777 | input_ids=input_ids, 778 | attention_mask=attention_mask, 779 | past_key_values=past_key_values, 780 | inputs_embeds=inputs_embeds, 781 | use_cache=use_cache, 782 | output_attentions=output_attentions, 783 | output_hidden_states=output_hidden_states, 784 | return_dict=return_dict, 785 | ) 786 | 787 | hidden_states = outputs[0] 788 | logits = self.lm_head(hidden_states) 789 | 790 | loss = None 791 | if labels is not None: 792 | # Shift so that tokens < n predict n 793 | shift_logits = logits[..., :-1, :].contiguous() 794 | shift_labels = labels[..., 1:].contiguous() 795 | # Flatten the tokens 796 | loss_fct = CrossEntropyLoss() 797 | loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) 798 | 799 | if not return_dict: 800 | output = (logits,) + outputs[1:] 801 | return (loss,) + output if loss is not None else output 802 | 803 | return CausalLMOutputWithPast( 804 | loss=loss, 805 | logits=logits, 806 | past_key_values=outputs.past_key_values, 807 | hidden_states=outputs.hidden_states, 808 | attentions=outputs.attentions, 809 | ) 810 | 811 | def prepare_inputs_for_generation( 812 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 813 | ): 814 | if past_key_values: 815 | input_ids = input_ids[:, -1:] 816 | 817 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 818 | if inputs_embeds is not None and past_key_values is None: 819 | model_inputs = {"inputs_embeds": inputs_embeds} 820 | else: 821 | model_inputs = {"input_ids": input_ids} 822 | 823 | model_inputs.update( 824 | { 825 | "past_key_values": past_key_values, 826 | "use_cache": kwargs.get("use_cache"), 827 | "attention_mask": attention_mask, 828 | } 829 | ) 830 | return model_inputs 831 | 832 | @staticmethod 833 | def _reorder_cache(past_key_values, beam_idx): 834 | reordered_past = () 835 | for layer_past in past_key_values: 836 | reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) 837 | return reordered_past -------------------------------------------------------------------------------- /model/llama/tokenization_llama.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | 21 | """Tokenization classes for LLaMA.""" 22 | import os 23 | import re 24 | from shutil import copyfile 25 | from typing import Any, Dict, List, Optional, Tuple 26 | 27 | import sentencepiece as spm 28 | 29 | from transformers.tokenization_utils import PreTrainedTokenizer 30 | from transformers.utils import logging 31 | 32 | 33 | logger = logging.get_logger(__name__) 34 | 35 | VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} 36 | 37 | PRETRAINED_VOCAB_FILES_MAP = {} 38 | 39 | def load_vocab(file_path, with_score = True): 40 | vocab = [] 41 | with open(file_path, "r") as file: 42 | for line in file.readlines(): 43 | if with_score: 44 | vocab.append(line.split("\t")[0]) 45 | else: 46 | vocab.append(line.split("\n")[0]) 47 | file.close() 48 | return vocab 49 | 50 | class LLaMATokenizer(PreTrainedTokenizer): 51 | """ 52 | Construct a LLaMA tokenizer. Based on byte-level Byte-Pair-Encoding. 53 | 54 | Args: 55 | vocab_file (`str`): 56 | Path to the vocabulary file. 57 | """ 58 | 59 | vocab_files_names = VOCAB_FILES_NAMES 60 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 61 | model_input_names = ["input_ids", "attention_mask"] 62 | 63 | def __init__( 64 | self, 65 | vocab_file, 66 | unk_token="", 67 | bos_token="", 68 | eos_token="", 69 | sp_model_kwargs: Optional[Dict[str, Any]] = None, 70 | add_bos_token=True, 71 | add_eos_token=False, 72 | **kwargs, 73 | ): 74 | self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs 75 | super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs) 76 | self.vocab_file = vocab_file 77 | self.add_bos_token = add_bos_token 78 | self.add_eos_token = add_eos_token 79 | self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) 80 | self.sp_model.Load(vocab_file) 81 | 82 | """ Initialisation""" 83 | 84 | @property 85 | def vocab_size(self): 86 | """Returns vocab size""" 87 | return self.sp_model.get_piece_size() 88 | 89 | @property 90 | def bos_token_id(self) -> Optional[int]: 91 | return self.sp_model.bos_id() 92 | 93 | @property 94 | def eos_token_id(self) -> Optional[int]: 95 | return self.sp_model.eos_id() 96 | 97 | # @property 98 | # def pad_token_id(self) -> Optional[int]: 99 | # return self.sp_model.pad_id() 100 | 101 | def get_vocab(self): 102 | """Returns vocab as a dict""" 103 | vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} 104 | vocab.update(self.added_tokens_encoder) 105 | return vocab 106 | 107 | def _tokenize(self, text): 108 | """Returns a tokenized string.""" 109 | return self.sp_model.encode(text, out_type=str) 110 | 111 | def _convert_token_to_id(self, token): 112 | """Converts a token (str) in an id using the vocab.""" 113 | return self.sp_model.piece_to_id(token) 114 | 115 | def _convert_id_to_token(self, index): 116 | """Converts an index (integer) in a token (str) using the vocab.""" 117 | token = self.sp_model.IdToPiece(index) 118 | return token 119 | 120 | def convert_tokens_to_string(self, tokens): 121 | """Converts a sequence of tokens (string) in a single string.""" 122 | current_sub_tokens = [] 123 | out_string = "" 124 | prev_is_special = False 125 | for token in tokens: 126 | # make sure that special tokens are not decoded using sentencepiece model 127 | if token in self.all_special_tokens: 128 | if not prev_is_special: 129 | out_string += " " 130 | out_string += self.sp_model.decode(current_sub_tokens) + token 131 | prev_is_special = True 132 | current_sub_tokens = [] 133 | else: 134 | current_sub_tokens.append(token) 135 | prev_is_special = False 136 | out_string += self.sp_model.decode(current_sub_tokens) 137 | return out_string.strip() 138 | 139 | def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]: 140 | """ 141 | Save the vocabulary and special tokens file to a directory. 142 | 143 | Args: 144 | save_directory (`str`): 145 | The directory in which to save the vocabulary. 146 | 147 | Returns: 148 | `Tuple(str)`: Paths to the files saved. 149 | """ 150 | if not os.path.isdir(save_directory): 151 | logger.error(f"Vocabulary path ({save_directory}) should be a directory") 152 | return 153 | out_vocab_file = os.path.join( 154 | save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] 155 | ) 156 | 157 | if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): 158 | copyfile(self.vocab_file, out_vocab_file) 159 | elif not os.path.isfile(self.vocab_file): 160 | with open(out_vocab_file, "wb") as fi: 161 | content_spiece_model = self.sp_model.serialized_model_proto() 162 | fi.write(content_spiece_model) 163 | 164 | return (out_vocab_file,) 165 | 166 | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): 167 | if self.add_bos_token: 168 | bos_token_ids = [self.bos_token_id] 169 | else: 170 | bos_token_ids = [] 171 | 172 | output = bos_token_ids + token_ids_0 173 | 174 | if token_ids_1 is not None: 175 | output = output + token_ids_1 176 | 177 | if self.add_eos_token: 178 | output = output + [self.eos_token_id] 179 | 180 | return output 181 | 182 | def get_special_tokens_mask( 183 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False 184 | ) -> List[int]: 185 | """ 186 | Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding 187 | special tokens using the tokenizer `prepare_for_model` method. 188 | 189 | Args: 190 | token_ids_0 (`List[int]`): 191 | List of IDs. 192 | token_ids_1 (`List[int]`, *optional*): 193 | Optional second list of IDs for sequence pairs. 194 | already_has_special_tokens (`bool`, *optional*, defaults to `False`): 195 | Whether or not the token list is already formatted with special tokens for the model. 196 | 197 | Returns: 198 | `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. 199 | """ 200 | if already_has_special_tokens: 201 | return super().get_special_tokens_mask( 202 | token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True 203 | ) 204 | 205 | if token_ids_1 is None: 206 | return [1] + ([0] * len(token_ids_0)) + [1] 207 | return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] 208 | 209 | def create_token_type_ids_from_sequences( 210 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 211 | ) -> List[int]: 212 | """ 213 | Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make 214 | use of token type ids, therefore a list of zeros is returned. 215 | 216 | Args: 217 | token_ids_0 (`List[int]`): 218 | List of IDs. 219 | token_ids_1 (`List[int]`, *optional*): 220 | Optional second list of IDs for sequence pairs. 221 | 222 | Returns: 223 | `List[int]`: List of zeros. 224 | """ 225 | eos = [self.eos_token_id] 226 | 227 | if token_ids_1 is None: 228 | return len(token_ids_0 + eos) * [0] 229 | return len(token_ids_0 + eos + token_ids_1 + eos) * [0] -------------------------------------------------------------------------------- /model/translate.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | import torch 3 | import time 4 | from nltk.translate.bleu_score import sentence_bleu 5 | import jieba 6 | import llama 7 | import argparse 8 | from accelerate.utils import set_seed 9 | import json 10 | import tensor_parallel as tp 11 | 12 | import os 13 | from tqdm import tqdm 14 | 15 | PROMPT_DICT = { 16 | # "prompt_instruct": ( 17 | # "以下是一个描述任务的指令,并配有一个提供详细上下文信息的输入。" 18 | # "请写一个完成该指令的适当回复。\n\n" 19 | # "### 指令:\n{instruction}\n\n### 输入:\n{input}\n\n### 回复:" 20 | # ), 21 | "prompt_input": ( 22 | "以下是一个描述任务的指令,请写一个完成该指令的适当回复。\n\n" 23 | "### 指令:\n{0}\n\n### 回复:" 24 | ), 25 | "translate_prompt":"{0}句子:“{2}”的{1}是:", 26 | "translate_instruct": "请将以下{0}句子翻译成{1}:{2}", 27 | } 28 | 29 | TYPE_DICT = { 30 | "fp16": torch.float16, 31 | "bf16": torch.bfloat16, 32 | "fp32": torch.float32, 33 | } 34 | 35 | def read_prompt_txt(file_path:str): 36 | """ 37 | read prompt from text file in line by line manner. 38 | """ 39 | file_handle = open(file_path) 40 | prompts = [] 41 | 42 | while True: 43 | line = file_handle.readline() 44 | if not line: 45 | break 46 | line = line.strip() 47 | prompts.append(line) 48 | 49 | return prompts 50 | 51 | def read_prompt_json(file_path:str): 52 | """ 53 | read prompt (dict) from json file in line by line manner. 54 | """ 55 | file_handle = open(file_path) 56 | prompts = [] 57 | while True: 58 | line = file_handle.readline() 59 | if not line: 60 | break 61 | line = line.strip() 62 | prompts.append(json.loads(line)) 63 | return prompts 64 | 65 | FILE_TYPE2LOAD = { 66 | "txt": read_prompt_txt, 67 | "json": read_prompt_json 68 | } 69 | 70 | def abbreviation2fullname(src_abbreviation, tgt_abbreviation): 71 | with open("./languages_abbreviation2fullname.txt") as f: 72 | lines = f.readlines() 73 | f.close() 74 | language_fullname_dict = {} 75 | for line in lines: 76 | abbreviation = line.strip().split("\t")[0] 77 | english_full_name = line.strip().split("\t")[1] 78 | chinese_full_name = line.strip().split("\t")[2] 79 | language_fullname_dict.update({abbreviation : chinese_full_name}) 80 | assert language_fullname_dict[src_abbreviation]!='NONE', f'Source language abbreviation can not convert to Chinese full name, Please check the source language abbreviation in languages_abbreviation2fullname.txt' 81 | assert language_fullname_dict[tgt_abbreviation]!='NONE', f'Target language abbreviation can not convert to Chinese full name, Please check the target language abbreviation in languages_abbreviation2fullname.txt' 82 | 83 | return language_fullname_dict[src_abbreviation], language_fullname_dict[tgt_abbreviation] 84 | 85 | def cut2list(line): 86 | line_cut = jieba.cut(line, cut_all=True) 87 | line_list = [c for c in line_cut] 88 | out_list = [] 89 | for c in line_list: 90 | if len(c) == 0: 91 | continue 92 | if c == ' ': 93 | continue 94 | out_list.append(c) 95 | return out_list 96 | 97 | def single_prompt(model, tokenizer, prompt="Hello, I'm am conscious and", max_new_tokens:int=128, do_sample:bool=True, num_beams:int=1, top_k:int=50, top_p:float=0.95, no_repeat_ngram_size=6, temperature:float=0.7, cuda=True, verbose=False): 98 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids 99 | if cuda: 100 | # model = model.cuda() 101 | 102 | model = tp.tensor_parallel(model) 103 | input_ids = input_ids.cuda() 104 | 105 | with torch.inference_mode(): 106 | generated_ids = model.generate(input_ids, max_new_tokens=max_new_tokens, do_sample=do_sample, num_beams=num_beams, top_k=top_k, top_p=top_p, temperature=temperature, no_repeat_ngram_size=no_repeat_ngram_size) 107 | 108 | results = tokenizer.batch_decode(generated_ids, skip_special_tokens=True, spaces_between_special_tokens=False) 109 | 110 | if verbose: 111 | print(results) 112 | return results 113 | 114 | def batch_prompt(model, tokenizer, prompts:list=["Hello, I'm am conscious and"], max_new_tokens:int=128, do_sample:bool=True, num_beams:int=1, top_k:int=50, top_p:float=0.95, temperature:float=0.7, no_repeat_ngram_size=6, cuda=True, verbose=False): 115 | tokenizer.padding_side="left" 116 | tokenizer.pad_token_id = tokenizer.eos_token_id 117 | model.config.pad_token_id = model.config.eos_token_id 118 | tokens = tokenizer(prompts, return_tensors="pt", padding=True) 119 | input_ids = tokens["input_ids"] 120 | attention_mask = tokens["attention_mask"] 121 | if cuda: 122 | # model = model.cuda() 123 | 124 | model = tp.tensor_parallel(model) 125 | input_ids = input_ids.cuda() 126 | attention_mask = attention_mask.cuda() 127 | with torch.inference_mode(): 128 | generated_ids = model.generate(input_ids,attention_mask=attention_mask, max_new_tokens=max_new_tokens, do_sample=do_sample, num_beams=num_beams, top_k=top_k, top_p=top_p, temperature=temperature, no_repeat_ngram_size=no_repeat_ngram_size) 129 | results = tokenizer.batch_decode(generated_ids, skip_special_tokens=True, spaces_between_special_tokens=False) 130 | if verbose: 131 | print(results) 132 | return results 133 | 134 | def back_process(str_out:str): 135 | if "\n" in str_out: 136 | str_out = str_out[:str_out.find("\n")] 137 | 138 | return str_out 139 | 140 | def single_translate(model, tokenizer, prompt="Hello, I'm am conscious and", src_lang="zh", tgt_lang="en", with_instruct:bool=True, cuda=True, verbose=False, generation_args:dict=None): 141 | # src_lang = PROMPT_DICT[src_lang] 142 | # tgt_lang = PROMPT_DICT[tgt_lang] 143 | 144 | prompt_in = None 145 | if with_instruct: 146 | prompt_in = PROMPT_DICT["prompt_input"].format(PROMPT_DICT["translate_instruct"].format(src_lang, tgt_lang, prompt)) 147 | else: 148 | prompt_in = PROMPT_DICT["translate_prompt"].format(src_lang, tgt_lang, prompt) 149 | 150 | src_len = len(prompt_in) 151 | 152 | if verbose: 153 | print(f"Translation Prompt: {prompt_in}") 154 | 155 | out_str = single_prompt( 156 | model=model, 157 | tokenizer=tokenizer, 158 | prompt=prompt_in, 159 | cuda=cuda, 160 | verbose=verbose, 161 | **generation_args, 162 | ) 163 | 164 | translate_res = out_str[0][src_len:] 165 | 166 | if verbose: 167 | print(f"Translation Result: {translate_res}") 168 | 169 | return translate_res 170 | 171 | 172 | def fp32to16(model_path,init_model): 173 | # convert fp32 model to fp16 model (in-place) 174 | model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, state_dict= torch.load(init_model) if init_model is not None else None) 175 | torch.save(model.state_dict(), init_model) 176 | 177 | # for direct load, in case the state dict needed 178 | def fp32to16_dir(model_path,init_path,tgt_dir): 179 | # convert fp32 model to fp16 model (in-place) 180 | model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, state_dict= torch.load(init_path) if init_path is not None else None) 181 | model.save_pretrained(tgt_dir) 182 | 183 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 184 | tokenizer.save_pretrained(tgt_dir) 185 | 186 | def to_matrix(l, n): 187 | return [l[i:i+n] for i in range(0, len(l), n)] 188 | 189 | if __name__ == '__main__': 190 | parser = argparse.ArgumentParser() 191 | 192 | parser.add_argument( 193 | "--model", 194 | type=str, 195 | default=None, 196 | help="The name of model to use.", 197 | ) 198 | 199 | parser.add_argument( 200 | "--tokenizer-path", 201 | type=str, 202 | default=None, 203 | help="The path of tokenizer to use.", 204 | ) 205 | 206 | parser.add_argument( 207 | "--parameter-type", 208 | type=str, 209 | # choices=["fp16","bf16","fp32"], 210 | default="bf16", 211 | help="The type of model parameters to load.", 212 | ) 213 | 214 | parser.add_argument( 215 | "--beam-search", 216 | action='store_true', 217 | help="Whether to run beam search." 218 | ) 219 | 220 | parser.add_argument( 221 | "--with-instruct", 222 | action='store_true', 223 | help="Whether to run beam search." 224 | ) 225 | 226 | parser.add_argument( 227 | "--num-beams", 228 | type=int, 229 | default=5 230 | ) 231 | 232 | parser.add_argument( 233 | "--no-repeat-ngram-size", 234 | type=int, 235 | default=0 236 | ) 237 | 238 | parser.add_argument( 239 | "--checkpoint", 240 | type=str, 241 | default=None 242 | ) 243 | 244 | parser.add_argument( 245 | "--prompt-file", 246 | type=str, 247 | default=None 248 | ) 249 | 250 | parser.add_argument( 251 | "--out-file", 252 | type=str, 253 | default=None 254 | ) 255 | 256 | parser.add_argument( 257 | "--seed", 258 | type=int, 259 | default=0 260 | ) 261 | 262 | parser.add_argument( 263 | "--translate", 264 | action='store_true', 265 | help="Whether to run translate." 266 | ) 267 | 268 | parser.add_argument( 269 | "--verbose", 270 | action='store_true', 271 | help="Whether to print the details in translation." 272 | ) 273 | 274 | parser.add_argument( 275 | "--source-language", 276 | type=str, 277 | default="zh", 278 | help="The source language of translation." 279 | ) 280 | 281 | parser.add_argument( 282 | "--target-language", 283 | type=str, 284 | default="en", 285 | help="The target language of translation." 286 | ) 287 | 288 | parser.add_argument( 289 | "--translate-json-skip-keys", 290 | nargs='+', 291 | default=["answer"], 292 | help="The key list to skip translation." 293 | ) 294 | 295 | parser.add_argument( 296 | "--batch-inference", 297 | action='store_true', 298 | help="Whether to run inference in batch." 299 | ) 300 | 301 | parser.add_argument( 302 | "--batch-size", 303 | type=int, 304 | default=3, 305 | help="Batch size of run inference in batch." 306 | ) 307 | 308 | parser.add_argument( 309 | "--split-translate", 310 | action='store_true', 311 | help="Whether to run translate by split on dot." 312 | ) 313 | 314 | 315 | parser.add_argument( 316 | "--gold-file", 317 | type=str, 318 | default=None 319 | ) 320 | 321 | parser.add_argument( 322 | "--times", 323 | type=int, 324 | default=3, 325 | help="Number of generation for each prompt.", 326 | ) 327 | 328 | parser.add_argument( 329 | "--max-tokens", 330 | type=int, 331 | default=256 332 | ) 333 | 334 | parser.add_argument( 335 | "--top-k", 336 | type=int, 337 | default=80, 338 | help="The configuration top k tokens in the generation of model.", 339 | ) 340 | 341 | parser.add_argument( 342 | "--top-p", 343 | type=float, 344 | default=0.95, 345 | help="The configuration top p tokens in the generation of model.", 346 | ) 347 | 348 | parser.add_argument( 349 | "--temperature", 350 | type=float, 351 | default=0.7, 352 | help="The configuration temperature in the generation of model.", 353 | ) 354 | 355 | args = parser.parse_args() 356 | set_seed(args.seed) 357 | print(args.out_file) 358 | 359 | config = llama.LLaMAConfig.from_pretrained(args.model) 360 | tokenizer = llama.LLaMATokenizer.from_pretrained(args.tokenizer_path) 361 | model = llama.LLaMAForCausalLM.from_pretrained( 362 | args.model, 363 | torch_dtype=TYPE_DICT[args.parameter_type], 364 | config=config, 365 | state_dict=torch.load(args.checkpoint) if args.checkpoint is not None else None 366 | ) 367 | 368 | generation_config = { 369 | "do_sample": not args.beam_search, 370 | "num_beams": args.num_beams if args.beam_search else 1, 371 | "max_new_tokens": args.max_tokens, 372 | "no_repeat_ngram_size": args.no_repeat_ngram_size, 373 | "top_k": args.top_k, 374 | "top_p": args.top_p, 375 | "temperature": args.temperature 376 | } 377 | 378 | # inference for translate 379 | prompt_file_type = os.path.basename(args.prompt_file).split(".")[-1] 380 | 381 | assert prompt_file_type in FILE_TYPE2LOAD, f"Prompt file type({prompt_file_type}) is not in {FILE_TYPE2LOAD.keys()}" 382 | 383 | prompts = FILE_TYPE2LOAD[prompt_file_type](args.prompt_file) 384 | 385 | # out_handle = open(args.out_file,"w",encoding="utf-8") 386 | for attr, value in sorted(args.__dict__.items()): 387 | print(f"\t{attr}={value}") 388 | # out_handle.write(f"\t{attr}={value}") 389 | # out_handle.write("\n") 390 | 391 | assert args.source_language != args.target_language, f"Target language({args.target_language}) must be different with the source language({args.source_language})!" 392 | 393 | source_full_name, target_full_name = abbreviation2fullname(args.source_language, args.target_language) 394 | 395 | if prompt_file_type == "json": 396 | for sample in tqdm(prompts): 397 | tgt_res = {} 398 | for k in sample.keys(): 399 | if k in args.translate_json_skip_keys or len(sample[k]) == 0: 400 | tgt_res[k] = sample[k] 401 | continue 402 | 403 | tgt_res[k] = single_translate( 404 | model=model, 405 | tokenizer=tokenizer, 406 | prompt=sample[k], 407 | src_lang=source_full_name, 408 | tgt_lang=target_full_name, 409 | with_instruct=args.with_instruct, 410 | cuda=True, 411 | verbose=args.verbose, 412 | generation_args=generation_config 413 | ) 414 | with open(args.out_file, "a", encoding="utf-8") as f: 415 | print(json.dumps(tgt_res, ensure_ascii=False)) 416 | f.write(json.dumps(tgt_res, ensure_ascii=False)) 417 | f.write("\n") 418 | 419 | else: 420 | for prompt in tqdm(prompts): 421 | tgt_out = single_translate( 422 | model=model, 423 | tokenizer=tokenizer, 424 | prompt=prompt, 425 | src_lang=source_full_name, 426 | tgt_lang=target_full_name, 427 | with_instruct=args.with_instruct, 428 | cuda=True, 429 | verbose=args.verbose, 430 | generation_args=generation_config 431 | ) 432 | 433 | with open(args.out_file, "a", encoding="utf-8") as f: 434 | print(tgt_out) 435 | f.write(tgt_out) 436 | f.write("\n") 437 | -------------------------------------------------------------------------------- /pics/104langs_bleu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZNLP/BigTranslate/c54fa543a8edb73af08ff1b5fbe81c3d9c910041/pics/104langs_bleu.png -------------------------------------------------------------------------------- /pics/70langs_gpt4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZNLP/BigTranslate/c54fa543a8edb73af08ff1b5fbe81c3d9c910041/pics/70langs_gpt4.png -------------------------------------------------------------------------------- /pics/The_outline_of_Increment_pre-training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZNLP/BigTranslate/c54fa543a8edb73af08ff1b5fbe81c3d9c910041/pics/The_outline_of_Increment_pre-training.png -------------------------------------------------------------------------------- /pics/corpus_distribution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZNLP/BigTranslate/c54fa543a8edb73af08ff1b5fbe81c3d9c910041/pics/corpus_distribution.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.12.1 2 | transformers==4.21.0 3 | sentencepiece==0.1.97 4 | accelerate==0.16.0 5 | nltk==3.8.1 6 | jieba==0.42.1 7 | tensor_parallel==1.2.8 -------------------------------------------------------------------------------- /translate.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | export SEED=0 4 | 5 | #You can find all supported languages abbreviation in languages_abbreviation2fullname.txt 6 | export SRC_LANG= #SOURCE_LANGUAGE_ABBREVIATION (e.g., "en") 7 | export TGT_LANG= #TARGET_LANGUAGE_ABBREVIATION (e.g., "zh") 8 | 9 | export PROMPT_FILE= #PROMPT_FILE_PATH (supported file type: txt or json, e.g., "./example/en.txt") 10 | 11 | export SAVE_PATH= #SAVE_FILE_PATH 12 | 13 | 14 | LOG_FILE="translate_bigtrans.example.log" 15 | 16 | 17 | export INSTRUCT="True" 18 | 19 | export VERBOSE="True" #Whether to print the details in translation. 20 | 21 | export CHECKPOINT_PATH= #CHECKPOINT_PATH (e.g., /PATH2BigTrans or decapoda-research/llama-7b-hf) 22 | export TOKENIZER_PATH= #TOKENIZER_PATH (e.g., /PATH2BigTrans or decapoda-research/llama-7b-hf) 23 | 24 | 25 | # export MODEL_TYPE="bf16" 26 | export MODEL_TYPE="fp16" #The type of model parameters to load (e.g., ["fp16", "bf16", "fp32"]) 27 | 28 | export NUM_BEAMS=5 29 | 30 | export MAX_TOKENS=1024 31 | export NO_REPEAT_NGRAM_SIZE=6 32 | export LOW_TEMPERATURE=0.01 33 | 34 | export ADD_PARAMETERS="" 35 | 36 | if [ "${INSTRUCT}" != "False" ]; 37 | then 38 | ADD_PARAMETERS="--with-instruct " 39 | fi 40 | 41 | if [ "${VERBOSE}" != "False" ]; 42 | then 43 | ADD_PARAMETERS="${ADD_PARAMETERS} --verbose " 44 | fi 45 | 46 | 47 | # beam search is deterministic 48 | export OUT_TIME=1 49 | python -u model/translate.py \ 50 | --model ${CHECKPOINT_PATH} \ 51 | --tokenizer-path ${TOKENIZER_PATH} \ 52 | --prompt-file ${PROMPT_FILE} \ 53 | ${ADD_PARAMETERS} \ 54 | --out-file ${SAVE_PATH} \ 55 | --source-language ${SRC_LANG} \ 56 | --target-language ${TGT_LANG} \ 57 | --seed ${SEED} \ 58 | --beam-search \ 59 | --parameter-type ${MODEL_TYPE} \ 60 | --num-beams ${NUM_BEAMS} \ 61 | --times ${OUT_TIME} \ 62 | --max-tokens ${MAX_TOKENS} \ 63 | --no-repeat-ngram-size ${NO_REPEAT_NGRAM_SIZE} \ 64 | --temperature ${LOW_TEMPERATURE} 2>&1 >>${LOG_FILE} 65 | --------------------------------------------------------------------------------