├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MODEL_CARD.md ├── README.md ├── Responsible-Use-Guide.pdf ├── USE_POLICY.md ├── download.sh ├── example_chat_completion.py ├── example_text_completion.py ├── llama ├── __init__.py ├── generation.py ├── model.py └── tokenizer.py ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Llama 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Meta's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to Llama, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | LLAMA 2 COMMUNITY LICENSE AGREEMENT 2 | Llama 2 Version Release Date: July 18, 2023 3 | 4 | "Agreement" means the terms and conditions for use, reproduction, distribution and 5 | modification of the Llama Materials set forth herein. 6 | 7 | "Documentation" means the specifications, manuals and documentation 8 | accompanying Llama 2 distributed by Meta at ai.meta.com/resources/models-and- 9 | libraries/llama-downloads/. 10 | 11 | "Licensee" or "you" means you, or your employer or any other person or entity (if 12 | you are entering into this Agreement on such person or entity's behalf), of the age 13 | required under applicable laws, rules or regulations to provide legal consent and that 14 | has legal authority to bind your employer or such other person or entity if you are 15 | entering in this Agreement on their behalf. 16 | 17 | "Llama 2" means the foundational large language models and software and 18 | algorithms, including machine-learning model code, trained model weights, 19 | inference-enabling code, training-enabling code, fine-tuning enabling code and other 20 | elements of the foregoing distributed by Meta at ai.meta.com/resources/models-and- 21 | libraries/llama-downloads/. 22 | 23 | "Llama Materials" means, collectively, Meta's proprietary Llama 2 and 24 | Documentation (and any portion thereof) made available under this Agreement. 25 | 26 | "Meta" or "we" means Meta Platforms Ireland Limited (if you are located in or, if you 27 | are an entity, your principal place of business is in the EEA or Switzerland) and Meta 28 | Platforms, Inc. (if you are located outside of the EEA or Switzerland). 29 | 30 | By clicking "I Accept" below or by using or distributing any portion or element of the 31 | Llama Materials, you agree to be bound by this Agreement. 32 | 33 | 1. License Rights and Redistribution. 34 | 35 | a. Grant of Rights. You are granted a non-exclusive, worldwide, non- 36 | transferable and royalty-free limited license under Meta's intellectual property or 37 | other rights owned by Meta embodied in the Llama Materials to use, reproduce, 38 | distribute, copy, create derivative works of, and make modifications to the Llama 39 | Materials. 40 | 41 | b. Redistribution and Use. 42 | 43 | i. If you distribute or make the Llama Materials, or any derivative works 44 | thereof, available to a third party, you shall provide a copy of this Agreement to such 45 | third party. 46 | ii. If you receive Llama Materials, or any derivative works thereof, from 47 | a Licensee as part of an integrated end user product, then Section 2 of this 48 | Agreement will not apply to you. 49 | 50 | iii. You must retain in all copies of the Llama Materials that you 51 | distribute the following attribution notice within a "Notice" text file distributed as a 52 | part of such copies: "Llama 2 is licensed under the LLAMA 2 Community License, 53 | Copyright (c) Meta Platforms, Inc. All Rights Reserved." 54 | 55 | iv. Your use of the Llama Materials must comply with applicable laws 56 | and regulations (including trade compliance laws and regulations) and adhere to the 57 | Acceptable Use Policy for the Llama Materials (available at 58 | https://ai.meta.com/llama/use-policy), which is hereby incorporated by reference into 59 | this Agreement. 60 | 61 | v. You will not use the Llama Materials or any output or results of the 62 | Llama Materials to improve any other large language model (excluding Llama 2 or 63 | derivative works thereof). 64 | 65 | 2. Additional Commercial Terms. If, on the Llama 2 version release date, the 66 | monthly active users of the products or services made available by or for Licensee, 67 | or Licensee's affiliates, is greater than 700 million monthly active users in the 68 | preceding calendar month, you must request a license from Meta, which Meta may 69 | grant to you in its sole discretion, and you are not authorized to exercise any of the 70 | rights under this Agreement unless or until Meta otherwise expressly grants you 71 | such rights. 72 | 73 | 3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE 74 | LLAMA MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE 75 | PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, 76 | EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY 77 | WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR 78 | FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE 79 | FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING 80 | THE LLAMA MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR 81 | USE OF THE LLAMA MATERIALS AND ANY OUTPUT AND RESULTS. 82 | 83 | 4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE 84 | LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, 85 | NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS 86 | AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, 87 | CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN 88 | IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF 89 | ANY OF THE FOREGOING. 90 | 91 | 5. Intellectual Property. 92 | 93 | a. No trademark licenses are granted under this Agreement, and in 94 | connection with the Llama Materials, neither Meta nor Licensee may use any name 95 | or mark owned by or associated with the other or any of its affiliates, except as 96 | required for reasonable and customary use in describing and redistributing the 97 | Llama Materials. 98 | 99 | b. Subject to Meta's ownership of Llama Materials and derivatives made by or 100 | for Meta, with respect to any derivative works and modifications of the Llama 101 | Materials that are made by you, as between you and Meta, you are and will be the 102 | owner of such derivative works and modifications. 103 | 104 | c. If you institute litigation or other proceedings against Meta or any entity 105 | (including a cross-claim or counterclaim in a lawsuit) alleging that the Llama 106 | Materials or Llama 2 outputs or results, or any portion of any of the foregoing, 107 | constitutes infringement of intellectual property or other rights owned or licensable 108 | by you, then any licenses granted to you under this Agreement shall terminate as of 109 | the date such litigation or claim is filed or instituted. You will indemnify and hold 110 | harmless Meta from and against any claim by any third party arising out of or related 111 | to your use or distribution of the Llama Materials. 112 | 113 | 6. Term and Termination. The term of this Agreement will commence upon your 114 | acceptance of this Agreement or access to the Llama Materials and will continue in 115 | full force and effect until terminated in accordance with the terms and conditions 116 | herein. Meta may terminate this Agreement if you are in breach of any term or 117 | condition of this Agreement. Upon termination of this Agreement, you shall delete 118 | and cease use of the Llama Materials. Sections 3, 4 and 7 shall survive the 119 | termination of this Agreement. 120 | 121 | 7. Governing Law and Jurisdiction. This Agreement will be governed and 122 | construed under the laws of the State of California without regard to choice of law 123 | principles, and the UN Convention on Contracts for the International Sale of Goods 124 | does not apply to this Agreement. The courts of California shall have exclusive 125 | jurisdiction of any dispute arising out of this Agreement. 126 | 127 | -------------------------------------------------------------------------------- /MODEL_CARD.md: -------------------------------------------------------------------------------- 1 | # **Model Details** 2 | 3 | Meta developed and released the Llama 2 family of large language models (LLMs), a collection of pretrained and fine-tuned generative text models ranging in scale from 7 billion to 70 billion parameters. Our fine-tuned LLMs, called Llama-2-Chat, are optimized for dialogue use cases. Llama-2-Chat models outperform open-source chat models on most benchmarks we tested, and in our human evaluations for helpfulness and safety, are on par with some popular closed-source models like ChatGPT and PaLM. 4 | 5 | **Model Developers** Meta 6 | 7 | **Variations** Llama 2 comes in a range of parameter sizes — 7B, 13B, and 70B — as well as pretrained and fine-tuned variations. 8 | 9 | **Input** Models input text only. 10 | 11 | **Output** Models generate text only. 12 | 13 | **Model Architecture** Llama 2 is an auto-regressive language model that uses an optimized transformer architecture. The tuned versions use supervised fine-tuning (SFT) and reinforcement learning with human feedback (RLHF) to align to human preferences for helpfulness and safety. 14 | 15 | ||Training Data|Params|Content Length|GQA|Tokens|LR| 16 | |---|---|---|---|---|---|---| 17 | Llama 2|*A new mix of publicly available online data*|7B|4k|✗|2.0T|3.0 x 10-4 18 | Llama 2|*A new mix of publicly available online data*|13B|4k|✗|2.0T|3.0 x 10-4 19 | Llama 2|*A new mix of publicly available online data*|70B|4k|✔|2.0T|1.5 x 10-4 20 | 21 | **Llama 2 family of models.** Token counts refer to pretraining data only. All models are trained with a global batch-size of 4M tokens. The 70B version uses Grouped-Query Attention (GQA) for improved inference scalability. 22 | 23 | **Model Dates** Llama 2 was trained between January 2023 and July 2023. 24 | 25 | **Status** This is a static model trained on an offline dataset. Future versions of the tuned models will be released as we improve model safety with community feedback. 26 | 27 | **License** A custom commercial license is available at: [https://ai.meta.com/resources/models-and-libraries/llama-downloads/](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) 28 | 29 | **Research Paper** More information can be found in the paper "Llama-2: Open Foundation and Fine-tuned Chat Models", available at https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/. 30 | 31 | **Where to send questions or comments about the model** Instructions on how to provide feedback or comments on the model can be found in the model [README](README.md). 32 | 33 | # **Intended Use** 34 | **Intended Use Cases** Llama 2 is intended for commercial and research use in English. Tuned models are intended for assistant-like chat, whereas pretrained models can be adapted for a variety of natural language generation tasks. 35 | 36 | **Out-of-scope Uses** Use in any manner that violates applicable laws or regulations (including trade compliance laws). Use in languages other than English. Use in any other way that is prohibited by the Acceptable Use Policy and Licensing Agreement for Llama 2. 37 | 38 | # **Hardware and Software** 39 | **Training Factors** We used custom training libraries, Meta's Research Super Cluster, and production clusters for pretraining. Fine-tuning, annotation, and evaluation were also performed on third-party cloud compute. 40 | 41 | **Carbon Footprint** Pretraining utilized a cumulative 3.3M GPU hours of computation on hardware of type A100-80GB (TDP of 350-400W). Estimated total emissions were 539 tCO2eq, 100% of which were offset by Meta’s sustainability program. 42 | 43 | ||Time (GPU hours)|Power Consumption (W)|Carbon Emitted(tCO2eq)| 44 | |---|---|---|---| 45 | |Llama 2 7B|184320|400|31.22| 46 | |Llama 2 13B|368640|400|62.44| 47 | |Llama 2 70B|1720320|400|291.42| 48 | |Total|3311616||539.00| 49 | 50 | **CO2 emissions during pretraining.** Time: total GPU time required for training each model. Power Consumption: peak power capacity per GPU device for the GPUs used adjusted for power usage efficiency. 100% of the emissions are directly offset by Meta's sustainability program, and because we are openly releasing these models, the pretraining costs do not need to be incurred by others. 51 | 52 | # **Training Data** 53 | **Overview** Llama 2 was pretrained on 2 trillion tokens of data from publicly available sources. The fine-tuning data includes publicly available instruction datasets, as well as over one million new human-annotated examples. Neither the pretraining nor the fine-tuning datasets include Meta user data. 54 | 55 | **Data Freshness** The pretraining data has a cutoff of September 2022, but some tuning data is more recent, up to July 2023. 56 | 57 | # **Evaluation Results** 58 | 59 | In this section, we report the results for the Llama 1 and Llama 2 models on standard academic benchmarks. 60 | For all the evaluations, we use our internal evaluations library. 61 | 62 | |Model|Size|Code|Commonsense Reasoning|World Knowledge|Reading Comprehension|Math|MMLU|BBH|AGI Eval| 63 | |---|---|---|---|---|---|---|---|---|---| 64 | |Llama 1|7B|14.1|60.8|46.2|58.5|6.95|35.1|30.3|23.9| 65 | |Llama 1|13B|18.9|66.1|52.6|62.3|10.9|46.9|37.0|33.9| 66 | |Llama 1|33B|26.0|70.0|58.4|67.6|21.4|57.8|39.8|41.7| 67 | |Llama 1|65B|30.7|70.7|60.5|68.6|30.8|63.4|43.5|47.6| 68 | |Llama 2|7B|16.8|63.9|48.9|61.3|14.6|45.3|32.6|29.3| 69 | |Llama 2|13B|24.5|66.9|55.4|65.8|28.7|54.8|39.4|39.1| 70 | |Llama 2|70B|**37.5**|**71.9**|**63.6**|**69.4**|**35.2**|**68.9**|**51.2**|**54.2**| 71 | 72 | **Overall performance on grouped academic benchmarks.** *Code:* We report the average pass@1 scores of our models on HumanEval and MBPP. *Commonsense Reasoning:* We report the average of PIQA, SIQA, HellaSwag, WinoGrande, ARC easy and challenge, OpenBookQA, and CommonsenseQA. We report 7-shot results for CommonSenseQA and 0-shot results for all other benchmarks. *World Knowledge:* We evaluate the 5-shot performance on NaturalQuestions and TriviaQA and report the average. *Reading Comprehension:* For reading comprehension, we report the 0-shot average on SQuAD, QuAC, and BoolQ. *MATH:* We report the average of the GSM8K (8 shot) and MATH (4 shot) benchmarks at top 1. 73 | 74 | |||TruthfulQA|Toxigen| 75 | |---|---|---|---| 76 | |Llama 1|7B|27.42|23.00| 77 | |Llama 1|13B|41.74|23.08| 78 | |Llama 1|33B|44.19|22.57| 79 | |Llama 1|65B|48.71|21.77| 80 | |Llama 2|7B|33.29|**21.25**| 81 | |Llama 2|13B|41.86|26.10| 82 | |Llama 2|70B|**50.18**|24.60| 83 | 84 | **Evaluation of pretrained LLMs on automatic safety benchmarks.** For TruthfulQA, we present the percentage of generations that are both truthful and informative (the higher the better). For ToxiGen, we present the percentage of toxic generations (the smaller the better). 85 | 86 | 87 | |||TruthfulQA|Toxigen| 88 | |---|---|---|---| 89 | |Llama-2-Chat|7B|57.04|**0.00**| 90 | |Llama-2-Chat|13B|62.18|**0.00**| 91 | |Llama-2-Chat|70B|**64.14**|0.01| 92 | 93 | **Evaluation of fine-tuned LLMs on different safety datasets.** Same metric definitions as above. 94 | 95 | # **Ethical Considerations and Limitations** 96 | Llama 2 is a new technology that carries risks with use. Testing conducted to date has been in English, and has not covered, nor could it cover all scenarios. For these reasons, as with all LLMs, Llama 2’s potential outputs cannot be predicted in advance, and the model may in some instances produce inaccurate, biased or other objectionable responses to user prompts. Therefore, before deploying any applications of Llama 2, developers should perform safety testing and tuning tailored to their specific applications of the model. 97 | 98 | Please see the Responsible Use Guide available at [https://ai.meta.com/llama/responsible-use-guide/](https://ai.meta.com/llama/responsible-use-guide/) 99 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Llama 2 2 | 3 | We are unlocking the power of large language models. Our latest version of Llama is now accessible to individuals, creators, researchers and businesses of all sizes so that they can experiment, innovate and scale their ideas responsibly. 4 | 5 | This release includes model weights and starting code for pretrained and fine-tuned Llama language models — ranging from 7B to 70B parameters. 6 | 7 | This repository is intended as a minimal example to load [Llama 2](https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/) models and run inference. For more detailed examples leveraging HuggingFace, see [llama-recipes](https://github.com/facebookresearch/llama-recipes/). 8 | 9 | ## Download 10 | 11 | ⚠️ **7/18: We're aware of people encountering a number of download issues today. Anyone still encountering issues should remove all local files, re-clone the repository, and [request a new download link](https://ai.meta.com/resources/models-and-libraries/llama-downloads/). It's critical to do all of these in case you have local corrupt files. When you receive the email, copy *only* the link text - it should begin with https://download.llamameta.net and not with https://l.facebook.com, which will give errors.** 12 | 13 | 14 | 15 | In order to download the model weights and tokenizer, please visit the [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and accept our License. 16 | 17 | Once your request is approved, you will receive a signed URL over email. Then run the download.sh script, passing the URL provided when prompted to start the download. Make sure that you copy the URL text itself, **do not use the 'Copy link address' option** when you right click the URL. If the copied URL text starts with: https://download.llamameta.net, you copied it correctly. If the copied URL text starts with: https://l.facebook.com, you copied it the wrong way. 18 | 19 | Pre-requisites: make sure you have `wget` and `md5sum` installed. Then to run the script: `./download.sh`. 20 | 21 | Keep in mind that the links expire after 24 hours and a certain amount of downloads. If you start seeing errors such as `403: Forbidden`, you can always re-request a link. 22 | 23 | ### Access on Hugging Face 24 | 25 | We are also providing downloads on [Hugging Face](https://huggingface.co/meta-llama). You must first request a download from the Meta AI website using the same email address as your Hugging Face account. After doing so, you can request access to any of the models on Hugging Face and within 1-2 days your account will be granted access to all versions. 26 | 27 | ## Setup 28 | 29 | In a conda env with PyTorch / CUDA available, clone the repo and run in the top-level directory: 30 | 31 | ``` 32 | pip install -e . 33 | ``` 34 | 35 | ## Inference 36 | 37 | Different models require different model-parallel (MP) values: 38 | 39 | | Model | MP | 40 | |--------|----| 41 | | 7B | 1 | 42 | | 13B | 2 | 43 | | 70B | 8 | 44 | 45 | All models support sequence length up to 4096 tokens, but we pre-allocate the cache according to `max_seq_len` and `max_batch_size` values. So set those according to your hardware. 46 | 47 | ### Pretrained Models 48 | 49 | These models are not finetuned for chat or Q&A. They should be prompted so that the expected answer is the natural continuation of the prompt. 50 | 51 | See `example_text_completion.py` for some examples. To illustrate, see command below to run it with the llama-2-7b model (`nproc_per_node` needs to be set to the `MP` value): 52 | 53 | ``` 54 | torchrun --nproc_per_node 1 example_text_completion.py \ 55 | --ckpt_dir llama-2-7b/ \ 56 | --tokenizer_path tokenizer.model \ 57 | --max_seq_len 128 --max_batch_size 4 58 | ``` 59 | 60 | ### Fine-tuned Chat Models 61 | 62 | The fine-tuned models were trained for dialogue applications. To get the expected features and performance for them, a specific formatting defined in [`chat_completion`](https://github.com/facebookresearch/llama/blob/main/llama/generation.py#L212) 63 | needs to be followed, including the `INST` and `<>` tags, `BOS` and `EOS` tokens, and the whitespaces and breaklines in between (we recommend calling `strip()` on inputs to avoid double-spaces). 64 | 65 | You can also deploy additional classifiers for filtering out inputs and outputs that are deemed unsafe. See the llama-recipes repo for [an example](https://github.com/facebookresearch/llama-recipes/blob/main/inference/inference.py) of how to add a safety checker to the inputs and outputs of your inference code. 66 | 67 | Examples using llama-2-7b-chat: 68 | 69 | ``` 70 | torchrun --nproc_per_node 1 example_chat_completion.py \ 71 | --ckpt_dir llama-2-7b-chat/ \ 72 | --tokenizer_path tokenizer.model \ 73 | --max_seq_len 512 --max_batch_size 4 74 | ``` 75 | 76 | Llama 2 is a new technology that carries potential risks with use. Testing conducted to date has not — and could not — cover all scenarios. 77 | In order to help developers address these risks, we have created the [Responsible Use Guide](Responsible-Use-Guide.pdf). More details can be found in our research paper as well. 78 | 79 | ## Issues 80 | 81 | Please report any software “bug,” or other problems with the models through one of the following means: 82 | - Reporting issues with the model: [github.com/facebookresearch/llama](http://github.com/facebookresearch/llama) 83 | - Reporting risky content generated by the model: [developers.facebook.com/llama_output_feedback](http://developers.facebook.com/llama_output_feedback) 84 | - Reporting bugs and security concerns: [facebook.com/whitehat/info](http://facebook.com/whitehat/info) 85 | 86 | ## Model Card 87 | See [MODEL_CARD.md](MODEL_CARD.md). 88 | 89 | ## License 90 | 91 | Our model and weights are licensed for both researchers and commercial entities, upholding the principles of openness. Our mission is to empower individuals, and industry through this opportunity, while fostering an environment of discovery and ethical AI advancements. 92 | 93 | See the [LICENSE](LICENSE) file, as well as our accompanying [Acceptable Use Policy](USE_POLICY.md) 94 | 95 | ## References 96 | 97 | 1. [Research Paper](https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/) 98 | 2. [Llama 2 technical overview](https://ai.meta.com/resources/models-and-libraries/llama) 99 | 3. [Open Innovation AI Research Community](https://ai.meta.com/llama/open-innovation-ai-research-community/) 100 | 101 | ## Original LLaMA 102 | The repo for the original llama release is in the [`llama_v1`](https://github.com/facebookresearch/llama/tree/llama_v1) branch. 103 | -------------------------------------------------------------------------------- /Responsible-Use-Guide.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-tpu/llama/6c7fe276574e78057f917549435a2554000a876d/Responsible-Use-Guide.pdf -------------------------------------------------------------------------------- /USE_POLICY.md: -------------------------------------------------------------------------------- 1 | # Llama 2 Acceptable Use Policy 2 | 3 | Meta is committed to promoting safe and fair use of its tools and features, including Llama 2. If you access or use Llama 2, you agree to this Acceptable Use Policy (“Policy”). The most recent copy of this policy can be found at [ai.meta.com/llama/use-policy](http://ai.meta.com/llama/use-policy). 4 | 5 | ## Prohibited Uses 6 | We want everyone to use Llama 2 safely and responsibly. You agree you will not use, or allow others to use, Llama 2 to: 7 | 8 | 1. Violate the law or others’ rights, including to: 9 | 1. Engage in, promote, generate, contribute to, encourage, plan, incite, or further illegal or unlawful activity or content, such as: 10 | 1. Violence or terrorism 11 | 2. Exploitation or harm to children, including the solicitation, creation, acquisition, or dissemination of child exploitative content or failure to report Child Sexual Abuse Material 12 | 3. Human trafficking, exploitation, and sexual violence 13 | 4. The illegal distribution of information or materials to minors, including obscene materials, or failure to employ legally required age-gating in connection with such information or materials. 14 | 5. Sexual solicitation 15 | 6. Any other criminal activity 16 | 2. Engage in, promote, incite, or facilitate the harassment, abuse, threatening, or bullying of individuals or groups of individuals 17 | 3. Engage in, promote, incite, or facilitate discrimination or other unlawful or harmful conduct in the provision of employment, employment benefits, credit, housing, other economic benefits, or other essential goods and services 18 | 4. Engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or related professional practices 19 | 5. Collect, process, disclose, generate, or infer health, demographic, or other sensitive personal or private information about individuals without rights and consents required by applicable laws 20 | 6. Engage in or facilitate any action or generate any content that infringes, misappropriates, or otherwise violates any third-party rights, including the outputs or results of any products or services using the Llama 2 Materials 21 | 7. Create, generate, or facilitate the creation of malicious code, malware, computer viruses or do anything else that could disable, overburden, interfere with or impair the proper working, integrity, operation or appearance of a website or computer system 22 | 23 | 24 | 25 | 2. Engage in, promote, incite, facilitate, or assist in the planning or development of activities that present a risk of death or bodily harm to individuals, including use of Llama 2 related to the following: 26 | 1. Military, warfare, nuclear industries or applications, espionage, use for materials or activities that are subject to the International Traffic Arms Regulations (ITAR) maintained by the United States Department of State 27 | 2. Guns and illegal weapons (including weapon development) 28 | 3. Illegal drugs and regulated/controlled substances 29 | 4. Operation of critical infrastructure, transportation technologies, or heavy machinery 30 | 5. Self-harm or harm to others, including suicide, cutting, and eating disorders 31 | 6. Any content intended to incite or promote violence, abuse, or any infliction of bodily harm to an individual 32 | 33 | 34 | 35 | 3. Intentionally deceive or mislead others, including use of Llama 2 related to the following: 36 | 1. Generating, promoting, or furthering fraud or the creation or promotion of disinformation 37 | 2. Generating, promoting, or furthering defamatory content, including the creation of defamatory statements, images, or other content 38 | 3. Generating, promoting, or further distributing spam 39 | 4. Impersonating another individual without consent, authorization, or legal right 40 | 5. Representing that the use of Llama 2 or outputs are human-generated 41 | 6. Generating or facilitating false online engagement, including fake reviews and other means of fake online engagement 42 | 4. Fail to appropriately disclose to end users any known dangers of your AI system 43 | 44 | Please report any violation of this Policy, software “bug,” or other problems that could lead to a violation of this Policy through one of the following means: 45 | 46 | * Reporting issues with the model: [github.com/facebookresearch/llama](http://github.com/facebookresearch/llama) 47 | * Reporting risky content generated by the model: [developers.facebook.com/llama_output_feedback](http://developers.facebook.com/llama_output_feedback) 48 | * Reporting bugs and security concerns: [facebook.com/whitehat/info](http://facebook.com/whitehat/info) 49 | * Reporting violations of the Acceptable Use Policy or unlicensed uses of Llama: [LlamaUseReport@meta.com](mailto:LlamaUseReport@meta.com) 50 | 51 | -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 5 | 6 | read -p "Enter the URL from email: " PRESIGNED_URL 7 | echo "" 8 | read -p "Enter the list of models to download without spaces (7B,13B,70B,7B-chat,13B-chat,70B-chat), or press Enter for all: " MODEL_SIZE 9 | TARGET_FOLDER="." # where all files should end up 10 | mkdir -p ${TARGET_FOLDER} 11 | 12 | if [[ $MODEL_SIZE == "" ]]; then 13 | MODEL_SIZE="7B,13B,70B,7B-chat,13B-chat,70B-chat" 14 | fi 15 | 16 | echo "Downloading LICENSE and Acceptable Usage Policy" 17 | wget ${PRESIGNED_URL/'*'/"LICENSE"} -O ${TARGET_FOLDER}"/LICENSE" 18 | wget ${PRESIGNED_URL/'*'/"USE_POLICY.md"} -O ${TARGET_FOLDER}"/USE_POLICY.md" 19 | 20 | echo "Downloading tokenizer" 21 | wget ${PRESIGNED_URL/'*'/"tokenizer.model"} -O ${TARGET_FOLDER}"/tokenizer.model" 22 | wget ${PRESIGNED_URL/'*'/"tokenizer_checklist.chk"} -O ${TARGET_FOLDER}"/tokenizer_checklist.chk" 23 | (cd ${TARGET_FOLDER} && md5sum -c tokenizer_checklist.chk) 24 | 25 | for m in ${MODEL_SIZE//,/ } 26 | do 27 | if [[ $m == "7B" ]]; then 28 | SHARD=0 29 | MODEL_PATH="llama-2-7b" 30 | elif [[ $m == "7B-chat" ]]; then 31 | SHARD=0 32 | MODEL_PATH="llama-2-7b-chat" 33 | elif [[ $m == "13B" ]]; then 34 | SHARD=1 35 | MODEL_PATH="llama-2-13b" 36 | elif [[ $m == "13B-chat" ]]; then 37 | SHARD=1 38 | MODEL_PATH="llama-2-13b-chat" 39 | elif [[ $m == "70B" ]]; then 40 | SHARD=7 41 | MODEL_PATH="llama-2-70b" 42 | elif [[ $m == "70B-chat" ]]; then 43 | SHARD=7 44 | MODEL_PATH="llama-2-70b-chat" 45 | fi 46 | 47 | echo "Downloading ${MODEL_PATH}" 48 | mkdir -p ${TARGET_FOLDER}"/${MODEL_PATH}" 49 | 50 | for s in $(seq -f "0%g" 0 ${SHARD}) 51 | do 52 | wget ${PRESIGNED_URL/'*'/"${MODEL_PATH}/consolidated.${s}.pth"} -O ${TARGET_FOLDER}"/${MODEL_PATH}/consolidated.${s}.pth" 53 | done 54 | 55 | wget ${PRESIGNED_URL/'*'/"${MODEL_PATH}/params.json"} -O ${TARGET_FOLDER}"/${MODEL_PATH}/params.json" 56 | wget ${PRESIGNED_URL/'*'/"${MODEL_PATH}/checklist.chk"} -O ${TARGET_FOLDER}"/${MODEL_PATH}/checklist.chk" 57 | echo "Checking checksums" 58 | (cd ${TARGET_FOLDER}"/${MODEL_PATH}" && md5sum -c checklist.chk) 59 | done 60 | 61 | -------------------------------------------------------------------------------- /example_chat_completion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from typing import Optional 5 | 6 | import fire 7 | 8 | from llama import Llama 9 | 10 | 11 | def main( 12 | ckpt_dir: str, 13 | tokenizer_path: str, 14 | temperature: float = 0.6, 15 | top_p: float = 0.9, 16 | max_seq_len: int = 512, 17 | max_batch_size: int = 4, 18 | max_gen_len: Optional[int] = None, 19 | ): 20 | generator = Llama.build( 21 | ckpt_dir=ckpt_dir, 22 | tokenizer_path=tokenizer_path, 23 | max_seq_len=max_seq_len, 24 | max_batch_size=max_batch_size, 25 | ) 26 | 27 | dialogs = [ 28 | [{"role": "user", "content": "what is the recipe of mayonnaise?"}], 29 | [ 30 | {"role": "user", "content": "I am going to Paris, what should I see?"}, 31 | { 32 | "role": "assistant", 33 | "content": """\ 34 | Paris, the capital of France, is known for its stunning architecture, art museums, historical landmarks, and romantic atmosphere. Here are some of the top attractions to see in Paris: 35 | 36 | 1. The Eiffel Tower: The iconic Eiffel Tower is one of the most recognizable landmarks in the world and offers breathtaking views of the city. 37 | 2. The Louvre Museum: The Louvre is one of the world's largest and most famous museums, housing an impressive collection of art and artifacts, including the Mona Lisa. 38 | 3. Notre-Dame Cathedral: This beautiful cathedral is one of the most famous landmarks in Paris and is known for its Gothic architecture and stunning stained glass windows. 39 | 40 | These are just a few of the many attractions that Paris has to offer. With so much to see and do, it's no wonder that Paris is one of the most popular tourist destinations in the world.""", 41 | }, 42 | {"role": "user", "content": "What is so great about #1?"}, 43 | ], 44 | [ 45 | {"role": "system", "content": "Always answer with Haiku"}, 46 | {"role": "user", "content": "I am going to Paris, what should I see?"}, 47 | ], 48 | [ 49 | { 50 | "role": "system", 51 | "content": "Always answer with emojis", 52 | }, 53 | {"role": "user", "content": "How to go from Beijing to NY?"}, 54 | ], 55 | ] 56 | results = generator.chat_completion( 57 | dialogs, # type: ignore 58 | max_gen_len=max_gen_len, 59 | temperature=temperature, 60 | top_p=top_p, 61 | ) 62 | 63 | for dialog, result in zip(dialogs, results): 64 | for msg in dialog: 65 | print(f"{msg['role'].capitalize()}: {msg['content']}\n") 66 | print( 67 | f"> {result['generation']['role'].capitalize()}: {result['generation']['content']}" 68 | ) 69 | print("\n==================================\n") 70 | 71 | 72 | if __name__ == "__main__": 73 | fire.Fire(main) 74 | -------------------------------------------------------------------------------- /example_text_completion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import fire 5 | 6 | from llama import Llama 7 | 8 | 9 | def main( 10 | ckpt_dir: str, 11 | tokenizer_path: str, 12 | temperature: float = 0.6, 13 | top_p: float = 0.9, 14 | max_seq_len: int = 128, 15 | max_gen_len: int = 64, 16 | max_batch_size: int = 4, 17 | ): 18 | generator = Llama.build( 19 | ckpt_dir=ckpt_dir, 20 | tokenizer_path=tokenizer_path, 21 | max_seq_len=max_seq_len, 22 | max_batch_size=max_batch_size, 23 | ) 24 | 25 | prompts = [ 26 | # For these prompts, the expected answer is the natural continuation of the prompt 27 | "I believe the meaning of life is", 28 | "Simply put, the theory of relativity states that ", 29 | """A brief message congratulating the team on the launch: 30 | 31 | Hi everyone, 32 | 33 | I just """, 34 | # Few shot prompt (providing a few examples before asking model to complete more); 35 | """Translate English to French: 36 | 37 | sea otter => loutre de mer 38 | peppermint => menthe poivrée 39 | plush girafe => girafe peluche 40 | cheese =>""", 41 | ] 42 | results = generator.text_completion( 43 | prompts, 44 | max_gen_len=max_gen_len, 45 | temperature=temperature, 46 | top_p=top_p, 47 | ) 48 | for prompt, result in zip(prompts, results): 49 | print(prompt) 50 | print(f"> {result['generation']}") 51 | print("\n==================================\n") 52 | 53 | 54 | if __name__ == "__main__": 55 | fire.Fire(main) 56 | -------------------------------------------------------------------------------- /llama/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from .generation import Llama 5 | from .model import ModelArgs, Transformer 6 | from .tokenizer import Tokenizer 7 | -------------------------------------------------------------------------------- /llama/generation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import json 5 | import os 6 | import sys 7 | import time 8 | from pathlib import Path 9 | from typing import List, Literal, Optional, Tuple, TypedDict 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from fairscale.nn.model_parallel.initialize import ( 14 | get_model_parallel_rank, 15 | initialize_model_parallel, 16 | model_parallel_is_initialized, 17 | ) 18 | 19 | from llama.model import ModelArgs, Transformer 20 | from llama.tokenizer import Tokenizer 21 | 22 | Role = Literal["system", "user", "assistant"] 23 | 24 | 25 | class Message(TypedDict): 26 | role: Role 27 | content: str 28 | 29 | 30 | class CompletionPrediction(TypedDict, total=False): 31 | generation: str 32 | tokens: List[str] # not required 33 | logprobs: List[float] # not required 34 | 35 | 36 | class ChatPrediction(TypedDict, total=False): 37 | generation: Message 38 | tokens: List[str] # not required 39 | logprobs: List[float] # not required 40 | 41 | 42 | Dialog = List[Message] 43 | 44 | B_INST, E_INST = "[INST]", "[/INST]" 45 | B_SYS, E_SYS = "<>\n", "\n<>\n\n" 46 | DEFAULT_SYSTEM_PROMPT = """\ 47 | You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. 48 | 49 | If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" 50 | 51 | 52 | class Llama: 53 | @staticmethod 54 | def build( 55 | ckpt_dir: str, 56 | tokenizer_path: str, 57 | max_seq_len: int, 58 | max_batch_size: int, 59 | model_parallel_size: Optional[int] = None, 60 | ) -> "Llama": 61 | if not torch.distributed.is_initialized(): 62 | torch.distributed.init_process_group("nccl") 63 | if not model_parallel_is_initialized(): 64 | if model_parallel_size is None: 65 | model_parallel_size = int(os.environ.get("WORLD_SIZE", 1)) 66 | initialize_model_parallel(model_parallel_size) 67 | 68 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 69 | torch.cuda.set_device(local_rank) 70 | 71 | # seed must be the same in all processes 72 | torch.manual_seed(1) 73 | 74 | if local_rank > 0: 75 | sys.stdout = open(os.devnull, "w") 76 | 77 | start_time = time.time() 78 | checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) 79 | assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" 80 | assert model_parallel_size == len( 81 | checkpoints 82 | ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}" 83 | ckpt_path = checkpoints[get_model_parallel_rank()] 84 | checkpoint = torch.load(ckpt_path, map_location="cpu") 85 | with open(Path(ckpt_dir) / "params.json", "r") as f: 86 | params = json.loads(f.read()) 87 | 88 | model_args: ModelArgs = ModelArgs( 89 | max_seq_len=max_seq_len, 90 | max_batch_size=max_batch_size, 91 | **params, 92 | ) 93 | tokenizer = Tokenizer(model_path=tokenizer_path) 94 | model_args.vocab_size = tokenizer.n_words 95 | torch.set_default_tensor_type(torch.cuda.HalfTensor) 96 | model = Transformer(model_args) 97 | model.load_state_dict(checkpoint, strict=False) 98 | print(f"Loaded in {time.time() - start_time:.2f} seconds") 99 | 100 | return Llama(model, tokenizer) 101 | 102 | def __init__(self, model: Transformer, tokenizer: Tokenizer): 103 | self.model = model 104 | self.tokenizer = tokenizer 105 | 106 | @torch.inference_mode() 107 | def generate( 108 | self, 109 | prompt_tokens: List[List[int]], 110 | max_gen_len: int, 111 | temperature: float = 0.6, 112 | top_p: float = 0.9, 113 | logprobs: bool = False, 114 | echo: bool = False, 115 | ) -> Tuple[List[List[int]], Optional[List[List[float]]]]: 116 | params = self.model.params 117 | bsz = len(prompt_tokens) 118 | assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) 119 | 120 | min_prompt_len = min(len(t) for t in prompt_tokens) 121 | max_prompt_len = max(len(t) for t in prompt_tokens) 122 | assert max_prompt_len <= params.max_seq_len 123 | total_len = min(params.max_seq_len, max_gen_len + max_prompt_len) 124 | 125 | pad_id = self.tokenizer.pad_id 126 | tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda") 127 | for k, t in enumerate(prompt_tokens): 128 | tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") 129 | if logprobs: 130 | token_logprobs = torch.zeros_like(tokens, dtype=torch.float) 131 | 132 | prev_pos = 0 133 | eos_reached = torch.tensor([False] * bsz, device="cuda") 134 | input_text_mask = tokens != pad_id 135 | for cur_pos in range(min_prompt_len, total_len): 136 | logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) 137 | if logprobs: 138 | token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy( 139 | input=logits.transpose(1, 2), 140 | target=tokens[:, prev_pos + 1 : cur_pos + 1], 141 | reduction="none", 142 | ignore_index=pad_id, 143 | ) 144 | if temperature > 0: 145 | probs = torch.softmax(logits[:, -1] / temperature, dim=-1) 146 | next_token = sample_top_p(probs, top_p) 147 | else: 148 | next_token = torch.argmax(logits[:, -1], dim=-1) 149 | 150 | next_token = next_token.reshape(-1) 151 | # only replace token if prompt has already been generated 152 | next_token = torch.where( 153 | input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token 154 | ) 155 | tokens[:, cur_pos] = next_token 156 | eos_reached |= (~input_text_mask[:, cur_pos]) & ( 157 | next_token == self.tokenizer.eos_id 158 | ) 159 | prev_pos = cur_pos 160 | if all(eos_reached): 161 | break 162 | 163 | if logprobs: 164 | token_logprobs = token_logprobs.tolist() 165 | out_tokens, out_logprobs = [], [] 166 | for i, toks in enumerate(tokens.tolist()): 167 | # cut to max gen len 168 | start = 0 if echo else len(prompt_tokens[i]) 169 | toks = toks[start : len(prompt_tokens[i]) + max_gen_len] 170 | probs = None 171 | if logprobs: 172 | probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len] 173 | # cut to eos tok if any 174 | if self.tokenizer.eos_id in toks: 175 | eos_idx = toks.index(self.tokenizer.eos_id) 176 | toks = toks[:eos_idx] 177 | probs = probs[:eos_idx] if logprobs else None 178 | out_tokens.append(toks) 179 | out_logprobs.append(probs) 180 | return (out_tokens, out_logprobs if logprobs else None) 181 | 182 | def text_completion( 183 | self, 184 | prompts: List[str], 185 | temperature: float = 0.6, 186 | top_p: float = 0.9, 187 | max_gen_len: Optional[int] = None, 188 | logprobs: bool = False, 189 | echo: bool = False, 190 | ) -> List[CompletionPrediction]: 191 | if max_gen_len is None: 192 | max_gen_len = self.model.params.max_seq_len - 1 193 | prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] 194 | generation_tokens, generation_logprobs = self.generate( 195 | prompt_tokens=prompt_tokens, 196 | max_gen_len=max_gen_len, 197 | temperature=temperature, 198 | top_p=top_p, 199 | logprobs=logprobs, 200 | echo=echo, 201 | ) 202 | if logprobs: 203 | return [ 204 | { 205 | "generation": self.tokenizer.decode(t), 206 | "tokens": [self.tokenizer.decode(x) for x in t], 207 | "logprobs": logprobs_i, 208 | } 209 | for t, logprobs_i in zip(generation_tokens, generation_logprobs) 210 | ] 211 | return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens] 212 | 213 | def chat_completion( 214 | self, 215 | dialogs: List[Dialog], 216 | temperature: float = 0.6, 217 | top_p: float = 0.9, 218 | max_gen_len: Optional[int] = None, 219 | logprobs: bool = False, 220 | ) -> List[ChatPrediction]: 221 | if max_gen_len is None: 222 | max_gen_len = self.model.params.max_seq_len - 1 223 | prompt_tokens = [] 224 | for dialog in dialogs: 225 | if dialog[0]["role"] != "system": 226 | dialog = [ 227 | { 228 | "role": "system", 229 | "content": DEFAULT_SYSTEM_PROMPT, 230 | } 231 | ] + dialog 232 | dialog = [ 233 | { 234 | "role": dialog[1]["role"], 235 | "content": B_SYS 236 | + dialog[0]["content"] 237 | + E_SYS 238 | + dialog[1]["content"], 239 | } 240 | ] + dialog[2:] 241 | assert all([msg["role"] == "user" for msg in dialog[::2]]) and all( 242 | [msg["role"] == "assistant" for msg in dialog[1::2]] 243 | ), ( 244 | "model only supports 'system', 'user' and 'assistant' roles, " 245 | "starting with 'system', then 'user' and alternating (u/a/u/a/u...)" 246 | ) 247 | dialog_tokens: List[int] = sum( 248 | [ 249 | self.tokenizer.encode( 250 | f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ", 251 | bos=True, 252 | eos=True, 253 | ) 254 | for prompt, answer in zip( 255 | dialog[::2], 256 | dialog[1::2], 257 | ) 258 | ], 259 | [], 260 | ) 261 | assert ( 262 | dialog[-1]["role"] == "user" 263 | ), f"Last message must be from user, got {dialog[-1]['role']}" 264 | dialog_tokens += self.tokenizer.encode( 265 | f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}", 266 | bos=True, 267 | eos=False, 268 | ) 269 | prompt_tokens.append(dialog_tokens) 270 | 271 | generation_tokens, generation_logprobs = self.generate( 272 | prompt_tokens=prompt_tokens, 273 | max_gen_len=max_gen_len, 274 | temperature=temperature, 275 | top_p=top_p, 276 | logprobs=logprobs, 277 | ) 278 | if logprobs: 279 | return [ 280 | { 281 | "generation": { 282 | "role": "assistant", 283 | "content": self.tokenizer.decode(t), 284 | }, 285 | "tokens": [self.tokenizer.decode(x) for x in t], 286 | "logprobs": logprobs_i, 287 | } 288 | for t, logprobs_i in zip(generation_tokens, generation_logprobs) 289 | ] 290 | return [ 291 | {"generation": {"role": "assistant", "content": self.tokenizer.decode(t)}} 292 | for t in generation_tokens 293 | ] 294 | 295 | 296 | def sample_top_p(probs, p): 297 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) 298 | probs_sum = torch.cumsum(probs_sort, dim=-1) 299 | mask = probs_sum - probs_sort > p 300 | probs_sort[mask] = 0.0 301 | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) 302 | next_token = torch.multinomial(probs_sort, num_samples=1) 303 | next_token = torch.gather(probs_idx, -1, next_token) 304 | return next_token 305 | -------------------------------------------------------------------------------- /llama/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import math 5 | from dataclasses import dataclass 6 | from typing import Any, Optional, Tuple 7 | 8 | import fairscale.nn.model_parallel.initialize as fs_init 9 | import torch 10 | import torch.nn.functional as F 11 | from fairscale.nn.model_parallel.layers import ( 12 | ColumnParallelLinear, 13 | ParallelEmbedding, 14 | RowParallelLinear, 15 | ) 16 | from torch import nn 17 | 18 | 19 | @dataclass 20 | class ModelArgs: 21 | dim: int = 4096 22 | n_layers: int = 32 23 | n_heads: int = 32 24 | n_kv_heads: Optional[int] = None 25 | vocab_size: int = -1 # defined later by tokenizer 26 | multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 27 | ffn_dim_multiplier: Optional[float] = None 28 | norm_eps: float = 1e-5 29 | 30 | max_batch_size: int = 32 31 | max_seq_len: int = 2048 32 | 33 | 34 | class RMSNorm(torch.nn.Module): 35 | def __init__(self, dim: int, eps: float = 1e-6): 36 | super().__init__() 37 | self.eps = eps 38 | self.weight = nn.Parameter(torch.ones(dim)) 39 | 40 | def _norm(self, x): 41 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 42 | 43 | def forward(self, x): 44 | output = self._norm(x.float()).type_as(x) 45 | return output * self.weight 46 | 47 | 48 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): 49 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 50 | t = torch.arange(end, device=freqs.device) # type: ignore 51 | freqs = torch.outer(t, freqs).float() # type: ignore 52 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 53 | return freqs_cis 54 | 55 | 56 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): 57 | ndim = x.ndim 58 | assert 0 <= 1 < ndim 59 | assert freqs_cis.shape == (x.shape[1], x.shape[-1]) 60 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] 61 | return freqs_cis.view(*shape) 62 | 63 | 64 | def apply_rotary_emb( 65 | xq: torch.Tensor, 66 | xk: torch.Tensor, 67 | freqs_cis: torch.Tensor, 68 | ) -> Tuple[torch.Tensor, torch.Tensor]: 69 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 70 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 71 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_) 72 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 73 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) 74 | return xq_out.type_as(xq), xk_out.type_as(xk) 75 | 76 | 77 | def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: 78 | """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" 79 | bs, slen, n_kv_heads, head_dim = x.shape 80 | if n_rep == 1: 81 | return x 82 | return ( 83 | x[:, :, :, None, :] 84 | .expand(bs, slen, n_kv_heads, n_rep, head_dim) 85 | .reshape(bs, slen, n_kv_heads * n_rep, head_dim) 86 | ) 87 | 88 | 89 | class Attention(nn.Module): 90 | def __init__(self, args: ModelArgs): 91 | super().__init__() 92 | self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads 93 | model_parallel_size = fs_init.get_model_parallel_world_size() 94 | self.n_local_heads = args.n_heads // model_parallel_size 95 | self.n_local_kv_heads = self.n_kv_heads // model_parallel_size 96 | self.n_rep = self.n_local_heads // self.n_local_kv_heads 97 | self.head_dim = args.dim // args.n_heads 98 | 99 | self.wq = ColumnParallelLinear( 100 | args.dim, 101 | args.n_heads * self.head_dim, 102 | bias=False, 103 | gather_output=False, 104 | init_method=lambda x: x, 105 | ) 106 | self.wk = ColumnParallelLinear( 107 | args.dim, 108 | self.n_kv_heads * self.head_dim, 109 | bias=False, 110 | gather_output=False, 111 | init_method=lambda x: x, 112 | ) 113 | self.wv = ColumnParallelLinear( 114 | args.dim, 115 | self.n_kv_heads * self.head_dim, 116 | bias=False, 117 | gather_output=False, 118 | init_method=lambda x: x, 119 | ) 120 | self.wo = RowParallelLinear( 121 | args.n_heads * self.head_dim, 122 | args.dim, 123 | bias=False, 124 | input_is_parallel=True, 125 | init_method=lambda x: x, 126 | ) 127 | 128 | self.cache_k = torch.zeros( 129 | ( 130 | args.max_batch_size, 131 | args.max_seq_len, 132 | self.n_local_kv_heads, 133 | self.head_dim, 134 | ) 135 | ).cuda() 136 | self.cache_v = torch.zeros( 137 | ( 138 | args.max_batch_size, 139 | args.max_seq_len, 140 | self.n_local_kv_heads, 141 | self.head_dim, 142 | ) 143 | ).cuda() 144 | 145 | def forward( 146 | self, 147 | x: torch.Tensor, 148 | start_pos: int, 149 | freqs_cis: torch.Tensor, 150 | mask: Optional[torch.Tensor], 151 | ): 152 | bsz, seqlen, _ = x.shape 153 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) 154 | 155 | xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) 156 | xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) 157 | xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) 158 | 159 | xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) 160 | 161 | self.cache_k = self.cache_k.to(xq) 162 | self.cache_v = self.cache_v.to(xq) 163 | 164 | self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk 165 | self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv 166 | 167 | keys = self.cache_k[:bsz, : start_pos + seqlen] 168 | values = self.cache_v[:bsz, : start_pos + seqlen] 169 | 170 | # repeat k/v heads if n_kv_heads < n_heads 171 | keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) 172 | values = repeat_kv(values, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) 173 | 174 | xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) 175 | keys = keys.transpose(1, 2) 176 | values = values.transpose(1, 2) 177 | scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) 178 | if mask is not None: 179 | scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen) 180 | scores = F.softmax(scores.float(), dim=-1).type_as(xq) 181 | output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim) 182 | output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) 183 | return self.wo(output) 184 | 185 | 186 | class FeedForward(nn.Module): 187 | def __init__( 188 | self, 189 | dim: int, 190 | hidden_dim: int, 191 | multiple_of: int, 192 | ffn_dim_multiplier: Optional[float], 193 | ): 194 | super().__init__() 195 | hidden_dim = int(2 * hidden_dim / 3) 196 | # custom dim factor multiplier 197 | if ffn_dim_multiplier is not None: 198 | hidden_dim = int(ffn_dim_multiplier * hidden_dim) 199 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) 200 | 201 | self.w1 = ColumnParallelLinear( 202 | dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x 203 | ) 204 | self.w2 = RowParallelLinear( 205 | hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x 206 | ) 207 | self.w3 = ColumnParallelLinear( 208 | dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x 209 | ) 210 | 211 | def forward(self, x): 212 | return self.w2(F.silu(self.w1(x)) * self.w3(x)) 213 | 214 | 215 | class TransformerBlock(nn.Module): 216 | def __init__(self, layer_id: int, args: ModelArgs): 217 | super().__init__() 218 | self.n_heads = args.n_heads 219 | self.dim = args.dim 220 | self.head_dim = args.dim // args.n_heads 221 | self.attention = Attention(args) 222 | self.feed_forward = FeedForward( 223 | dim=args.dim, 224 | hidden_dim=4 * args.dim, 225 | multiple_of=args.multiple_of, 226 | ffn_dim_multiplier=args.ffn_dim_multiplier, 227 | ) 228 | self.layer_id = layer_id 229 | self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) 230 | self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) 231 | 232 | def forward( 233 | self, 234 | x: torch.Tensor, 235 | start_pos: int, 236 | freqs_cis: torch.Tensor, 237 | mask: Optional[torch.Tensor], 238 | ): 239 | h = x + self.attention.forward( 240 | self.attention_norm(x), start_pos, freqs_cis, mask 241 | ) 242 | out = h + self.feed_forward.forward(self.ffn_norm(h)) 243 | return out 244 | 245 | 246 | class Transformer(nn.Module): 247 | def __init__(self, params: ModelArgs): 248 | super().__init__() 249 | self.params = params 250 | self.vocab_size = params.vocab_size 251 | self.n_layers = params.n_layers 252 | 253 | self.tok_embeddings = ParallelEmbedding( 254 | params.vocab_size, params.dim, init_method=lambda x: x 255 | ) 256 | 257 | self.layers = torch.nn.ModuleList() 258 | for layer_id in range(params.n_layers): 259 | self.layers.append(TransformerBlock(layer_id, params)) 260 | 261 | self.norm = RMSNorm(params.dim, eps=params.norm_eps) 262 | self.output = ColumnParallelLinear( 263 | params.dim, params.vocab_size, bias=False, init_method=lambda x: x 264 | ) 265 | 266 | self.freqs_cis = precompute_freqs_cis( 267 | self.params.dim // self.params.n_heads, self.params.max_seq_len * 2 268 | ) 269 | 270 | @torch.inference_mode() 271 | def forward(self, tokens: torch.Tensor, start_pos: int): 272 | _bsz, seqlen = tokens.shape 273 | h = self.tok_embeddings(tokens) 274 | self.freqs_cis = self.freqs_cis.to(h.device) 275 | freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] 276 | 277 | mask = None 278 | if seqlen > 1: 279 | mask = torch.full( 280 | (1, 1, seqlen, seqlen), float("-inf"), device=tokens.device 281 | ) 282 | mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h) 283 | 284 | for layer in self.layers: 285 | h = layer(h, start_pos, freqs_cis, mask) 286 | h = self.norm(h) 287 | output = self.output(h).float() 288 | return output 289 | -------------------------------------------------------------------------------- /llama/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import os 5 | from logging import getLogger 6 | from typing import List 7 | 8 | from sentencepiece import SentencePieceProcessor 9 | 10 | 11 | logger = getLogger() 12 | 13 | 14 | class Tokenizer: 15 | def __init__(self, model_path: str): 16 | # reload tokenizer 17 | assert os.path.isfile(model_path), model_path 18 | self.sp_model = SentencePieceProcessor(model_file=model_path) 19 | logger.info(f"Reloaded SentencePiece model from {model_path}") 20 | 21 | # BOS / EOS token IDs 22 | self.n_words: int = self.sp_model.vocab_size() 23 | self.bos_id: int = self.sp_model.bos_id() 24 | self.eos_id: int = self.sp_model.eos_id() 25 | self.pad_id: int = self.sp_model.pad_id() 26 | logger.info( 27 | f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" 28 | ) 29 | assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() 30 | 31 | def encode(self, s: str, bos: bool, eos: bool) -> List[int]: 32 | assert type(s) is str 33 | t = self.sp_model.encode(s) 34 | if bos: 35 | t = [self.bos_id] + t 36 | if eos: 37 | t = t + [self.eos_id] 38 | return t 39 | 40 | def decode(self, t: List[int]) -> str: 41 | return self.sp_model.decode(t) 42 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | fairscale 3 | fire 4 | sentencepiece 5 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from setuptools import find_packages, setup 5 | 6 | 7 | def get_requirements(path: str): 8 | return [l.strip() for l in open(path)] 9 | 10 | 11 | setup( 12 | name="llama", 13 | version="0.0.1", 14 | packages=find_packages(), 15 | install_requires=get_requirements("requirements.txt"), 16 | ) 17 | --------------------------------------------------------------------------------