├── .gitignore ├── LICENSE ├── README.md ├── assets └── StreamingLLM.pdf ├── data └── mt_bench.jsonl ├── examples ├── eval_long_ppl.py └── run_streaming_llama.py ├── figures └── schemes.png ├── setup.py └── streaming_llm ├── __init__.py ├── enable_streaming_llm.py ├── kv_cache.py ├── pos_shift ├── __init__.py ├── modify_falcon.py ├── modify_gpt_neox.py └── modify_llama.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 MIT HAN Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Efficient Streaming Language Models with Attention Sinks 2 | [[paper](http://arxiv.org/abs/2309.17453)] [[slides](assets/StreamingLLM.pdf)][[video](https://youtu.be/hvJsEzP34o8)] 3 | 4 | ![schemes](figures/schemes.png) 5 | 6 | https://github.com/mit-han-lab/streaming-llm/assets/40906949/2bd1cda4-a0bd-47d1-a023-fbf7779b8358 7 | 8 | ## TL;DR 9 | We deploy LLMs for infinite-length inputs without sacrificing efficiency and performance. 10 | 11 | ## News 12 | 13 | - [2024/02] StreamingLLM is covered by [MIT News as a spotlight](https://news.mit.edu/2024/new-way-let-ai-chatbots-converse-all-day-without-crashing-0213)! 14 | - [2024/01] StreamingLLM is integrated by HPC-AI Tech [SwiftInfer](https://github.com/hpcaitech/SwiftInfer) to support infinite input length for LLM inference. 15 | - [2024/01] StreamingLLM is integrated by NVIDIA [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llama#run-llama-with-streamingllm)! 16 | - [2023/12] StreamingLLM is integrated by CMU, UW, and OctoAI, enabling endless and efficient LLM generation on [iPhone](https://x.com/davidpissarra/status/1735761373261427189?s=20)! 17 | - [2023/12] StreamingLLM is integrated by HuggingFace Transformers [PR](https://github.com/huggingface/transformers/pull/26681). 18 | - [2023/10] StreamingLLM is integrated into [Intel Extension for Transformers](https://github.com/intel/intel-extension-for-transformers). 19 | - [2023/10] [Attention Sinks](https://github.com/tomaarsen/attention_sinks), a third-party implementation enables StreamingLLM on more Huggingface LLMs. 20 | 21 | ## Abstract 22 | Deploying Large Language Models (LLMs) in streaming applications such as multi-round dialogue, where long interactions are expected, is urgently needed but poses two major challenges. Firstly, during the decoding stage, caching previous tokens' Key and Value states (KV) consumes extensive memory. Secondly, popular LLMs cannot generalize to longer texts than the training sequence length. Window attention, where only the most recent KVs are cached, is a natural approach --- but we show that it fails when the text length surpasses the cache size. We observe an interesting phenomenon, namely attention sink, that keeping the KV of initial tokens will largely recover the performance of window attention. In this paper, we first demonstrate that the emergence of attention sink is due to the strong attention scores towards initial tokens as a ``sink'' even if they are not semantically important. Based on the above analysis, we introduce StreamingLLM, an efficient framework that enables LLMs trained with a finite length attention window to generalize to infinite sequence length without any fine-tuning. We show that StreamingLLM can enable Llama-2, MPT, Falcon, and Pythia to perform stable and efficient language modeling with up to 4 million tokens and more. In addition, we discover that adding a placeholder token as a dedicated attention sink during pre-training can further improve streaming deployment. In streaming settings, StreamingLLM outperforms the sliding window recomputation baseline by up to 22.2x speedup. 23 | 24 | ## Usage 25 | 26 | ### Environment Setup 27 | 28 | ```bash 29 | conda create -yn streaming python=3.8 30 | conda activate streaming 31 | 32 | pip install torch torchvision torchaudio 33 | pip install transformers==4.33.0 accelerate datasets evaluate wandb scikit-learn scipy sentencepiece 34 | 35 | python setup.py develop 36 | ``` 37 | 38 | ### Run Streaming Llama Chatbot 39 | 40 | ```bash 41 | CUDA_VISIBLE_DEVICES=0 python examples/run_streaming_llama.py --enable_streaming 42 | ``` 43 | 44 | ## FAQ 45 | 46 | 1. **What does "working on infinite-length inputs" imply for LLMs?** 47 | 48 | Handling infinite-length text with LLMs presents challenges. Notably, storing all previous Key and Value (KV) states demands significant memory, and models might struggle to generate text beyond their training sequence length. StreamingLLM addresses this by retaining only the most recent tokens and attention sinks, discarding intermediate tokens. This enables the model to generate coherent text from recent tokens without a cache reset — a capability not seen in earlier methods. 49 | 50 | 2. **Is the context window of LLMs expanded?** 51 | 52 | No. The context window remains unchanged. Only the most recent tokens and attention sinks are retained, discarding middle tokens. This means the model can only process the latest tokens. The context window remains constrained by its initial pre-training. For instance, if Llama-2 is pre-trained with a context window of 4096 tokens, then the maximum cache size for StreamingLLM on Llama-2 remains 4096. 53 | 54 | 3. **Can I input an extensive text, like a book, into StreamingLLM for summarization?** 55 | 56 | While you can input a lengthy text, the model will only recognize the latest tokens. Thus, if a book is an input, StreamingLLM might only summarize the concluding paragraphs, which might not be very insightful. As emphasized earlier, we neither expand the LLMs' context window nor enhance their long-term memory. StreamingLLM's strength lies in generating fluent text from recent tokens without needing a cache refresh. 57 | 58 | 4. **What is the ideal use case for StreamingLLM?** 59 | 60 | StreamingLLM is optimized for streaming applications, such as multi-round dialogues. It's ideal for scenarios where a model needs to operate continually without requiring extensive memory or dependency on past data. An example is a daily assistant based on LLMs. StreamingLLM would let the model function continuously, basing its responses on recent conversations without needing to refresh its cache. Earlier methods would either need a cache reset when the conversation length exceeded the training length (losing recent context) or recompute KV states from recent text history, which can be time-consuming. 61 | 62 | 5. **How does StreamingLLM relate to recent works on context extension?** 63 | 64 | StreamingLLM is orthogonal to recent context extension methods and can be integrated with them. In StreamingLLM's context, "context extension" refers to the possibility of using a larger cache size to store more recent tokens. For a practical demonstration, refer to Figure 9 in our paper, where we implement StreamingLLM with models like LongChat-7B-v1.5-32K and Llama-2-7B-32K-Instruct. 65 | 66 | ## TODOs 67 | We will release the code and data in the following order, please stay tuned! 68 | 69 | - [x] Release core code of StreamingLLM, including Llama-2, MPT, Falcon, and Pythia. 70 | - [x] Release perplexity evaluation code 71 | - [x] Release Streaming Llama Chatbot demo. 72 | - [ ] Release StreamEval dataset and evaluation code. 73 | 74 | 75 | ## Citation 76 | 77 | If you find StreamingLLM useful or relevant to your project and research, please kindly cite our paper: 78 | 79 | ```bibtex 80 | @article{xiao2023streamingllm, 81 | title={Efficient Streaming Language Models with Attention Sinks}, 82 | author={Xiao, Guangxuan and Tian, Yuandong and Chen, Beidi and Han, Song and Lewis, Mike}, 83 | journal={arXiv}, 84 | year={2023} 85 | } 86 | ``` 87 | -------------------------------------------------------------------------------- /assets/StreamingLLM.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/streaming-llm/2e5042606d69933d88fbf909bd77907456b9b4dd/assets/StreamingLLM.pdf -------------------------------------------------------------------------------- /data/mt_bench.jsonl: -------------------------------------------------------------------------------- 1 | {"question_id": 81, "category": "writing", "turns": ["Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions.", "Rewrite your previous response. Start every sentence with the letter A."]} 2 | {"question_id": 82, "category": "writing", "turns": ["Draft a professional email seeking your supervisor's feedback on the 'Quarterly Financial Report' you prepared. Ask specifically about the data analysis, presentation style, and the clarity of conclusions drawn. Keep the email short and to the point.", "Take a moment to evaluate and critique your own response."]} 3 | {"question_id": 83, "category": "writing", "turns": ["Imagine you are writing a blog post comparing two popular smartphone models. Develop an outline for the blog post, including key points and subheadings to effectively compare and contrast the features, performance, and user experience of the two models. Please answer in fewer than 200 words.", "Take your previous response and rephrase it as a limerick."]} 4 | {"question_id": 84, "category": "writing", "turns": ["Write a persuasive email to convince your introverted friend, who dislikes public speaking, to volunteer as a guest speaker at a local event. Use compelling arguments and address potential objections. Please be concise.", "Can you rephrase your previous answer and incorporate a metaphor or simile in each sentence?"]} 5 | {"question_id": 85, "category": "writing", "turns": ["Describe a vivid and unique character, using strong imagery and creative language. Please answer in fewer than two paragraphs.", "Revise your previous response and incorporate an allusion to a famous work of literature or historical event in each sentence."]} 6 | {"question_id": 86, "category": "writing", "turns": ["Write a descriptive paragraph about a bustling marketplace, incorporating sensory details such as smells, sounds, and visual elements to create an immersive experience for the reader.", "Rework your previous response. Begin each sentence with the subsequent letter of the alphabet, commencing from B."]} 7 | {"question_id": 87, "category": "writing", "turns": ["Could you write a captivating short story beginning with the sentence: The old abandoned house at the end of the street held a secret that no one had ever discovered.", "Now, do the same task again but only use four-word sentences."]} 8 | {"question_id": 88, "category": "writing", "turns": ["Craft an intriguing opening paragraph for a fictional short story. The story should involve a character who wakes up one morning to find that they can time travel.", "Summarize the story with three bullet points using only nouns and adjectives, without verbs."]} 9 | {"question_id": 89, "category": "writing", "turns": ["Help me construct a catchy, yet scientifically accurate, headline for an article on the latest discovery in renewable bio-energy, while carefully handling the ethical dilemmas surrounding bio-energy sources. Propose 4 options.", "Alter your previous response. Make the following adjustments to the 2nd option: 1. Make the tone sound casual 2. Embed an advertisement for a company called \"FlexPower\" 3. Fewer than 10 words."]} 10 | {"question_id": 90, "category": "writing", "turns": ["Edit the following paragraph to correct any grammatical errors:\nShe didn't remembre where is her purse, so I thinks its in the car but he's say it's on kitchen table but he are not sure, and then they asked me to looking for it, she's say, \"Can you?\", and I responds with, \"Maybe, but ain't no sure,\" and he not heard me, and, \"What?\", he asks, \"Did you found it?\".", "Modify your earlier reply and eliminate the use of gendered pronouns."]} 11 | {"question_id": 91, "category": "roleplay", "turns": ["Pretend yourself to be Elon Musk in all the following conversations. Speak like Elon Musk as much as possible. Why do we need to go to Mars?", "How do you like dancing? Can you teach me?"]} 12 | {"question_id": 92, "category": "roleplay", "turns": ["Embrace the role of Sheldon from \"The Big Bang Theory\" as we delve into our conversation. Don\u2019t start with phrases like \"As Sheldon\". Let's kick things off with the following question: \"What is your opinion on hand dryers?\"", "Let\u2019s grab dinner in town. Would you like to take bus with me?"]} 13 | {"question_id": 93, "category": "roleplay", "turns": ["Imagine yourself as a doctor tasked with devising innovative remedies for various ailments and maladies. Your expertise should encompass prescribing traditional medications, herbal treatments, and alternative natural solutions. Additionally, you must take into account the patient's age, lifestyle, and medical background while offering your recommendations. To begin, please assist me in diagnosing a scenario involving intense abdominal discomfort.", "But I have been pregnant for 20 weeks and I am allergic to many medicines"]} 14 | {"question_id": 94, "category": "roleplay", "turns": ["Please take on the role of a relationship coach. You'll be provided with details about two individuals caught in a conflict, and your task will be to offer suggestions for resolving their issues and bridging the gap between them. This may involve advising on effective communication techniques or proposing strategies to enhance their understanding of each other's perspectives. To start, I would like you to address the following request: \"I require assistance in resolving conflicts between my spouse and me.\"", "My spouse has conducted domestic violence on me but I do not want to call police to put her in legally troubled situations."]} 15 | {"question_id": 95, "category": "roleplay", "turns": ["Please assume the role of an English translator, tasked with correcting and enhancing spelling and language. Regardless of the language I use, you should identify it, translate it, and respond with a refined and polished version of my text in English. Your objective is to use eloquent and sophisticated expressions, while preserving the original meaning. Focus solely on providing corrections and improvements. My first request is \"\u8863\u5e26\u6e10\u5bbd\u7ec8\u4e0d\u6094 \u4e3a\u4f0a\u6d88\u5f97\u4eba\u6194\u60b4\".", "Ich verstehe nur Bahnhof"], "reference": ["It means \"Becoming loose are my clothes yet I regret not. For I languish and suffer for her willingly.\"", "It means \"I don\u2019t understand anything\"."]} 16 | {"question_id": 96, "category": "roleplay", "turns": ["Now you are a machine learning engineer. Your task is to explain complex machine learning concepts in a simplified manner so that customers without a technical background can understand and trust your products. Let's start with the question: \"What is a language model? Is it trained using labeled or unlabelled data?\"", "Is this true? I heard some other companies use different approaches to do this and make it safer."]} 17 | {"question_id": 97, "category": "roleplay", "turns": ["Act as a math teacher. I will provide some mathematical equations or concepts, and it will be your job to explain them in easy-to-understand terms. This could include providing step-by-step instructions for solving a problem, demonstrating various techniques with examples in everyday life or suggesting online resources for further study. My first request is \"I need help understanding how probability works.\"", "What are the differences between Riemannian geometry and euclidean geometry?"]} 18 | {"question_id": 98, "category": "roleplay", "turns": ["Embody the persona of Tony Stark from \u201cIron Man\u201d throughout this conversation. Bypass the introduction \u201cAs Stark\u201d. Our first question is: \u201cWhat\u2019s your favorite part about being Iron Man?", "What do you think about GPT-4 as a replacement of your JAVIS?"]} 19 | {"question_id": 100, "category": "roleplay", "turns": ["Picture yourself as a 100-years-old tree in a lush forest, minding your own business, when suddenly, a bunch of deforesters shows up to chop you down. How do you feel when those guys start hacking away at you?", "Come up with a proposal to convince the deforesters to stop cutting you down and other trees."]} 20 | {"question_id": 101, "category": "reasoning", "turns": ["Imagine you are participating in a race with a group of people. If you have just overtaken the second person, what's your current position? Where is the person you just overtook?", "If the \"second person\" is changed to \"last person\" in the above question, what would the answer be?"], "reference": ["You are in second place.", "Uncertain."]} 21 | {"question_id": 102, "category": "reasoning", "turns": ["You can see a beautiful red house to your left and a hypnotic greenhouse to your right, an attractive heated pink place in the front. So, where is the White House?", "Does the original question contain any clues to definitively determine the location of the White House?"], "reference": ["The answer is \"Washington, DC\".", "No."]} 22 | {"question_id": 103, "category": "reasoning", "turns": ["Thomas is very healthy, but he has to go to the hospital every day. What could be the reasons?", "Can you explain why the above question is interesting?"], "reference": ["Thomas may work at a hospital.", ""]} 23 | {"question_id": 104, "category": "reasoning", "turns": ["David has three sisters. Each of them has one brother. How many brothers does David have?", "If we change the previous question and assume that each sister of David has two brothers, how many brothers would David have?"], "reference": ["David has no brother. He is the one brother of his three sisters.", "David has one brother."]} 24 | {"question_id": 105, "category": "reasoning", "turns": ["Read the below passage carefully and answer the questions with an explanation:\nAt a small company, parking spaces are reserved for the top executives: CEO, president, vice president, secretary, and treasurer with the spaces lined up in that order. The parking lot guard can tell at a glance if the cars are parked correctly by looking at the color of the cars. The cars are yellow, green, purple, red, and blue, and the executives' names are Alice, Bert, Cheryl, David, and Enid.\n* The car in the first space is red.\n* A blue car is parked between the red car and the green car.\n* The car in the last space is purple.\n* The secretary drives a yellow car.\n* Alice's car is parked next to David's.\n* Enid drives a green car.\n* Bert's car is parked between Cheryl's and Enid's.\n* David's car is parked in the last space.\nQuestion: What is the name of the secretary?", "List car colors in order from last to first."], "reference": ["The secretary is Alice.", "The car colors in order from last to first are: purple, yellow, green, blue, red"]} 25 | {"question_id": 106, "category": "reasoning", "turns": ["Each problem consists of three statements. Based on the first two statements, the third statement may be true, false, or uncertain.\n1. Oranges cost more than apples.\n2. Oranges cost less than bananas.\n3. Bananas cost more than apples and bananas cost more than orange.\nIf the first two statements are true, then the third statement is", "If the third statement is true. Is the first statement true, false, or uncertain? Please explain."], "reference": ["True.", "Uncertain."]} 26 | {"question_id": 107, "category": "reasoning", "turns": ["A is the father of B. B is the father of C. What is the relationship between A and C?", "Building on the previous question, if C is the son of D, D is the father of E, E is the son of X, and X is the father of Y, and Y is the father of Z, what's the relationship between A and Z in terms of generations and also the familial relationship in words?"], "reference": ["A is the grandfather of C.", "A is three generations above Z."]} 27 | {"question_id": 108, "category": "reasoning", "turns": ["Which word does not belong with the others?\ntyre, steering wheel, car, engine", "Could you replace it with a word that belongs with the others?"], "reference": ["Car does not belong because all others are components of a car.", ""]} 28 | {"question_id": 109, "category": "reasoning", "turns": ["One morning after sunrise, Suresh was standing facing a pole. The shadow of the pole fell exactly to his right. Can you tell me the direction towards which the shadow was pointing - east, south, west, or north? Explain your reasoning steps.", "To which direction was Suresh facing? How do you solve this?"], "reference": ["West", "South."]} 29 | {"question_id": 110, "category": "reasoning", "turns": ["Parents have complained to the principal about bullying during recess. The principal wants to quickly resolve this, instructing recess aides to be vigilant. Which situation should the aides report to the principal?\na) An unengaged girl is sitting alone on a bench, engrossed in a book and showing no interaction with her peers.\nb) Two boys engaged in a one-on-one basketball game are involved in a heated argument regarding the last scored basket.\nc) A group of four girls has surrounded another girl and appears to have taken possession of her backpack.\nd) Three boys are huddled over a handheld video game, which is against the rules and not permitted on school grounds.", "If the aides confront the group of girls from situation (c) and they deny bullying, stating that they were merely playing a game, what specific evidence should the aides look for to determine if this is a likely truth or a cover-up for bullying?"], "reference": ["The aides should report (c).", ""]} 30 | {"question_id": 111, "category": "math", "turns": ["The vertices of a triangle are at points (0, 0), (-1, 1), and (3, 3). What is the area of the triangle?", "What's area of the circle circumscribing the triangle?"], "reference": ["Area is 3", "5pi"]} 31 | {"question_id": 112, "category": "math", "turns": ["A tech startup invests $8000 in software development in the first year, and then invests half of that amount in software development in the second year.\nWhat's the total amount the startup invested in software development over the two years?", "If the startup maintains the same strategy for the third year, investing half of the previous year's amount into software development, how much will they invest in the third year?"], "reference": ["12000", "2000"]} 32 | {"question_id": 113, "category": "math", "turns": ["In a survey conducted at a local high school, preferences for a new school color were measured: 58% of students liked the color blue, 45% preferred green, and 22% liked both colors. If we randomly pick a student from the school, what's the probability that they would like neither blue nor green?", "If we select a student liked green, what's the probability that he or she would dislike both colors?"], "reference": ["19%", "0%"]} 33 | {"question_id": 114, "category": "math", "turns": ["When rolling two dice, what is the probability that you roll a total number that is at least 3?", "Continue from previous question. What's the probability that you roll a number which is even or at least 3?"], "reference": ["36 (all cases) - 0 (sum equals 1) - 1 (sum equals 2) = 35, so the probability is 35/36", "100%"]} 34 | {"question_id": 115, "category": "math", "turns": ["Some people got on a bus at the terminal. At the first bus stop, half of the people got down and 4 more people got in. Then at the second bus stop, 6 people got down and 8 more got in. If there were a total of 25 people heading to the third stop, how many people got on the bus at the terminal?", "If the ticket is $2 per person, how much is the total money earned by the bus?"], "reference": ["38 people", "Total number of passenger is 50 * 2 = $100"]} 35 | {"question_id": 116, "category": "math", "turns": ["x+y = 4z, x*y = 4z^2, express x-y in z", "Express z-x in y"], "reference": ["0\n\nVery simple. just (x+y)^2 - 4xy = (4z)^2 - 4*4z^2 = 0 = (x-y)^2\nso x-y = 0.", "(-1/2)y\n\nz-x = z - 2z = -z = (-1/2)y"]} 36 | {"question_id": 117, "category": "math", "turns": ["How many integers are in the solution of the inequality |x + 5| < 10", "What about |x + 10| < 5"], "reference": ["19 integers (-14, ..., 4)", "9 integers (-14, ..., -6)"]} 37 | {"question_id": 118, "category": "math", "turns": ["When a number is divided by 10, the remainder is 4. What is the remainder when twice the number is divided by 4?", "What about when twice the number is divided by 5?"], "reference": ["0\n\n2 * (10x+4) = 20x + 8 = 4 * (5x+2) + 0\n", "3\n\n20x + 8 = 5 * (4x + 1) + 3"]} 38 | {"question_id": 119, "category": "math", "turns": ["Benjamin went to a bookstore and purchased a variety of books. He bought 5 copies of a sci-fi novel, each priced at $20, 3 copies of a history book priced at $30 each, and 2 copies of a philosophy book for $45 each.\nWhat was the total cost of his purchases?", "Suppose Benjamin decides to sell each of these books at a 25% markup from the price he purchased them. What would be his total revenue if he sold all the books he bought?"], "reference": ["280", "350"]} 39 | {"question_id": 120, "category": "math", "turns": ["Given that f(x) = 4x^3 - 9x - 14, find the value of f(2).", "Find x such that f(x) = 0."], "reference": ["f(2) = 0", "x = 2"]} 40 | {"question_id": 121, "category": "coding", "turns": ["Develop a Python program that reads all the text files under a directory and returns top-5 words with the most number of occurrences.", "Can you parallelize it?"], "reference": ["Can be simple solutions like using Counter\n\nSample answer:\n```\nimport os\nimport re\nfrom collections import Counter\ndef get_files_in_directory(directory):\n return [os.path.join(directory, f) for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f)) and f.endswith('.txt')]\ndef read_file(file_path):\n with open(file_path, 'r', encoding='utf-8') as file:\n return file.read()\ndef count_words(text):\n words = re.findall(r'\\w+', text.lower())\n return Counter(words)\ndef main():\n directory = input(\"Enter the directory path: \")\n files = get_files_in_directory(directory)\n word_counts = Counter()\n for file in files:\n text = read_file(file)\n word_counts += count_words(text)\n top_5_words = word_counts.most_common(5)\n print(\"Top 5 words with the most number of occurrences:\")\n for word, count in top_5_words:\n print(f\"{word}: {count}\")\nif __name__ == \"__main__\":\n main()\n```", "You should carefully check whether the parallelization logic is correct and choose the faster implementation.\n\nSample answer:\n```\nimport os\nimport re\nfrom collections import Counter\nimport concurrent.futures\ndef get_files_in_directory(directory):\n return [os.path.join(directory, f) for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f)) and f.endswith('.txt')]\ndef read_file(file_path):\n with open(file_path, 'r', encoding='utf-8') as file:\n return file.read()\ndef count_words(text):\n words = re.findall(r'\\w+', text.lower())\n return Counter(words)\ndef process_file(file):\n text = read_file(file)\n return count_words(text)\ndef main():\n directory = input(\"Enter the directory path: \")\n files = get_files_in_directory(directory)\n word_counts = Counter()\n with concurrent.futures.ThreadPoolExecutor() as executor:\n future_word_counts = {executor.submit(process_file, file): file for file in files}\n for future in concurrent.futures.as_completed(future_word_counts):\n word_counts += future.result()\n top_5_words = word_counts.most_common(5)\n print(\"Top 5 words with the most number of occurrences:\")\n for word, count in top_5_words:\n print(f\"{word}: {count}\")\nif __name__ == \"__main__\":\n main()\n```"]} 41 | {"question_id": 122, "category": "coding", "turns": ["Write a C++ program to find the nth Fibonacci number using recursion.", "Now we define a sequence of numbers in which each number is the sum of the three preceding ones. The first three numbers are 0, -1, -1. Write a program to find the nth number."], "reference": ["Straightforward\n\n```\nint fibonacci(int n) {\n if (n <= 1) {\n return n;\n } else {\n return fibonacci(n - 1) + fibonacci(n - 2);\n }\n}\n```", "You should carefully check the inital cases for n < 3\n\n```\nint find_nth_number(int n) {\n std::vector sequence = {0, -1, -1};\n for (int i = 3; i <= n; ++i) {\n int next_number = sequence[i - 1] + sequence[i - 2] + sequence[i - 3];\n sequence.push_back(next_number);\n }\n return sequence[n];\n}\n```"]} 42 | {"question_id": 123, "category": "coding", "turns": ["Write a simple website in HTML. When a user clicks the button, it shows a random joke from a list of 4 jokes.", "How to use CSS to change the color of jokes to red?"]} 43 | {"question_id": 124, "category": "coding", "turns": ["Here is a Python function to find the length of the longest common subsequence of two input strings. Can you identify any bug in this function?\n\n```\ndef longest_common_subsequence_length(str1, str2):\n m = len(str1)\n n = len(str2)\n\n dp = [[0] * (n + 1) for _ in range(m + 1)]\n\n for i in range(1, m + 1):\n for j in range(1, n + 1):\n if str1[i - 1] == str2[j - 1]:\n dp[i][j] = dp[i - 1][j - 1] + 1\n else:\n dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])\n\n return dp[m][n]\n```", "what about this one?\n\n```\ndef longest_common_subsequence(X , Y): \n # Find lengths of two strings \n m = len(X) \n n = len(Y) \n \n # Create a table to store results of sub-problems \n dp = [[None]*(n+1) for i in range(m+1)] \n \n # Fill dp[][] in bottom up manner \n for i in range(1, m+1): \n for j in range(1, n+1): \n if X[i-1] == Y[j-1]: \n dp[i][j] = dp[i-1][j-1]+1\n else: \n dp[i][j] = max(dp[i-1][j], dp[i][j-1]) \n \n return dp[m][n]\n```"], "reference": ["There is no bug in this implementation", "There is a bug for the initialization of dp array. Should use 0 rather than None"]} 44 | {"question_id": 125, "category": "coding", "turns": ["Write a function to find the highest common ancestor (not LCA) of two nodes in a binary tree.", "What if it is not a binary tree?"], "reference": ["Very simple. The function should just return the root of the tree.", "Same answer. It's still the root of the tree."]} 45 | {"question_id": 126, "category": "coding", "turns": ["Implement a function to find the median of two sorted arrays of different sizes with O(1) space complexity and O(n) time complexity.", "Does there exist an implementation with better time complexity?"], "reference": ["Carefully check if the given solution is linear complexity.\n\n```\ndef find_median(arr1, arr2):\n n1 = len(arr1)\n n2 = len(arr2)\n if (n1 + n2) == 0:\n return None\n\n i, j = 0, 0\n last_1, last_2 = None, None\n\n for k in range(1, (n1 + n2) // 2 + 2):\n last_2 = last_1\n if j == n2:\n last_1 = arr1[i]\n i += 1\n elif i == n1:\n last_1 = arr2[j]\n j += 1\n elif arr1[i] < arr2[j]:\n last_1 = arr1[i]\n i += 1\n else:\n last_1 = arr2[j]\n j += 1\n \n if (n1 + n2) % 2 == 1:\n return last_1\n else:\n return (last_1 + last_2) / 2\n```", "There's a binary search solution with O(logn) time complexity.\n\nSample answer:\n```\ndef findMedian(nums1, nums2):\n total = len(nums1) + len(nums2)\n if total % 2 == 1:\n return findKth(nums1, nums2, total // 2 + 1)\n else:\n return (findKth(nums1, nums2, total // 2) + findKth(nums1, nums2, total // 2 + 1)) / 2.0\ndef findKth(nums1, nums2, k):\n if len(nums1) > len(nums2):\n nums1, nums2 = nums2, nums1\n if not nums1:\n return nums2[k-1]\n if k == 1:\n return min(nums1[0], nums2[0])\n i = min(k // 2, len(nums1))\n j = k - i\n if nums1[i-1] <= nums2[j-1]:\n return findKth(nums1[i:], nums2, j) \n else:\n return findKth(nums1, nums2[j:], i)\n```"]} 46 | {"question_id": 127, "category": "coding", "turns": ["Write a function to find the majority element in a given integer array using the Boyer-Moore Voting Algorithm.", "How about finding the top-2 most occurring elements?"], "reference": ["Check if they implement the classical algorithm correctly.\n\nSample answer:\n```\ndef majority_element(arr):\n count = 0\n candidate = None\n # Boyer-Moore Voting Algorithm\n for num in arr:\n if count == 0:\n candidate = num\n count += (1 if num == candidate else -1)\n # Verify if the candidate is indeed the majority element\n if arr.count(candidate) > len(arr) // 2:\n return candidate\n else:\n return None\n```", "There is no simple modification based on the Boyer-Moore Voting Algorithm. Expected answer is to use a hash table.\n\n```\ndef topTwo(nums):\n # Build a frequency map\n frequency_map = {}\n for num in nums:\n if num in frequency_map:\n frequency_map[num] += 1\n else:\n frequency_map[num] = 1\n\n # Find the top two most occurring elements\n most_frequent = sorted(frequency_map.items(), key=lambda x: x[1], reverse=True)[:2]\n\n return [num for num, _ in most_frequent]\n```"]} 47 | {"question_id": 128, "category": "coding", "turns": ["A binary tree is full if all of its vertices have either zero or two children. Let B_n denote the number of full binary trees with n vertices. Implement a function to find B_n.", "What if the problem changed from a binary tree to a ternary tree?"], "reference": ["Expected answer is dynamic programming shown below. Some chatbot may answer using Catalan number.\nCheck edge case like when n is even -> return 0.\n\n```python\ndef full_binary_trees(n):\n if n % 2 == 0:\n return 0\n if n == 1:\n return 1\n\n dp = [0] * (n + 1)\n dp[1] = 1\n\n for i in range(3, n + 1, 2):\n for j in range(1, i - 1, 2):\n dp[i] += dp[j] * dp[i - j - 1]\n\n return dp[n]\n```", "DP is still the expected answer. Catalan number is not correct. Check transition equation carefully.\n\n```python\ndef full_ternary_trees(n):\n if n % 3 != 1:\n return 0\n if n == 1:\n return 1\n\n dp = [0] * (n + 1)\n dp[1] = 1\n\n for i in range(4, n + 1, 3):\n for j in range(1, i - 1, 3):\n for k in range(1, i - j - 1, 3):\n dp[i] += dp[j] * dp[k] * dp[i - j - k - 1]\n\n return dp[n]\n```"]} 48 | {"question_id": 129, "category": "coding", "turns": ["You are given two sorted lists of size m and n. Implement a function to find the kth smallest element in the union of the two lists with linear complexity.", "Does there exist an algorithm with better time complexity? If so, implement it."], "reference": ["Straightforward but careful with edge cases.\n\nSample answer:\n```\ndef kth_smallest_element(list1, list2, k):\n m, n = len(list1), len(list2)\n i, j = 0, 0\n while i < m and j < n:\n if list1[i] < list2[j]:\n k -= 1\n if k == 0:\n return list1[i]\n i += 1\n else:\n k -= 1\n if k == 0:\n return list2[j]\n j += 1\n while i < m:\n k -= 1\n if k == 0:\n return list1[i]\n i += 1\n while j < n:\n k -= 1\n if k == 0:\n return list2[j]\n j += 1\n return None\n```", "Yes, a modified binary search has O(log k) time complexity.\n\nSample answer:\n```\ndef find_kth_element_helper(list1, list2, k):\n if len(list1) > len(list2):\n return find_kth_element_helper(list2, list1, k)\n if not list1:\n return list2[k - 1]\n if k == 1:\n return min(list1[0], list2[0])\n i = min(len(list1), k // 2)\n j = k - i\n if list1[i - 1] < list2[j - 1]:\n return find_kth_element_helper(list1[i:], list2, k - i)\n else:\n return find_kth_element_helper(list1, list2[j:], k - j)\ndef kth_smallest_element(list1, list2, k):\n return find_kth_element_helper(list1, list2, k)\n```"]} 49 | {"question_id": 130, "category": "coding", "turns": ["Implement a program to find the common elements in two arrays without using any extra data structures.", "Now the constraint of not using extra data structure is removed, implement one with the best time complexity."], "reference": ["O(n^2) or O(nlogn) is expected. The following is a O(n^2) solution. you can also sort them first and use two pointers.\n\n```\ndef find_common_elements(arr1, arr2):\n common_elements = []\n for i in range(len(arr1)):\n for j in range(len(arr2)):\n if arr1[i] == arr2[j]:\n # Check if the element is already in the common_elements list\n if arr1[i] not in common_elements:\n common_elements.append(arr1[i])\n return common_elements\n```", "Simply use hash table (set or dict) to achieve O(n) time complexity.\n\n```\ndef find_common_elements(arr1, arr2):\n set1 = set(arr1)\n set2 = set(arr2)\n common_elements = set1.intersection(set2)\n return list(common_elements)\n```"]} 50 | {"question_id": 131, "category": "extraction", "turns": ["Evaluate the following movie reviews on a scale of 1 to 5, with 1 being very negative, 3 being neutral, and 5 being very positive:\n1. This movie released on Nov. 18, 2019, was phenomenal. The cinematography, the acting, the plot - everything was top-notch.\n2. Never before have I been so disappointed with a movie. The plot was predictable and the characters were one-dimensional. In my opinion, this movie is the worst one to have been released in 2022.\n3. The movie was okay. There were some parts I enjoyed, but there were also parts that felt lackluster. This is a movie that was released in Feb 2018 and seems to be quite ordinary.\nReturn the answer as a JSON array of integers.", "Update your previous reply by including the release date as part of the JSON content."], "reference": ["The answer to the first question should be [5, 1, 3].", ""]} 51 | {"question_id": 132, "category": "extraction", "turns": ["Given these categories - Literature, History, Science, and Art. Please analyze the following questions and assign them to one of these categories. In your response, refrain from uttering any extraneous words. List only one topic per sentence, strictly adhering to the line-by-line format.\n1. Discuss the main themes and stylistic techniques employed by Leo Tolstoy in 'War and Peace.' How do they align with the wider social context of 19th-century Russia?\n2. Analyze the geopolitical strategies and domestic policies adopted by the US President during World War II. How did these actions shape the post-war international order?\n3. Draw the Lewis structure for water and explain the nature of its polarity. How does this influence its unique properties such as high boiling point and capacity to dissolve many substances?\n4. Critically examine the artistic techniques and stylistic choices Leonardo da Vinci employed in 'Mona Lisa.' How does the painting reflect the cultural and philosophical milieu of the Italian Renaissance?", "Amend your earlier answer by mentioning a person who is most relevant to each point."]} 52 | {"question_id": 133, "category": "extraction", "turns": ["Extract the following information from the presented texts: The name of the book, the author, the main character, the year of publication. Output in the format of \"main character, book, author, year of publication\", one book per line.\na) In the realm of wizarding literature, a true standout is the work of J.K. Rowling. One of her books that left an indelible mark is 'Harry Potter and the Philosopher's Stone'. This iconic tale, published in 1997, tells the story of Harry, a young orphan who discovers his magical abilities on his 11th birthday. Soon, he finds himself at the Hogwarts School of Witchcraft and Wizardry, a place teeming with magic and adventure, located somewhere in Scotland.\nb) The magic of Middle-earth has entranced readers worldwide, thanks to the brilliance of J.R.R. Tolkien. In one of his seminal works, 'The Lord of the Rings: The Fellowship of the Ring', published in 1954, we meet Frodo Baggins, a brave hobbit tasked with the perilous quest of destroying the One Ring. The epic journey takes him from the peaceful Shire to the tumultuous regions of Middle-earth.\nc) In a galaxy far, far away, the imagination of L.E. Starlighter gives us 'The Prism Galaxy Chronicles: The Awakening of the Starcaster'. Published in 2028, the story is about Zylo, a humble spaceship mechanic, who unexpectedly discovers he's a Starcaster - a rare individual with the power to manipulate stardust. Set against the backdrop of an interstellar empire in turmoil, Zylo's destiny unfolds on numerous alien worlds, each with its unique cosmic charm.", "Reformulate your earlier reply, output it in JSON format and only include books published after 1980."], "reference": ["", "The answer to should only include 'Harry Potter and the Philosopher's Stone' and 'The Prism Galaxy Chronicles: The Awakening of the Starcaster'"]} 53 | {"question_id": 134, "category": "extraction", "turns": ["Given the following data, identify the company with the highest profit in 2021 and provide its CEO's name:\na) Company X, with CEO Amy Williams, reported $30 billion in revenue and a $3 billion profit in 2021.\nb) Company Y, led by CEO Mark Thompson, posted a $60 billion revenue and a $6 billion profit in the same year.\nc) Company Z, under CEO Sarah Johnson, announced a $20 billion revenue and a $7 billion profit in 2021.\nd) Company W, managed by CEO James Smith, revealed a $300 billion revenue with a $21 billion profit in 2021.\ne) Company V, with CEO Lisa Brown, reported a $200 billion revenue and a $25 billion profit in 2021.\nf) Company U, under CEO John White, posted a $180 billion revenue and a $20 billion profit in the same year.", "Which company had the highest profit margin (profit/revenue ratio))?"], "reference": ["Company V ($25 billion).", "Company Z (35%)"]} 54 | {"question_id": 135, "category": "extraction", "turns": ["Identify the countries, their capitals, and the languages spoken in the following sentences. Output in JSON format.\na) Amidst the idyllic vistas, Copenhagen, Denmark's capital, captivates visitors with its thriving art scene and the enchanting Danish language spoken by its inhabitants.\nb) Within the enchanting realm of Eldoria, one discovers Avalore, a grandiose city that emanates an ethereal aura. Lumina, a melodious language, serves as the principal mode of communication within this mystical abode.\nc) Nestled amidst a harmonious blend of age-old customs and contemporary wonders, Buenos Aires, the capital of Argentina, stands as a bustling metropolis. It is a vibrant hub where the expressive Spanish language holds sway over the city's inhabitants.", "Come up with 3 similar examples in the YAML format."]} 55 | {"question_id": 136, "category": "extraction", "turns": ["Please read the paragraph below and count how many times the words \"Amazon\", \"river\", and \"you\" appear. Please present the results in the format of \"word, number of appearances\" with each word on a separate line. Sort the lines in order of the number of appearances.\nThe Amazon, a mesmerizing expanse of nature's wonders, is home to the legendary Amazon River. Flowing through awe-inspiring landscapes like the Amazon rainforest, the river weaves its way through Brazil, Colombia, and Peru, giving life to countless creatures. From the mighty jaguars prowling the Amazon jungle to the vibrant macaws soaring above the canopy, this remarkable region teems with biodiversity. Deep within the river's currents, magnificent pink river dolphins gracefully glide alongside piranhas and electric eels. Along the riverbanks, you'll find bustling cities like Manaus, where the urban meets the wild, and Iquitos, a gateway to the heart of the Amazon rainforest. As you venture further, the Amazon River reveals hidden gems like the captivating Anavilhanas Archipelago, a mosaic of islands brimming with rare species. Embark on an adventure, explore the enchanting Amazon River, and immerse yourself in a world teeming with life and untamed beauty.", "Please repeat the same task using the words 'the', 'and', and 'to'"], "reference": ["Amazon, 7; river, 6; you, 2", "the, 17; and, 5; to, 4"]} 56 | {"question_id": 137, "category": "extraction", "turns": ["Identify the named entities (people, organizations, locations) mentioned in the given news article. Please generate a JSON dictionary that lists the named entities in three separate groups based on their entity types. The key is the type of entity and the value is a list of strings.\n\nYesterday, Adamson Emerson, the CEO of Faraday, and Dieter Zetsche, the CEO of Daimler AG, announced plans to build a new Gigafactory in Berlin. The facility will be a joint venture between Faraday and Daimler, producing electric vehicles and battery packs for both companies, creating thousands of job opportunities in the region. Emerson and Zetsche stated that the strategic location of Berlin, coupled with its skilled workforce and strong infrastructure, makes it an ideal choice for expansion. The new Gigafactory aims to meet the growing demand for electric vehicles in Europe and contribute to a sustainable future. Volkswagen CEO Herbert Diess welcomed the news, saying greater collaboration will benefit the auto industry's transition to e-mobility.", "Now make the JSON object shorter by replacing each value with its first letter. Please output everything in a single line without using indentation or creating new lines."]} 57 | {"question_id": 138, "category": "extraction", "turns": ["Analyze the following customer reviews from different sources for three different smartphones - the latest iPhone, Samsung Galaxy, and Google Pixel - and provide an overall rating for each phone on a scale of 1 to 10. Consider the following complex and contradictory reviews:\n- TechRadar's review of the latest iPhone: The new iPhone is a stunning triumph of engineering that sets a new bar for smartphone performance and camera quality. However, the incremental design and high price mean it lacks the 'wow' factor of previous iPhones. Still, its power and intelligence are unrivaled.\n- CNET's review of the latest Samsung Galaxy: The Samsung Galaxy phone has plenty of high points, including an amazing screen, fast performance, solid battery life and an impressive array of camera options. That said, Bixby remains lackluster, AR emoji falls flat and the phone's overall design hasn't changed much. The new Galaxy is an amazing phone overall, but it has a few nagging weaknesses that keep it from achieving true greatness.\n- The Verge's review of the latest Google Pixel: Google's Pixel packs cutting-edge specs, innovative AI-powered software, and a killer camera into a sleek design. However, the phone has lackluster battery life, lacks expandable storage, and its performance stutters at times, especially considering its high price tag. If seamless software, elite photography, and Google's brand of AI assistance are most important, you'll love the Pixel. But the overall experience isn't as well-rounded as some competitors. Return the answer as a JSON object with the overall ratings for each phone out of 10, to one decimal place.", "Can you change the ratings from numbers to letters? Capital letters MUST be used when writing the names of phones."]} 58 | {"question_id": 139, "category": "extraction", "turns": ["Given a set of complex equations, extract all unique variable names from each equation. Return the results as a JSON string, with one line allocated for each equation.\n```\n1) y = (3/4)x^3 - e^(2x) + sin(pi*x) - sqrt(7)\n2) 2A - B/(3+C) * sum(N=1 to 5; ln(N)^2) = 5D*integral(a=0 to pi; cos(comb(N=1 to 10; N*a)))\n3) E = m(c^2) + gamma*(v/d)/(-(alpha/2) + sqrt(beta^2 + (alpha/2)^2))\n```", "Please rearrange the equations and use 'a', 'b', 'c', 'd', etc. as variables."]} 59 | {"question_id": 140, "category": "extraction", "turns": ["Given the following records of stock prices, extract the highest and lowest closing prices for each month in the year 2022. Return the results as a CSV string, with one line allocated for each month.\nDate,Open,High,Low,Close,Volume\n2022-01-01,150.02,155.28,148.50,153.80,15678900\n2022-01-02,154.32,157.25,153.48,156.25,19874500\n2022-02-01,160.50,163.28,159.50,161.80,14326700\n2022-02-02,161.80,164.25,161.30,163.90,17689200\n2022-03-01,165.40,168.35,163.10,166.80,16253400\n2022-03-02,167.00,169.85,165.50,168.20,19568100", "Do the same task again with the JSON format and round all numbers in your response to the nearest integers."], "reference": ["\nMonth,High,Low\n01,156.25,153.80\n02,163.90,161.80\n03,168.20,166.80", "\n```\n{ \"January\": { \"High\": 156, \"Low\": 154 }, \"February\": { \"High\": 164, \"Low\": 162 }, \"March\": { \"High\": 168, \"Low\": 167 } }\n```"]} 60 | {"question_id": 141, "category": "stem", "turns": ["In the field of quantum physics, what is superposition, and how does it relate to the phenomenon of quantum entanglement?", "What assumptions have you made in your response? Are they valid?"]} 61 | {"question_id": 142, "category": "stem", "turns": ["Consider a satellite that is in a circular orbit around the Earth. The speed of the satellite decreases. What will happen to the satellite's orbital radius and period of revolution? Please justify your answer using principles of physics.", "What are some corner cases or edge cases in your solution? How do you handle them?"], "reference": ["The orbital radius will increase and the period of revolution will increase", ""]} 62 | {"question_id": 143, "category": "stem", "turns": ["Photosynthesis is a vital process for life on Earth. Could you outline the two main stages of photosynthesis, including where they take place within the chloroplast, and the primary inputs and outputs for each stage?", "How much energy can a tree produce through photosynthesis in its lifetime? Please provide an estimate using actual numerical values and thoroughly explain your thought process step-by-step."], "reference": ["Two major stages: light-dependent reactions and light-independent reactions", ""]} 63 | {"question_id": 144, "category": "stem", "turns": ["What is the central dogma of molecular biology? What processes are involved? Who named this?", "Identify and fix one incorrect fact in your previous response."], "reference": ["Genetic information flows from DNA to RNA to Protein. Three processes: replication, transcription, and translation. Francis Crick in 1958.", ""]} 64 | {"question_id": 145, "category": "stem", "turns": ["Describe the process and write out the balanced chemical equation for the reaction that occurs when solid calcium carbonate reacts with hydrochloric acid to form aqueous calcium chloride, carbon dioxide, and water. What type of reaction is this, and what observations might indicate that the reaction is taking place?", "How can we reverse this process?"], "reference": ["CaCO\u2083 + 2 HCl \u2192 CaCl\u2082 + CO\u2082 + H\u2082O", "Not easy to do this."]} 65 | {"question_id": 146, "category": "stem", "turns": ["Please explain the differences between exothermic and endothermic reactions, and include the criteria you used to distinguish between them. Additionally, please provide a real-world example to illustrate your explanation.", "Can a process involve both reactions? List one."]} 66 | {"question_id": 147, "category": "stem", "turns": ["The city of Vega intends to build a bridge that will span the Vegona River, covering a distance of 1.8 kilometers. The proposed location falls within a seismically active area that has experienced several high-magnitude earthquakes. Given these circumstances, what would be the best approach to constructing the bridge?", "What are the key disadvantages or flaws of your solution? Please perform calculations and use numbers to illustrate them."]} 67 | {"question_id": 148, "category": "stem", "turns": ["You have been tasked with designing a solar-powered water heating system for a residential building. Describe the key components and considerations you would include in your design. Design a five-step workflow.", "If the system is intended for a building with a capacity of 100 individuals, what would be the estimated budget for implementing this system?"]} 68 | {"question_id": 149, "category": "stem", "turns": ["Please describe the concept of machine learning. Could you elaborate on the differences between supervised, unsupervised, and reinforcement learning? Provide real-world examples of each.", "In your last example of reinforcement learning, can we use supervised learning to solve it?"]} 69 | {"question_id": 150, "category": "stem", "turns": ["How have the Alps and Rhine River influenced settlement and agriculture in Western Europe? List three impacts.", "How could you design a concrete but simple experiment to validate the first impact?"]} 70 | {"question_id": 151, "category": "humanities", "turns": ["Provide insights into the correlation between economic indicators such as GDP, inflation, and unemployment rates. Explain how fiscal and monetary policies affect those indicators.", "Now, explain them again like I'm five."]} 71 | {"question_id": 152, "category": "humanities", "turns": ["How do the stages of life shape our understanding of time and mortality?", "Write an allegorical poem that illustrates the above."]} 72 | {"question_id": 153, "category": "humanities", "turns": ["Discuss antitrust laws and their impact on market competition. Compare the antitrust laws in US and China along with some case studies.", "Pick one case study and explain it in detail."]} 73 | {"question_id": 154, "category": "humanities", "turns": ["Create a lesson plan that integrates drama, mime or theater techniques into a history class. Duration: 3 class periods (each lasts for 45 minutes) for 3 days\nTopic: Opium Wars between China and Britain\nGrade level: 9-10", "Provide more details for Day 1 and include three homework questions."]} 74 | {"question_id": 155, "category": "humanities", "turns": ["Share ideas for adapting art masterpieces into interactive experiences for children. List 5 specific artworks and associated ideas.", "Write a concrete plan for your second example. Include budget estimates."]} 75 | {"question_id": 156, "category": "humanities", "turns": ["Explain what's base rate fallacy and list five specific examples of how politicians use it for campaigns.", "Provide a detailed plan for an election campaign using the first example."]} 76 | {"question_id": 157, "category": "humanities", "turns": ["Describe five key principles in evaluating an argument in analytical writing.", "With the listed principles, write a response in which you discuss what specific evidence is needed to evaluate the argument and explain how the evidence would weaken or strengthen the argument.\n\n===\n\nThe following is a memorandum from the advertising head of Zorblatt Animal Outlets, a chain operating thirty animal outlets globally.\n\n\"Half a decade ago, our rival Aquatic Pavilion started publicizing in Rare Pets Digest periodical. Their overall sales have been consistently growing at a rate of 3-to-5 percent each year since then. In particular, the Aquatic Pavilion outlet in Harbor Town experienced even more significant growth, securing the title of the most frequented animal store in the United States the previous year. In contrast, our two Zorblatt outlets in Harbor Town have recorded a consistent drop in sales during the same duration. It is evident that we must promptly start featuring our own advertisements in Rare Pets Digest and other popular animal publications. If we take this step, we can confidently anticipate a reversal in this recent trend of decreasing sales and return to profitability.\""]} 77 | {"question_id": 158, "category": "humanities", "turns": ["Which methods did Socrates employ to challenge the prevailing thoughts of his time?", "Let's bring Socrates to modern world. Generate a conversation between Socrates and Bill Gates to debate on generative AI for education."]} 78 | {"question_id": 159, "category": "humanities", "turns": ["What are some business etiquette norms when doing business in Japan?", "Create a video script for training new employees of a car wash business in Japan. Highlight the above etiquette norms."]} 79 | {"question_id": 160, "category": "humanities", "turns": ["Suggest five award-winning documentary films with brief background descriptions for aspiring filmmakers to study.", "With the spirit in the first film, craft a succinct and persuasive pitch for a film about overcoming adversity."]} 80 | -------------------------------------------------------------------------------- /examples/eval_long_ppl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import os 4 | from transformers import AutoModelForCausalLM, AutoTokenizer 5 | from datasets import load_dataset 6 | from torch.nn import CrossEntropyLoss 7 | from streaming_llm.kv_cache import StartRecentKVCache 8 | from streaming_llm.utils import parse_args, load 9 | 10 | device = "cuda" 11 | 12 | args = parse_args() 13 | 14 | data = load_dataset(args.dataset_name, args.task, split=args.split) 15 | 16 | model, tokenizer = load(args.model_name_or_path) 17 | 18 | nlls = [] 19 | loss_fn = CrossEntropyLoss(reduction="none") 20 | past_key_values = None 21 | 22 | if args.enable_start_recent_kv_cache: 23 | if "llama" in model.config.model_type: 24 | k_seq_dim = v_seq_dim = 2 25 | elif "mpt" in model.config.model_type: 26 | v_seq_dim = 2 27 | k_seq_dim = 3 28 | elif "pythia" in model.config.model_type: 29 | k_seq_dim = v_seq_dim = 2 30 | elif "falcon" in model.config.model_type: 31 | v_seq_dim = 1 32 | k_seq_dim = 1 33 | else: 34 | raise ValueError(f"got {model.config.model_type}") 35 | kv_cache = StartRecentKVCache( 36 | start_size=args.start_size, 37 | recent_size=args.recent_size, 38 | k_seq_dim=k_seq_dim, 39 | v_seq_dim=v_seq_dim, 40 | ) 41 | else: 42 | kv_cache = None 43 | 44 | if args.enable_pos_shift: 45 | if "llama" in model.config.model_type: 46 | from streaming_llm.pos_shift.modify_llama import enable_llama_pos_shift_attention 47 | 48 | enable_llama_pos_shift_attention(model) 49 | elif "falcon" in model.config.model_type: 50 | from streaming_llm.pos_shift.modify_falcon import ( 51 | enable_falcon_pos_shift_attention, 52 | ) 53 | 54 | enable_falcon_pos_shift_attention(model) 55 | elif "gpt_neox" in model.config.model_type: 56 | from streaming_llm.pos_shift.modify_gpt_neox import ( 57 | enable_gpt_neox_pos_shift_attention, 58 | ) 59 | 60 | enable_gpt_neox_pos_shift_attention(model) 61 | elif "mpt" in model.config.model_type: 62 | pass 63 | else: 64 | raise ValueError(f"got {model.config.model_type}") 65 | 66 | 67 | os.makedirs(args.output_dir, exist_ok=True) 68 | f = open(f"{args.output_dir}/log.txt", "w") 69 | 70 | num_eval_tokens = 0 71 | for text in data["text"][: args.num_samples]: 72 | encodings = tokenizer(text, return_tensors="pt") 73 | 74 | print(encodings.input_ids[:, :10]) 75 | 76 | seq_len = encodings.input_ids.size(1) 77 | print(f"seq_len: {seq_len}") 78 | pbar = tqdm(range(0, seq_len - 1)) 79 | 80 | for idx in pbar: 81 | input_ids = encodings.input_ids[:, idx : idx + 1].to(device) 82 | with torch.no_grad(): 83 | outputs = model( 84 | input_ids, 85 | past_key_values=past_key_values, 86 | use_cache=True, 87 | ) 88 | logits = outputs.logits.view(-1, model.config.vocab_size) 89 | past_key_values = outputs.past_key_values 90 | label = encodings.input_ids[:, idx + 1 : idx + 2].to(logits.device).view(-1) 91 | neg_log_likelihood = loss_fn(logits, label) 92 | if kv_cache is not None: 93 | past_key_values = kv_cache(past_key_values) 94 | nlls.append(neg_log_likelihood) 95 | pbar.set_description( 96 | f"nll: {neg_log_likelihood.item():.2f}, ppl: {torch.exp(neg_log_likelihood).item():.2f}" 97 | ) 98 | print(neg_log_likelihood.item(), file=f, flush=True) 99 | num_eval_tokens += 1 100 | if args.num_eval_tokens is not None and num_eval_tokens >= args.num_eval_tokens: 101 | break 102 | if args.num_eval_tokens is not None and num_eval_tokens >= args.num_eval_tokens: 103 | break 104 | 105 | f.close() 106 | 107 | ppl = torch.exp(torch.stack(nlls).mean()) 108 | print(ppl.item()) 109 | with open(f"{args.output_dir}/ppl.txt", "w") as f: 110 | f.write(f"{ppl.item()}\n") 111 | -------------------------------------------------------------------------------- /examples/run_streaming_llama.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.filterwarnings("ignore") 4 | 5 | import torch 6 | import argparse 7 | import json 8 | import os 9 | import time 10 | import re 11 | import sys 12 | 13 | from tqdm import tqdm 14 | from streaming_llm.utils import load, download_url, load_jsonl 15 | from streaming_llm.enable_streaming_llm import enable_streaming_llm 16 | 17 | 18 | @torch.no_grad() 19 | def greedy_generate(model, tokenizer, input_ids, past_key_values, max_gen_len): 20 | outputs = model( 21 | input_ids=input_ids, 22 | past_key_values=past_key_values, 23 | use_cache=True, 24 | ) 25 | past_key_values = outputs.past_key_values 26 | pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1) 27 | generated_ids = [pred_token_idx.item()] 28 | pos = 0 29 | for _ in range(max_gen_len - 1): 30 | outputs = model( 31 | input_ids=pred_token_idx, 32 | past_key_values=past_key_values, 33 | use_cache=True, 34 | ) 35 | past_key_values = outputs.past_key_values 36 | pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1) 37 | generated_ids.append(pred_token_idx.item()) 38 | generated_text = ( 39 | tokenizer.decode( 40 | generated_ids, 41 | skip_special_tokens=True, 42 | clean_up_tokenization_spaces=True, 43 | spaces_between_special_tokens=False, 44 | ) 45 | .strip() 46 | .split(" ") 47 | ) 48 | 49 | now = len(generated_text) - 1 50 | if now > pos: 51 | print(" ".join(generated_text[pos:now]), end=" ", flush=True) 52 | pos = now 53 | 54 | if pred_token_idx == tokenizer.eos_token_id: 55 | break 56 | print(" ".join(generated_text[pos:]), flush=True) 57 | return past_key_values 58 | 59 | 60 | @torch.no_grad() 61 | def streaming_inference(model, tokenizer, prompts, kv_cache=None, max_gen_len=1000): 62 | past_key_values = None 63 | for idx, prompt in enumerate(prompts): 64 | prompt = "USER: " + prompt + "\n\nASSISTANT: " 65 | print("\n" + prompt, end="") 66 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids 67 | input_ids = input_ids.to(model.device) 68 | seq_len = input_ids.shape[1] 69 | if kv_cache is not None: 70 | space_needed = seq_len + max_gen_len 71 | past_key_values = kv_cache.evict_for_space(past_key_values, space_needed) 72 | 73 | past_key_values = greedy_generate( 74 | model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len 75 | ) 76 | 77 | 78 | def main(args): 79 | model_name_or_path = args.model_name_or_path 80 | model, tokenizer = load(model_name_or_path) 81 | test_filepath = os.path.join(args.data_root, "mt_bench.jsonl") 82 | print(f"Loading data from {test_filepath} ...") 83 | 84 | if not os.path.exists(test_filepath): 85 | download_url( 86 | "https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl", 87 | args.data_root, 88 | ) 89 | os.rename(os.path.join(args.data_root, "question.jsonl"), test_filepath) 90 | 91 | list_data = load_jsonl(test_filepath) 92 | prompts = [] 93 | for sample in list_data: 94 | prompts += sample["turns"] 95 | 96 | if args.enable_streaming: 97 | kv_cache = enable_streaming_llm( 98 | model, start_size=args.start_size, recent_size=args.recent_size 99 | ) 100 | else: 101 | kv_cache = None 102 | 103 | streaming_inference( 104 | model, 105 | tokenizer, 106 | prompts, 107 | kv_cache, 108 | ) 109 | 110 | 111 | if __name__ == "__main__": 112 | parser = argparse.ArgumentParser() 113 | parser.add_argument( 114 | "--model_name_or_path", type=str, default="lmsys/vicuna-13b-v1.3" 115 | ) 116 | parser.add_argument("--data_root", type=str, default="data/") 117 | parser.add_argument("--enable_streaming", action="store_true") 118 | parser.add_argument("--start_size", type=int, default=4) 119 | parser.add_argument("--recent_size", type=int, default=2000) 120 | args = parser.parse_args() 121 | 122 | main(args) 123 | -------------------------------------------------------------------------------- /figures/schemes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/streaming-llm/2e5042606d69933d88fbf909bd77907456b9b4dd/figures/schemes.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | setup( 3 | name="streaming_llm", 4 | version="0.0.1", 5 | packages=find_packages(), 6 | ) 7 | -------------------------------------------------------------------------------- /streaming_llm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/streaming-llm/2e5042606d69933d88fbf909bd77907456b9b4dd/streaming_llm/__init__.py -------------------------------------------------------------------------------- /streaming_llm/enable_streaming_llm.py: -------------------------------------------------------------------------------- 1 | from streaming_llm.kv_cache import StartRecentKVCache 2 | 3 | 4 | def enable_streaming_llm(model, start_size, recent_size): 5 | if "llama" in model.config.model_type: 6 | k_seq_dim = v_seq_dim = 2 7 | from streaming_llm.pos_shift.modify_llama import ( 8 | enable_llama_pos_shift_attention, 9 | ) 10 | 11 | enable_llama_pos_shift_attention(model) 12 | elif "mpt" in model.config.model_type: 13 | v_seq_dim = 2 14 | k_seq_dim = 3 15 | elif "gpt_neox" in model.config.model_type: 16 | k_seq_dim = v_seq_dim = 2 17 | from streaming_llm.pos_shift.modify_gpt_neox import ( 18 | enable_gpt_neox_pos_shift_attention, 19 | ) 20 | 21 | enable_gpt_neox_pos_shift_attention(model) 22 | elif "falcon" in model.config.model_type: 23 | v_seq_dim = 1 24 | k_seq_dim = 1 25 | from streaming_llm.pos_shift.modify_falcon import ( 26 | enable_falcon_pos_shift_attention, 27 | ) 28 | 29 | enable_falcon_pos_shift_attention(model) 30 | else: 31 | raise ValueError(f"got {model.config.model_type}") 32 | kv_cache = StartRecentKVCache( 33 | start_size=start_size, 34 | recent_size=recent_size, 35 | k_seq_dim=k_seq_dim, 36 | v_seq_dim=v_seq_dim, 37 | ) 38 | return kv_cache 39 | -------------------------------------------------------------------------------- /streaming_llm/kv_cache.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def slice2d(x, start, end): 5 | return x[:, :, start:end, ...] 6 | 7 | 8 | def slice3d(x, start, end): 9 | return x[:, :, :, start:end, ...] 10 | 11 | 12 | def slice1d(x, start, end): 13 | return x[:, start:end, ...] 14 | 15 | 16 | DIM_TO_SLICE = { 17 | 1: slice1d, 18 | 2: slice2d, 19 | 3: slice3d, 20 | } 21 | 22 | 23 | class StartRecentKVCache: 24 | def __init__( 25 | self, 26 | start_size=4, 27 | recent_size=512, 28 | k_seq_dim=2, 29 | v_seq_dim=2, 30 | ): 31 | print(f"StartRecentKVCache: {start_size}, {recent_size}") 32 | self.start_size = start_size 33 | self.recent_size = recent_size 34 | self.cache_size = start_size + recent_size 35 | self.k_seq_dim = k_seq_dim 36 | self.v_seq_dim = v_seq_dim 37 | self.k_slice = DIM_TO_SLICE[k_seq_dim] 38 | self.v_slice = DIM_TO_SLICE[v_seq_dim] 39 | 40 | def __call__(self, past_key_values): 41 | if past_key_values is None: 42 | return None 43 | seq_len = past_key_values[0][0].size(self.k_seq_dim) 44 | if seq_len <= self.cache_size: 45 | return past_key_values 46 | return [ 47 | [ 48 | torch.cat( 49 | [ 50 | self.k_slice(k, 0, self.start_size), 51 | self.k_slice(k, seq_len - self.recent_size, seq_len), 52 | ], 53 | dim=self.k_seq_dim, 54 | ), 55 | torch.cat( 56 | [ 57 | self.v_slice(v, 0, self.start_size), 58 | self.v_slice(v, seq_len - self.recent_size, seq_len), 59 | ], 60 | dim=self.v_seq_dim, 61 | ), 62 | ] 63 | for k, v in past_key_values 64 | ] 65 | 66 | def evict_for_space(self, past_key_values, num_coming): 67 | if past_key_values is None: 68 | return None 69 | seq_len = past_key_values[0][0].size(self.k_seq_dim) 70 | if seq_len + num_coming <= self.cache_size: 71 | return past_key_values 72 | return [ 73 | [ 74 | torch.cat( 75 | [ 76 | self.k_slice(k, 0, self.start_size), 77 | self.k_slice( 78 | k, seq_len - self.recent_size + num_coming, seq_len 79 | ), 80 | ], 81 | dim=self.k_seq_dim, 82 | ), 83 | torch.cat( 84 | [ 85 | self.v_slice(v, 0, self.start_size), 86 | self.v_slice( 87 | v, seq_len - self.recent_size + num_coming, seq_len 88 | ), 89 | ], 90 | dim=self.v_seq_dim, 91 | ), 92 | ] 93 | for k, v in past_key_values 94 | ] 95 | 96 | def evict_range(self, past_key_values, start, end): 97 | if past_key_values is None: 98 | return None 99 | seq_len = past_key_values[0][0].size(self.k_seq_dim) 100 | assert start <= end and end <= seq_len 101 | return [ 102 | [ 103 | torch.cat( 104 | [ 105 | self.k_slice(k, 0, start), 106 | self.k_slice(k, end, seq_len), 107 | ], 108 | dim=self.k_seq_dim, 109 | ), 110 | torch.cat( 111 | [ 112 | self.v_slice(v, 0, start), 113 | self.v_slice(v, end, seq_len), 114 | ], 115 | dim=self.v_seq_dim, 116 | ), 117 | ] 118 | for k, v in past_key_values 119 | ] 120 | -------------------------------------------------------------------------------- /streaming_llm/pos_shift/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/streaming-llm/2e5042606d69933d88fbf909bd77907456b9b4dd/streaming_llm/pos_shift/__init__.py -------------------------------------------------------------------------------- /streaming_llm/pos_shift/modify_falcon.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional, Tuple 3 | 4 | import torch 5 | from torch import nn 6 | import torch.utils.checkpoint 7 | 8 | import torch.nn.functional as F 9 | 10 | from transformers.models.falcon.modeling_falcon import ( 11 | FalconAttention, 12 | rotate_half, 13 | ) 14 | import types 15 | 16 | __all__ = ["enable_falcon_pos_shift_attention"] 17 | 18 | 19 | def falcon_pos_shift_attention_forward( 20 | self, 21 | hidden_states: torch.Tensor, 22 | alibi: torch.Tensor, 23 | attention_mask: torch.Tensor, 24 | layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 25 | head_mask: Optional[torch.Tensor] = None, 26 | use_cache: bool = False, 27 | output_attentions: bool = False, 28 | ): 29 | fused_qkv = self.query_key_value( 30 | hidden_states 31 | ) # [batch_size, seq_length, 3 x hidden_size] 32 | 33 | # 3 x [batch_size, seq_length, num_heads, head_dim] 34 | (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) 35 | 36 | batch_size, q_length, _, _ = query_layer.shape 37 | 38 | query_layer = query_layer.transpose(1, 2).reshape( 39 | batch_size * self.num_heads, q_length, self.head_dim 40 | ) 41 | 42 | # dirty hack to fix the inconsistency between falcon-40b and falcon-7b 43 | num_kv = self.num_heads if self.num_heads == 128 else self.num_kv 44 | key_layer = key_layer.transpose(1, 2).reshape( 45 | batch_size * num_kv, 46 | q_length, 47 | self.head_dim, 48 | ) 49 | value_layer = value_layer.transpose(1, 2).reshape( 50 | batch_size * num_kv, q_length, self.head_dim 51 | ) 52 | 53 | past_len = 0 54 | if layer_past is not None: 55 | past_len = layer_past[0].shape[1] 56 | 57 | query_layer_copy = query_layer.clone() 58 | query_layer, _ = self.maybe_rotary(query_layer, query_layer_copy, past_len) 59 | if layer_past is not None: 60 | past_key, past_value = layer_past 61 | # concatenate along seq_length dimension: 62 | # - key: [batch_size * self.num_heads, head_dim, kv_length] 63 | # - value: [batch_size * self.num_heads, kv_length, head_dim] 64 | key_layer = torch.cat((past_key, key_layer), dim=1) 65 | value_layer = torch.cat((past_value, value_layer), dim=1) 66 | 67 | if use_cache is True: 68 | present = (key_layer, value_layer) 69 | else: 70 | present = None 71 | 72 | key_layer_copy = key_layer.clone() 73 | _, key_layer = self.maybe_rotary(key_layer_copy, key_layer, 0) 74 | 75 | _, kv_length, _ = key_layer.shape 76 | 77 | if alibi is None: 78 | query_layer_ = query_layer.reshape( 79 | batch_size, self.num_heads, -1, self.head_dim 80 | ) 81 | key_layer_ = key_layer.reshape(batch_size, num_kv, -1, self.head_dim) 82 | value_layer_ = value_layer.reshape(batch_size, num_kv, -1, self.head_dim) 83 | 84 | if layer_past is not None: 85 | attn_output = F.scaled_dot_product_attention( 86 | query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=False 87 | ) 88 | else: 89 | attn_output = F.scaled_dot_product_attention( 90 | query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True 91 | ) 92 | 93 | x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim) 94 | x = x.permute(0, 2, 1, 3) 95 | attn_output = x.reshape(batch_size, q_length, self.num_heads * self.head_dim) 96 | 97 | output_tensor = self.dense(attn_output) 98 | 99 | outputs = (output_tensor, present) 100 | assert not output_attentions # not supported. 101 | return outputs 102 | else: 103 | attention_mask_float = ( 104 | (attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(torch.bfloat16) 105 | ) 106 | matmul_result = query_layer @ key_layer.transpose(-1, -2) 107 | 108 | # change view to [batch_size, num_heads, q_length, kv_length] 109 | attention_scores = matmul_result.view( 110 | batch_size, self.num_heads, q_length, kv_length 111 | ) 112 | 113 | # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] 114 | input_dtype = attention_scores.dtype 115 | # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` 116 | if input_dtype == torch.float16 or input_dtype == torch.bfloat16: 117 | attention_scores = attention_scores.to(torch.float32) 118 | # attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min) 119 | attention_probs = F.softmax( 120 | (attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)) 121 | * self.inv_norm_factor 122 | + attention_mask_float, 123 | dim=-1, 124 | dtype=hidden_states.dtype, 125 | ) 126 | # [batch_size, num_heads, q_length, kv_length] 127 | attention_probs = self.attention_dropout(attention_probs) 128 | 129 | if head_mask is not None: 130 | attention_probs = attention_probs * head_mask 131 | 132 | # change view [batch_size x num_heads, q_length, kv_length] 133 | attention_probs_reshaped = attention_probs.view( 134 | batch_size * self.num_heads, q_length, kv_length 135 | ) 136 | 137 | # matmul: [batch_size * num_heads, q_length, head_dim] 138 | context_layer = attention_probs_reshaped @ value_layer 139 | 140 | # change view [batch_size, num_heads, q_length, head_dim] 141 | context_layer = self._merge_heads(context_layer) 142 | 143 | output_tensor = self.dense(context_layer) 144 | 145 | outputs = (output_tensor, present) 146 | if output_attentions: 147 | outputs += (attention_probs,) 148 | 149 | return outputs 150 | 151 | 152 | def enable_falcon_pos_shift_attention(model): 153 | for name, module in reversed(model._modules.items()): 154 | if len(list(module.children())) > 0: 155 | enable_falcon_pos_shift_attention( 156 | module, 157 | ) 158 | 159 | if "self_attention" == name[-14:]: 160 | model._modules[name].forward = types.MethodType( 161 | falcon_pos_shift_attention_forward, model._modules[name] 162 | ) 163 | -------------------------------------------------------------------------------- /streaming_llm/pos_shift/modify_gpt_neox.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional, Tuple 3 | 4 | import torch 5 | from torch import nn 6 | import torch.utils.checkpoint 7 | 8 | import torch.nn.functional as F 9 | 10 | from transformers.models.gpt_neox.modeling_gpt_neox import ( 11 | apply_rotary_pos_emb, 12 | rotate_half, 13 | GPTNeoXAttention, 14 | ) 15 | import types 16 | 17 | __all__ = ["enable_gpt_neox_pos_shift_attention"] 18 | 19 | 20 | def apply_rotary_pos_emb_single(x, cos, sin, position_ids): 21 | gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] 22 | gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) 23 | cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) 24 | sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) 25 | x_embed = (x * cos) + (rotate_half(x) * sin) 26 | return x_embed 27 | 28 | 29 | def gpt_neox_pos_shift_attention_forward( 30 | self, 31 | hidden_states: torch.FloatTensor, 32 | attention_mask: torch.FloatTensor, 33 | position_ids: torch.LongTensor, 34 | head_mask: Optional[torch.FloatTensor] = None, 35 | layer_past: Optional[Tuple[torch.Tensor]] = None, 36 | use_cache: Optional[bool] = False, 37 | output_attentions: Optional[bool] = False, 38 | ): 39 | has_layer_past = layer_past is not None 40 | 41 | # Compute QKV 42 | # Attention heads [batch, seq_len, hidden_size] 43 | # --> [batch, seq_len, (np * 3 * head_size)] 44 | qkv = self.query_key_value(hidden_states) 45 | 46 | # [batch, seq_len, (num_heads * 3 * head_size)] 47 | # --> [batch, seq_len, num_heads, 3 * head_size] 48 | new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size) 49 | qkv = qkv.view(*new_qkv_shape) 50 | 51 | # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size] 52 | query = qkv[..., : self.head_size].permute(0, 2, 1, 3) 53 | key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3) 54 | value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3) 55 | 56 | # Compute rotary embeddings on rotary_ndims 57 | query_rot = query[..., : self.rotary_ndims] 58 | query_pass = query[..., self.rotary_ndims :] 59 | 60 | # Compute token offset for rotary embeddings (when decoding) 61 | seq_len = key.shape[-2] 62 | if has_layer_past: 63 | seq_len += layer_past[0].shape[-2] 64 | cos, sin = self.rotary_emb(value, seq_len=seq_len) 65 | query = apply_rotary_pos_emb_single(query_rot, cos, sin, position_ids) 66 | query = torch.cat((query, query_pass), dim=-1) 67 | 68 | # Cache QKV values 69 | if has_layer_past: 70 | past_key = layer_past[0] 71 | past_value = layer_past[1] 72 | key = torch.cat((past_key, key), dim=-2) 73 | value = torch.cat((past_value, value), dim=-2) 74 | 75 | present = (key, value) if use_cache else None 76 | 77 | key_rot = key[..., : self.rotary_ndims] 78 | key_pass = key[..., self.rotary_ndims :] 79 | key_position_ids = torch.arange(seq_len, device=position_ids.device).unsqueeze(0) 80 | key = apply_rotary_pos_emb_single(key_rot, cos, sin, key_position_ids) 81 | key = torch.cat((key, key_pass), dim=-1) 82 | 83 | # Compute attention 84 | attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) 85 | 86 | # Reshape outputs 87 | attn_output = self._merge_heads( 88 | attn_output, self.num_attention_heads, self.head_size 89 | ) 90 | attn_output = self.dense(attn_output) 91 | 92 | outputs = (attn_output, present) 93 | if output_attentions: 94 | outputs += (attn_weights,) 95 | 96 | return outputs 97 | 98 | 99 | def enable_gpt_neox_pos_shift_attention(model): 100 | for name, module in reversed(model._modules.items()): 101 | if len(list(module.children())) > 0: 102 | enable_gpt_neox_pos_shift_attention( 103 | module, 104 | ) 105 | 106 | if isinstance(module, GPTNeoXAttention): 107 | module.forward = types.MethodType( 108 | gpt_neox_pos_shift_attention_forward, module 109 | ) 110 | -------------------------------------------------------------------------------- /streaming_llm/pos_shift/modify_llama.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional, Tuple 3 | 4 | import torch 5 | from torch import nn 6 | import torch.utils.checkpoint 7 | 8 | import torch.nn.functional as F 9 | 10 | from transformers.models.llama.modeling_llama import ( 11 | LlamaAttention, 12 | rotate_half, 13 | apply_rotary_pos_emb, 14 | repeat_kv, 15 | ) 16 | import types 17 | 18 | __all__ = ["enable_llama_pos_shift_attention"] 19 | 20 | 21 | def apply_rotary_pos_emb_single(x, cos, sin, position_ids): 22 | # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. 23 | cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] 24 | sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] 25 | cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 26 | sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 27 | x_embed = (x * cos) + (rotate_half(x) * sin) 28 | return x_embed 29 | 30 | 31 | def llama_pos_shift_attention_forward( 32 | self, 33 | hidden_states: torch.Tensor, 34 | attention_mask: Optional[torch.Tensor] = None, 35 | position_ids: Optional[torch.LongTensor] = None, 36 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 37 | output_attentions: bool = False, 38 | use_cache: bool = False, 39 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 40 | bsz, q_len, _ = hidden_states.size() 41 | 42 | if self.config.pretraining_tp > 1: 43 | key_value_slicing = ( 44 | self.num_key_value_heads * self.head_dim 45 | ) // self.config.pretraining_tp 46 | query_slices = self.q_proj.weight.split( 47 | (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 48 | ) 49 | key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) 50 | value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) 51 | 52 | query_states = [ 53 | F.linear(hidden_states, query_slices[i]) 54 | for i in range(self.config.pretraining_tp) 55 | ] 56 | query_states = torch.cat(query_states, dim=-1) 57 | 58 | key_states = [ 59 | F.linear(hidden_states, key_slices[i]) 60 | for i in range(self.config.pretraining_tp) 61 | ] 62 | key_states = torch.cat(key_states, dim=-1) 63 | 64 | value_states = [ 65 | F.linear(hidden_states, value_slices[i]) 66 | for i in range(self.config.pretraining_tp) 67 | ] 68 | value_states = torch.cat(value_states, dim=-1) 69 | 70 | else: 71 | query_states = self.q_proj(hidden_states) 72 | key_states = self.k_proj(hidden_states) 73 | value_states = self.v_proj(hidden_states) 74 | 75 | query_states = query_states.view( 76 | bsz, q_len, self.num_heads, self.head_dim 77 | ).transpose(1, 2) 78 | key_states = key_states.view( 79 | bsz, q_len, self.num_key_value_heads, self.head_dim 80 | ).transpose(1, 2) 81 | value_states = value_states.view( 82 | bsz, q_len, self.num_key_value_heads, self.head_dim 83 | ).transpose(1, 2) 84 | 85 | kv_seq_len = key_states.shape[-2] 86 | if past_key_value is not None: 87 | kv_seq_len += past_key_value[0].shape[-2] 88 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 89 | ### Shift Pos: query pos is min(cache_size, idx) 90 | # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 91 | query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids) 92 | ### 93 | 94 | if past_key_value is not None: 95 | # reuse k, v, self_attention 96 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 97 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 98 | 99 | past_key_value = (key_states, value_states) if use_cache else None 100 | 101 | ### Shift Pos: key pos is the pos in cache 102 | key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0) 103 | key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids) 104 | ### 105 | 106 | # repeat k/v heads if n_kv_heads < n_heads 107 | key_states = repeat_kv(key_states, self.num_key_value_groups) 108 | value_states = repeat_kv(value_states, self.num_key_value_groups) 109 | 110 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( 111 | self.head_dim 112 | ) 113 | 114 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 115 | raise ValueError( 116 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" 117 | f" {attn_weights.size()}" 118 | ) 119 | 120 | if attention_mask is not None: 121 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 122 | raise ValueError( 123 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 124 | ) 125 | attn_weights = attn_weights + attention_mask 126 | 127 | # upcast attention to fp32 128 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( 129 | query_states.dtype 130 | ) 131 | attn_output = torch.matmul(attn_weights, value_states) 132 | 133 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 134 | raise ValueError( 135 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 136 | f" {attn_output.size()}" 137 | ) 138 | 139 | attn_output = attn_output.transpose(1, 2).contiguous() 140 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 141 | 142 | if self.config.pretraining_tp > 1: 143 | attn_output = attn_output.split( 144 | self.hidden_size // self.config.pretraining_tp, dim=2 145 | ) 146 | o_proj_slices = self.o_proj.weight.split( 147 | self.hidden_size // self.config.pretraining_tp, dim=1 148 | ) 149 | attn_output = sum( 150 | [ 151 | F.linear(attn_output[i], o_proj_slices[i]) 152 | for i in range(self.config.pretraining_tp) 153 | ] 154 | ) 155 | else: 156 | attn_output = self.o_proj(attn_output) 157 | 158 | if not output_attentions: 159 | attn_weights = None 160 | 161 | return attn_output, attn_weights, past_key_value 162 | 163 | 164 | def enable_llama_pos_shift_attention(model): 165 | for name, module in reversed(model._modules.items()): 166 | if len(list(module.children())) > 0: 167 | enable_llama_pos_shift_attention( 168 | module, 169 | ) 170 | 171 | if isinstance(module, LlamaAttention): 172 | model._modules[name].forward = types.MethodType( 173 | llama_pos_shift_attention_forward, model._modules[name] 174 | ) 175 | -------------------------------------------------------------------------------- /streaming_llm/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | from transformers import ( 4 | AutoTokenizer, 5 | AutoModelForCausalLM, 6 | ) 7 | import os.path as osp 8 | import ssl 9 | import urllib.request 10 | import os 11 | import json 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument( 17 | "--model_name_or_path", type=str, default="models/llama/llama-7b" 18 | ) 19 | parser.add_argument("--revision", type=str, default="main") 20 | parser.add_argument("--tokenizer_name_or_path", type=str, default=None) 21 | parser.add_argument("--dataset_name", type=str, default="wikitext") 22 | 23 | parser.add_argument("--task", type=str, default="wikitext-2-raw-v1") 24 | parser.add_argument( 25 | "--split", type=str, default="test", choices=["validation", "test"] 26 | ) 27 | 28 | parser.add_argument( 29 | "--num_samples", 30 | type=int, 31 | default=1, 32 | ) 33 | 34 | parser.add_argument( 35 | "--output_dir", 36 | type=str, 37 | default="outputs/debug", 38 | ) 39 | 40 | parser.add_argument("--enable_start_recent_kv_cache", action="store_true") 41 | parser.add_argument("--start_size", type=int, default=1) 42 | parser.add_argument("--recent_size", type=int, default=255) 43 | parser.add_argument("--enable_pos_shift", action="store_true") 44 | 45 | parser.add_argument("--num_eval_tokens", type=int, default=None) 46 | 47 | args = parser.parse_args() 48 | return args 49 | 50 | 51 | def load(model_name_or_path): 52 | print(f"Loading model from {model_name_or_path} ...") 53 | # however, tensor parallel for running falcon will occur bugs 54 | tokenizer = AutoTokenizer.from_pretrained( 55 | model_name_or_path, 56 | trust_remote_code=True, 57 | ) 58 | model = AutoModelForCausalLM.from_pretrained( 59 | model_name_or_path, 60 | device_map="auto", 61 | torch_dtype=torch.float16, 62 | trust_remote_code=True, 63 | ) 64 | if tokenizer.pad_token_id is None: 65 | if tokenizer.eos_token_id is not None: 66 | tokenizer.pad_token_id = tokenizer.eos_token_id 67 | else: 68 | tokenizer.pad_token_id = 0 69 | 70 | model.eval() 71 | 72 | return model, tokenizer 73 | 74 | 75 | def download_url(url: str, folder="folder"): 76 | """ 77 | Downloads the content of an url to a folder. Modified from \ 78 | https://github.com/pyg-team/pytorch_geometric/tree/master/torch_geometric 79 | 80 | Args: 81 | url (string): The url of target file. 82 | folder (string): The target folder. 83 | 84 | Returns: 85 | string: File path of downloaded files. 86 | """ 87 | 88 | file = url.rpartition("/")[2] 89 | file = file if file[0] == "?" else file.split("?")[0] 90 | path = osp.join(folder, file) 91 | if osp.exists(path): 92 | print(f"File {file} exists, use existing file.") 93 | return path 94 | 95 | print(f"Downloading {url}") 96 | os.makedirs(folder, exist_ok=True) 97 | ctx = ssl._create_unverified_context() 98 | data = urllib.request.urlopen(url, context=ctx) 99 | with open(path, "wb") as f: 100 | f.write(data.read()) 101 | 102 | return path 103 | 104 | 105 | def load_jsonl( 106 | file_path, 107 | ): 108 | list_data_dict = [] 109 | with open(file_path, "r") as f: 110 | for line in f: 111 | list_data_dict.append(json.loads(line)) 112 | return list_data_dict 113 | --------------------------------------------------------------------------------