├── .github ├── dependabot.yaml └── workflows │ ├── codeql.yml │ ├── main.yml │ ├── pr.yml │ ├── run-tests.yml │ └── weekly.yml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── dataset_generation ├── create_train_test_split.py ├── gen_synthetic_data.py └── generate_kb_embeddings.py ├── datasets ├── .gitattributes ├── datasetcard_enron.md ├── datasetcard_synthetic.md ├── enron.json └── synthetic.json ├── eval_acc.ipynb ├── experiments ├── Makefile ├── eval.py ├── output_scorer.py ├── output_scorer_open_ended.py └── train.py ├── pyproject.toml ├── src └── kblam │ ├── gpt_session.py │ ├── kb_encoder.py │ ├── models │ ├── kblam_config.py │ ├── kblam_processor.py │ ├── llama3_model.py │ └── phi3_model.py │ └── utils │ ├── convert.py │ ├── data_utils.py │ ├── eval_utils.py │ ├── model_utils.py │ └── train_utils.py └── tests ├── sample_data.json ├── test_dataset.json ├── test_dataset_construction.py └── test_kb_encoder.py /.github/dependabot.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: github-actions 4 | directory: "/" 5 | schedule: 6 | interval: weekly 7 | target-branch: main 8 | groups: 9 | github-actions: 10 | patterns: 11 | - "*" 12 | commit-message: 13 | prefix: ":arrow_up: dep-bump" 14 | include: scope 15 | 16 | - package-ecosystem: pip 17 | directory: "/" # location of package manifests 18 | schedule: 19 | interval: weekly 20 | target-branch: main 21 | groups: 22 | security-packages: 23 | applies-to: security-updates 24 | patterns: 25 | - "*" 26 | dependency-packages: 27 | applies-to: version-updates 28 | patterns: 29 | - "*" 30 | commit-message: 31 | prefix: ":arrow_up: dep-bump" 32 | include: scope -------------------------------------------------------------------------------- /.github/workflows/codeql.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: CodeQL 13 | 14 | on: 15 | workflow_call: 16 | 17 | jobs: 18 | analyze: 19 | name: analyze 20 | runs-on: ubuntu-latest 21 | permissions: 22 | actions: read 23 | contents: read 24 | security-events: write 25 | 26 | strategy: 27 | fail-fast: false 28 | matrix: 29 | language: [ python ] 30 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 31 | # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support 32 | 33 | steps: 34 | - name: checkout-repository 35 | uses: actions/checkout@v4 36 | 37 | # Initializes the CodeQL tools for scanning. 38 | - name: initialize-codeql 39 | uses: github/codeql-action/init@v3 40 | with: 41 | languages: ${{ matrix.language }} 42 | # If you wish to specify custom queries, you can do so here or in a config file. 43 | # By default, queries listed here will override any specified in a config file. 44 | # Prefix the list here with "+" to use these queries and those in the config file. 45 | 46 | # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 47 | # queries: security-extended,security-and-quality 48 | 49 | 50 | # Autobuild attempts to build any compiled languages (C/C++, C#, Go, or Java). 51 | # If this step fails, then you should remove it and run the build manually (see below) 52 | - name: autobuild 53 | uses: github/codeql-action/autobuild@v3 54 | 55 | # ℹ️ Command-line programs to run using the OS shell. 56 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 57 | 58 | # If the Autobuild fails above, remove it and uncomment the following three lines. 59 | # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. 60 | 61 | # - run: | 62 | # echo "Run, Build Application using script" 63 | # ./location_of_script_within_repo/buildscript.sh 64 | 65 | - name: codeql-analysis 66 | uses: github/codeql-action/analyze@v3 67 | with: 68 | category: "/language:${{matrix.language}}" -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: main 2 | 3 | on: 4 | workflow_dispatch: 5 | push: 6 | branches: [ main ] 7 | tags: 8 | - "v[0-9]+.[0-9]+.[0-9]+" 9 | 10 | concurrency: 11 | group: ${{ github.workflow }}-${{ github.ref }} 12 | cancel-in-progress: true 13 | 14 | jobs: 15 | run-tests: 16 | uses: ./.github/workflows/run-tests.yml 17 | 18 | codeql: 19 | uses: ./.github/workflows/codeql.yml 20 | secrets: inherit 21 | permissions: 22 | contents: read 23 | actions: read 24 | security-events: write 25 | -------------------------------------------------------------------------------- /.github/workflows/pr.yml: -------------------------------------------------------------------------------- 1 | name: pr 2 | 3 | on: 4 | pull_request: 5 | branches: [ main ] 6 | 7 | concurrency: 8 | group: ${{ github.workflow }}-${{ github.ref }} 9 | cancel-in-progress: true 10 | 11 | jobs: 12 | run-tests: 13 | uses: ./.github/workflows/run-tests.yml 14 | 15 | codeql: 16 | uses: ./.github/workflows/codeql.yml 17 | secrets: inherit 18 | permissions: 19 | contents: read 20 | actions: read 21 | security-events: write 22 | -------------------------------------------------------------------------------- /.github/workflows/run-tests.yml: -------------------------------------------------------------------------------- 1 | name: run-tests 2 | 3 | on: 4 | workflow_call: 5 | 6 | jobs: 7 | run-tests: 8 | runs-on: ubuntu-latest 9 | name: run-tests 10 | steps: 11 | - uses: actions/checkout@v4 12 | 13 | - name: Set up Python 14 | uses: actions/setup-python@v5 15 | with: 16 | python-version: "3.10" 17 | 18 | - name: Install dependencies 19 | run: | 20 | python -m pip install -U pip 21 | pip install -U pdm 22 | pdm install --dev 23 | 24 | # To be added back gradually 25 | # - name: ruff-check 26 | # run: | 27 | # pdm run ruff check --no-fix --output-format=github 28 | 29 | # - name: ruff-format 30 | # run: | 31 | # pdm run ruff format --check 32 | 33 | # - name: pyright 34 | # run: | 35 | # pdm run pyright 36 | 37 | # - name: codespell 38 | # run: | 39 | # pdm run codespell 40 | 41 | - name: pytest 42 | run: | 43 | pdm run coverage run -m pytest 44 | 45 | - name: coverage 46 | run: | 47 | pdm run coverage xml -o .tmp/reports/coverage.xml --include="src/*" -------------------------------------------------------------------------------- /.github/workflows/weekly.yml: -------------------------------------------------------------------------------- 1 | name: weekly 2 | 3 | on: 4 | workflow_dispatch: 5 | schedule: 6 | - cron: '42 23 * * 4' 7 | 8 | jobs: 9 | codeql: 10 | uses: ./.github/workflows/codeql.yml 11 | secrets: inherit 12 | permissions: 13 | contents: read 14 | actions: read 15 | security-events: write -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | 165 | heatmap* -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # KBLaM - Knowledge Base Augmented Language Models [ICLR 2025] 2 | 3 | This repo contains the official implementation of [KBLaM: Knowledge Base Augmented Language Models](https://arxiv.org/abs/2410.10450). 4 | 5 | Authors: Xi Wang, Liana Mikaelyan, Taketomo Isazawa, Mathew Salvaris, James Hensman. 6 | 7 | KBLaM is a new method for augmentating LLMs with external knowledge. 8 | Unlike Retrieval-Augmented Generation, KBLAM eliminates external 9 | retrieval modules, and unlike in-context learning, its computational overhead scales linearly with KB size rather than quadratically. 10 | 11 | ## Supported Models 12 | The following models from Hugging Face hub are currently supported: 13 | - [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) 14 | - [meta-llama/Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct) 15 | - [Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) 16 | 17 | To add support for new model types, you will need to update the model processing scripts to incorporate an adapter similar to `llama_model.py` in `src/kblam/models`. 18 | 19 | ## Setting up 20 | 21 | Install the kblam package with 22 | 23 | ``` 24 | pip install -e . 25 | ``` 26 | 27 | To use Llama models, you will need to generate a token from Hugging Face and use it to log in: 28 | 29 | ``` 30 | pip install huggingface_hub 31 | huggingface-cli login 32 | ``` 33 | 34 | The experiments in the paper can be replicated by running the scripts in `./experiments`. 35 | 36 | 37 | ## Dataset Construction 38 | 39 | To run the synthetic dataset construction, you will need a valid Azure OpenAI endpoint. 40 | 41 | To construct a synthetic KB and question-answer pairs use `dataset_generation/gen_synthetic_data.py` 42 | 43 | The question-answer pairs are constructed in the form: 44 | 45 | ``` 46 | What is the description of {entity_name}? 47 | The description of {entity_name} is {description}. 48 | ``` 49 | 50 | To generate KB embeddings, use `dataset_generation/generate_kb_embeddings.py`. 51 | The embeddings we current support are [text-embedding-ada-002](https://openai.com/index/new-and-improved-embedding-model/) and [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2). 52 | 53 | 54 | ## Training 55 | 56 | To train the model, run the following (with the appropriate arguments): 57 | 58 | ``` 59 | python train.py --dataset synthetic_data --N 120000 --B 20 --total_steps 601 --encoder_spec OAI --use_oai_embd --key_embd_src key --use_data_aug 60 | ``` 61 | 62 | ## Contributing 63 | 64 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 65 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 66 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 67 | 68 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 69 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 70 | provided by the bot. You will only need to do this once across all repos using our CLA. 71 | 72 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 73 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 74 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 75 | 76 | ## Trademarks 77 | 78 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 79 | trademarks or logos is subject to and must follow 80 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 81 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 82 | Any use of third-party trademarks or logos are subject to those third-party's policies. 83 | 84 | ## FAQ 85 | 86 | ### What is KBLaM? 87 | 88 | KBLaM is a method to enhance a transformer-based LLM to augment it with knowledge. It consists of a base LLM, and some adapters that we train to transform the knowledge base to special knowledge tokens that the LLM ingests. In particular, because we only train adapters over the knowledge part, the base LLM is completely unmodified with regards to text input. If given no knowledge base, the model outputs the exact same thing as the base model for any given input. 89 | 90 | ### What can KBLaM do? 91 | 92 | KBLaM can, in addition to the base LLM’s capabilities, also attend over the knowledge base to answer questions in a grounded manner. 93 | 94 | ### What is/are KBLaM’s intended use(s)? 95 | 96 | The model is intended to be used for research. 97 | 98 | ### How was KBLaM evaluated? What metrics are used to measure performance? 99 | 100 | KBLaM was evaluated on accuracy of retrieval from the knowledge base, its refusal rate (how often it correctly said that it didn’t have the requisite information to answer the question), and precision and recall on how well the answers aligned with the correct answers given the knowledge base. 101 | 102 | ### What are the limitations of KBLaM? How can users minimize the impact of KBLaM’s limitations when using the system? 103 | 104 | When used with knowledge bases that are very different from the knowledge base it was trained on, KBLaM will give incomplete answers, and the answers can be reworded from the original value in the knowledge base or at times entirely incorrect. As a result, KBLaM is not currently intended for use as a complete system in a production setting, but is a research project that we are sharing. 105 | 106 | ### What operational factors and settings allow for effective and responsible use of KBLaM? 107 | 108 | KBLaM with no knowledge base will perform the exact same as the base model. With a knowledge base, for effective use, one should make sure that the training dataset and the usecase have sufficiently similar knowledge bases 109 | 110 | ### How do I provide feedback on KBLaM? 111 | 112 | Please add issues to this repository to provide feedback on KBLaM. 113 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. 7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /dataset_generation/create_train_test_split.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | 7 | 8 | def _create_train_test_names(orig_path: str) -> tuple[str, str]: 9 | train = f"train_{Path(orig_path).name}" 10 | test = f"test_{Path(orig_path).name}" 11 | return train, test 12 | 13 | 14 | def _write_json(data: list[dict], save_path: str) -> None: 15 | with open(save_path, "w") as f: 16 | json.dump(data, f, indent=2) 17 | 18 | 19 | def _save_array(arr: np.array, filename: str) -> None: 20 | np.save(filename, arr) 21 | 22 | 23 | def create_train_test_split( 24 | data_path: str, 25 | embedding_keys_path: str, 26 | embeddings_values_path: str, 27 | split_index: int, 28 | output_path: str, 29 | ) -> None: 30 | """Split data into training and test sets and save the results. 31 | 32 | This function loads data from the specified paths, creates a train-test split at the given index, 33 | and saves the resulting splits to the output path. 34 | 35 | Parameters 36 | ---------- 37 | data_path : str 38 | Path to the main data file to be split 39 | embedding_keys_path : str 40 | Path to the file containing embedding keys 41 | embeddings_values_path : str 42 | Path to the file containing embedding values 43 | split_index : int 44 | Index at which to split the data into train and test sets. 45 | Data before this index will be training, after will be test. 46 | output_path : str 47 | Directory path where the split datasets will be saved 48 | 49 | Returns 50 | ------- 51 | None 52 | Function saves the split datasets to disk but does not return any values 53 | 54 | """ 55 | 56 | output_p = Path(output_path) 57 | output_p.mkdir(exist_ok=True, parents=True) 58 | 59 | train_dataset = json.load(open(data_path))[:split_index] 60 | 61 | train_key_embds = np.load(embedding_keys_path).astype("float32")[:split_index] 62 | train_value_embds = np.load(embeddings_values_path).astype("float32")[:split_index] 63 | 64 | test_dataset = json.load(open(data_path))[split_index:] 65 | 66 | test_key_embds = np.load(embedding_keys_path).astype("float32")[split_index:] 67 | test_value_embds = np.load(embeddings_values_path).astype("float32")[split_index:] 68 | 69 | train_dataset_name, test_dataset_name = _create_train_test_names(data_path) 70 | train_keys_name, test_keys_name = _create_train_test_names(embedding_keys_path) 71 | train_values_name, test_values_name = _create_train_test_names( 72 | embeddings_values_path 73 | ) 74 | 75 | _write_json(train_dataset, output_p / train_dataset_name) 76 | _write_json(test_dataset, output_p / test_dataset_name) 77 | _save_array(train_key_embds, output_p / train_keys_name) 78 | _save_array(train_value_embds, output_p / train_values_name) 79 | _save_array(test_key_embds, output_p / test_keys_name) 80 | _save_array(test_value_embds, output_p / test_values_name) 81 | 82 | 83 | def parser_args(): 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument("--data_path", type=str) 86 | parser.add_argument("--embedding_keys_path", type=str) 87 | parser.add_argument("--embeddings_values_path", type=str) 88 | parser.add_argument("--output_path", type=str) 89 | parser.add_argument("--split_index", type=int) 90 | 91 | args = parser.parse_args() 92 | return args 93 | 94 | 95 | if __name__ == "__main__": 96 | args = parser_args() 97 | create_train_test_split( 98 | args.data_path, 99 | args.embedding_keys_path, 100 | args.embeddings_values_path, 101 | args.split_index, 102 | args.output_path, 103 | ) 104 | -------------------------------------------------------------------------------- /dataset_generation/gen_synthetic_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | from itertools import product 6 | 7 | from tqdm import tqdm 8 | from transformers import AutoModelForCausalLM 9 | 10 | from kblam.gpt_session import GPT 11 | from kblam.utils.data_utils import DataPoint, Entity, save_entity 12 | 13 | 14 | def construct_prompts(entity: DataPoint) -> tuple[str, str, str]: 15 | """Given a data point, creates a question, answer and key string.""" 16 | Q = "What is the {} of {}?".format(entity.description_type, entity.name) 17 | A = f"The {entity.description_type} of {entity.name} is {entity.description}." 18 | key_string = f"the {entity.description_type} of {entity.name}" 19 | return Q, A, key_string 20 | 21 | 22 | class SyntheticDataGenerator(GPT): 23 | def __init__( 24 | self, model: AutoModelForCausalLM, endpoint_url: str, **kwargs 25 | ) -> None: 26 | self.system_prompt = """You are a AI system that generates synthetic data examples in JSON format.""" 27 | 28 | self.entity_format_prompt = """ 29 | \nMake sure to generate a single data point in the following JSON format: 30 | { 31 | "name": "{name}", 32 | "description": "{description}", 33 | "objectives": "{objectives}", 34 | "purpose": "{purpose}" 35 | } 36 | """ 37 | 38 | self.prompt_2nd_phase = ( 39 | """ 40 | Now for each of the names generated, generate a short desciption, short objectives, and a purpose for the data point. 41 | Please ensure that the generated contents has **LOW** correlation with the name. 42 | Make the data point styles diverse using a mixture of formal and informal language. 43 | """ 44 | + self.entity_format_prompt 45 | + " Do **NOT** generate anything else." 46 | ) 47 | 48 | self.idea_sources = [ 49 | "software companies", 50 | "tech companies", 51 | "software tools", 52 | "greek letters", 53 | "product reviews", 54 | "product releases", 55 | "work-related concepts", 56 | "work-related documents", 57 | "document types", 58 | "financial terms", 59 | "legal terms", 60 | "medical terms", 61 | "fiction characters", 62 | "famous rock bands", 63 | "birds", 64 | "animals", 65 | "natural phenomena", 66 | "physical locations", 67 | "artist names", 68 | "classical music", 69 | "musical instruments", 70 | "music genres", 71 | "art styles", 72 | "ancient Roman concepts", 73 | "Hindu myths", 74 | "Cthulhu Mythos", 75 | "real-world company names", 76 | "mythological creatures", 77 | "planets and stars", 78 | "historical figures", 79 | "political figures", 80 | "literary genres", 81 | "botanical names", 82 | "famous landmarks", 83 | "scientific concepts", 84 | "space missions", 85 | "inventions", 86 | "philosophical terms", 87 | "chemical elements", 88 | "famous scientists", 89 | "famous mathematicians", 90 | "famous authors", 91 | "marine life", 92 | "mythological places", 93 | "famous battles", 94 | "sports teams", 95 | "sport events", 96 | "food and drinks", 97 | ] 98 | 99 | self.data_types = [ 100 | "person name", 101 | "idea", 102 | "team", 103 | "meeting", 104 | "event", 105 | "location", 106 | "document", 107 | "presentation", 108 | "meeting", 109 | "conference", 110 | "workshop", 111 | "database", 112 | "organization", 113 | "tech company", 114 | "car company", 115 | "entertainment company", 116 | "construction company", 117 | "retail company", 118 | "finance company", 119 | "healthcare company", 120 | "restaurant", 121 | "hotel", 122 | "museum", 123 | "university", 124 | "educational institution", 125 | "government agency", 126 | "hospital", 127 | "github repo", 128 | "project", 129 | "meeting room", 130 | "building", 131 | "product", 132 | "lab", 133 | "airline", 134 | "textbook", 135 | "tv show", 136 | "music album", 137 | "website", 138 | "personal blog", 139 | "gaming company", 140 | "game" "movie studio", 141 | "consulting firm", 142 | "biotech company", 143 | "app", 144 | "software tool", 145 | "bookstore", 146 | "coffee shop", 147 | "bar", 148 | "e-commerce site", 149 | "social media platform", 150 | "fitness brand", 151 | "fashion brand", 152 | "beauty brand", 153 | "food brand", 154 | "drink brand", 155 | "sports brand", 156 | "travel brand", 157 | "non-profit organization", 158 | "political party", 159 | ] 160 | 161 | super().__init__(model, endpoint_url, **kwargs) 162 | 163 | def get_instructions(self): 164 | return [ 165 | f"Please randomly generate a {name_type} name innovated by or associated with {idea_type}." 166 | "The generated name should be of diverse style and length. A valid name should consist of a single word (such as Alexandria or Microsoft) or multiple words (such as Microsoft Office or Theta-Phoenix Entertainment). " 167 | for (name_type, idea_type) in product(self.idea_sources, self.data_types) 168 | ] 169 | 170 | def generate_entity(self, instruction: str) -> Entity: 171 | prompt = [ 172 | {"role": "system", "content": self.system_prompt}, 173 | {"role": "user", "content": instruction}, 174 | ] 175 | gpt_output = self.api_call_chat(prompt) 176 | 177 | messages = [ 178 | {"role": "system", "content": self.system_prompt}, 179 | {"role": "user", "content": instruction}, 180 | {"role": "assistant", "content": gpt_output}, 181 | {"role": "user", "content": self.prompt_2nd_phase}, 182 | ] 183 | gpt_output = self.api_call_chat(messages) 184 | 185 | gpt_output = self.api_call_chat(messages) 186 | entity = Entity(**json.loads(gpt_output)) 187 | return entity 188 | 189 | def generate_related_data(self, entity: Entity) -> Entity: 190 | instruction = f"Generate a person name related to the entity {entity.name} with description {entity.description}." 191 | instruction += "The person needs to be associated with the entity in some way. e.g. they work in the company or they are a character in the book." 192 | instruction += ( 193 | f"Make sure the entity is in the format of {self.entity_format_prompt}" 194 | ) 195 | 196 | prompt = [ 197 | {"role": "system", "content": self.system_prompt}, 198 | {"role": "user", "content": instruction}, 199 | ] 200 | 201 | gpt_output = self.api_call_chat(prompt) 202 | entity = Entity(**json.loads(gpt_output)) 203 | 204 | return entity 205 | 206 | def post_process_data(self, entity_list: list[Entity]) -> list[DataPoint]: 207 | dataset = [] 208 | keywords = {"description", "objectives", "purpose"} 209 | 210 | for entity in entity_list: 211 | for keyword in keywords: 212 | datapoint = DataPoint( 213 | name=entity.name, 214 | description_type=keyword.lower(), 215 | description=getattr(entity, keyword), 216 | ) 217 | datapoint.Q, datapoint.A, datapoint.key_string = construct_prompts( 218 | datapoint 219 | ) 220 | dataset.append(datapoint) 221 | 222 | return dataset 223 | 224 | def augmenta_data_with_synthetic_QA( 225 | self, dataset: list[DataPoint] 226 | ) -> list[DataPoint]: 227 | self.system_prompt = """You are given a question and answer pair, please extend the question to be open-ended and generate a short answer. 228 | For example, you could generate "What is the objective of xxx and what do you think of it?" 229 | Make sure the answer is **only** based on information provided from the QA pair. In addition, please generate in the format of: 230 | Q: ... 231 | A: ... 232 | """ 233 | 234 | for data in dataset: 235 | try: 236 | prompt = ( 237 | "Generate an extended Q and an A for this pair: " 238 | + f"Q: {data.Q}\nA: {data.A}" 239 | ) 240 | gpt_output = self.generate_response(prompt) 241 | extended_q = re.findall(r"Q: (.*)", gpt_output)[0] 242 | extended_a = re.findall(r"A: (.*)", gpt_output)[0] 243 | data.extended_Q = extended_q 244 | data.extended_A = extended_a 245 | except Exception as e: 246 | print("Error augmenting Q&A.") 247 | print(e) 248 | continue 249 | return dataset 250 | 251 | def perturb_names(self, dataset: list[DataPoint]): 252 | for data in dataset: 253 | try: 254 | prompt = f"Perturb the names in the queries of the dataset (e.g. Margaret Thatcher -> Maggie Thatcher or Microsoft Research to MSR) for data point with name {data.name}." 255 | prompt += f"Return the question {data.Q} with the perturbed name. Make sure the perturbation is valid. Do NOT generate anything else." 256 | gpt_output = self.generate_response(prompt) 257 | data.Q = gpt_output 258 | 259 | except Exception as e: 260 | print("Error perturbing the names in the queries.") 261 | print(e) 262 | continue 263 | return dataset 264 | 265 | 266 | def parser_args(): 267 | parser = argparse.ArgumentParser() 268 | parser.add_argument("--model_name", type=str, default="gpt-4o") 269 | parser.add_argument("--endpoint_url", type=str, required=True) 270 | parser.add_argument("--output_path", type=str, default="dataset") 271 | parser.add_argument("--generate_related_people", type=bool, default=True) 272 | parser.add_argument( 273 | "--raw_output_file", type=str, default="synthetic_data_raw.json" 274 | ) 275 | parser.add_argument("--output_file", type=str, default="synthetic_data_QA.json") 276 | parser.add_argument( 277 | "--perturbed_output_file", type=str, default="perturbed_output_file" 278 | ) 279 | parser.add_argument( 280 | "--augmented_output_file", type=str, default="synthetic_data_QA_augmented.json" 281 | ) 282 | 283 | args = parser.parse_args() 284 | return args 285 | 286 | 287 | if __name__ == "__main__": 288 | args = parser_args() 289 | 290 | data_generator = SyntheticDataGenerator(args.model_name, args.endpoint_url) 291 | 292 | os.makedirs(args.output_path, exist_ok=True) 293 | 294 | raw_output_file = os.path.join(args.output_path, args.raw_output_file) 295 | 296 | if os.path.exists(raw_output_file): 297 | # skip entities creation if it's already generated 298 | with open(raw_output_file, "r") as file: 299 | entity_list = [Entity(**json.loads(line)) for line in file] 300 | 301 | else: 302 | entity_list = [] 303 | for seed in range(1): 304 | data_generator.set_seed(seed) 305 | for instruction in tqdm(data_generator.get_instructions()): 306 | try: 307 | entity = data_generator.generate_entity(instruction) 308 | except Exception as e: 309 | print("Error generating entity.") 310 | print(e) 311 | continue 312 | save_entity(entity, raw_output_file) 313 | entity_list.append(entity) 314 | 315 | if args.generate_related_people: 316 | try: 317 | response = data_generator.generate_related_data(entity) 318 | except Exception as e: 319 | print("Error generating entity.") 320 | print(e) 321 | continue 322 | save_entity(response, raw_output_file) 323 | entity_list.append(response) 324 | 325 | QA_output_file = os.path.join(args.output_path, args.output_file) 326 | 327 | if os.path.exists(QA_output_file): 328 | with open(QA_output_file, "r") as file: 329 | dataset = [DataPoint(**json.loads(line)) for line in file] 330 | else: 331 | dataset = data_generator.post_process_data(entity_list) 332 | for data in dataset: 333 | save_entity(data, QA_output_file) 334 | 335 | perturbed_dataset = data_generator.perturbe_names(dataset) 336 | 337 | for data in perturbed_dataset: 338 | save_entity(data, os.path.join(args.output_path, args.perturbed_output_file)) 339 | 340 | dataset = data_generator.augmenta_data_with_synthetic_QA(dataset) 341 | for data in dataset: 342 | save_entity(data, os.path.join(args.output_path, args.augmented_output_file)) 343 | -------------------------------------------------------------------------------- /dataset_generation/generate_kb_embeddings.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import numpy as np 6 | from sentence_transformers import SentenceTransformer 7 | from tqdm import tqdm 8 | 9 | from kblam.gpt_session import GPT 10 | from kblam.utils.data_utils import DataPoint 11 | 12 | 13 | def parser_args(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument( 16 | "--model_name", 17 | type=str, 18 | default="text-embedding-3-large", 19 | choices=["all-MiniLM-L6-v2", "text-embedding-3-large", "ada-embeddings"], 20 | ) 21 | parser.add_argument("--dataset_name", type=str, default="synthetic_data") 22 | parser.add_argument("--endpoint_url", type=str) 23 | parser.add_argument( 24 | "--dataset_path", 25 | type=str, 26 | required=False, 27 | help="Path to the dataset in JSON format.", 28 | ) 29 | parser.add_argument("--output_path", type=str, default="dataset") 30 | 31 | args = parser.parse_args() 32 | return args 33 | 34 | 35 | def compute_embeddings( 36 | encoder_model_spec: str, dataset: list[DataPoint], part: str, batch_size: int = 100 37 | ) -> np.array: 38 | """Compute embeddings for the given dataset in batches using the encoder model spec.""" 39 | embeddings = [] 40 | all_elements = [] 41 | for entity in dataset: 42 | if part == "key_string": 43 | all_elements.append(entity.key_string) 44 | elif part == "description": 45 | all_elements.append(entity.description) 46 | else: 47 | raise ValueError(f"Part {part} not supported.") 48 | chunks = [ 49 | all_elements[i : i + batch_size] 50 | for i in range(0, len(all_elements), batch_size) 51 | ] 52 | 53 | model = SentenceTransformer(encoder_model_spec, device="cuda") 54 | for chunk in tqdm(chunks): 55 | embd = model.encode(chunk, convert_to_numpy=True) 56 | embeddings.append(embd) 57 | 58 | embeddings = np.concatenate(embeddings, 0) 59 | assert len(embeddings) == len(all_elements) 60 | return embeddings 61 | 62 | 63 | if __name__ == "__main__": 64 | args = parser_args() 65 | with open(args.dataset_path, "r") as file: 66 | loaded_dataset = json.loads(file.read()) 67 | 68 | dataset = [DataPoint(**line) for line in loaded_dataset] 69 | if args.model_name == "all-MiniLM-L6-v2": 70 | key_embeds = compute_embeddings(args.model_name, dataset, "key_string") 71 | value_embeds = compute_embeddings(args.model_name, dataset, "description") 72 | elif args.model_name in ["ada-embeddings", "text-embedding-3-large"]: 73 | gpt = GPT(args.model_name, args.endpoint_url) 74 | 75 | key_embeds = [] 76 | value_embeds = [] 77 | 78 | for entity in tqdm(dataset): 79 | key_embeds.append(gpt.generate_embedding(entity.key_string)) 80 | value_embeds.append(gpt.generate_embedding(entity.description)) 81 | else: 82 | raise ValueError(f"Model {args.model_name} not supported.") 83 | 84 | os.makedirs(args.output_path, exist_ok=True) 85 | 86 | if args.model_name == "all-MiniLM-L6-v2": 87 | save_name = "all-MiniLM-L6-v2" 88 | elif args.model_name == "ada-embeddings": 89 | save_name = "OAI" 90 | else: 91 | save_name = "BigOAI" 92 | 93 | np.save( 94 | f"{args.output_path}/{args.dataset_name}_{save_name}_embd_key.npy", 95 | np.array(key_embeds), 96 | ) 97 | np.save( 98 | f"{args.output_path}/{args.dataset_name}_{save_name}_embd_value.npy", 99 | np.array(value_embeds), 100 | ) 101 | -------------------------------------------------------------------------------- /datasets/.gitattributes: -------------------------------------------------------------------------------- 1 | synthetic.json filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /datasets/datasetcard_enron.md: -------------------------------------------------------------------------------- 1 | # Dataset Card for Enron extracted knowledgebase 2 | 3 | The Enron extracted knowledgebase consists of triples that were extracted using a small language model from the Enron database. 4 | 5 | ## Dataset Details 6 | 7 | ### Dataset Description 8 | 9 | The triples in this knowledge base were extracted from the Enron email dataset using a small language model trained for entity extraction, then were clustered for de-duplication of entities and converted into triples where the relations were description, objective, and purpose. 10 | 11 | - **Curated by:** Microsoft Research 12 | - **Language(s) (NLP):** English 13 | - **License:** MIT 14 | 15 | ## Uses 16 | 17 | The dataset is intended to be used for the training and evaluation of grounded LLMs. The dataset can also be used for needle-in-the-haystack retrieval tasks, by augmenting the questions with a number of noise triples. 18 | 19 | ### Direct Use 20 | 21 | Research model training and evaluation. 22 | 23 | ### Out-of-Scope Use 24 | 25 | The dataset may not be useful as a real-world knowledge base as the triples were extracted using an automated system. 26 | 27 | ## Dataset Structure 28 | 29 | A list of JSONs, each with the following properties: 30 | 31 | `name`: name of entity 32 | `description_type`: the name of the property 33 | `description`: the value of the property 34 | `Q`: A question based on the triple 35 | `A`: An answer based on the triple 36 | `key_string`: The key used in KBLaM (created with a template of "The {property name} of {entity name}") 37 | 38 | ## Dataset Creation 39 | 40 | ### Curation Rationale 41 | 42 | The data was created to allow for the evaluation of knowledge-base augmented LLMs on real-world data. 43 | 44 | ### Source Data 45 | 46 | Enron email dataset 47 | 48 | #### Data Collection and Processing 49 | 50 | The entities were extracted using a generative SLM fine-tuned on the task, and linked using the project Alexandria entity linker for disambiguation. 51 | 52 | #### Who are the source data producers? 53 | 54 | The Enron email dataset is provided by CMU, who sourced it from Enron. 55 | 56 | #### Personal and Sensitive Information 57 | 58 | No additional personal information is contained over the Enron email dataset. 59 | 60 | ## Bias, Risks, and Limitations 61 | 62 | This dataset reflects the biases from the Enron email dataset, and may also be limited by the capabilities of the extraction process. This dataset only contains one description, objective, and purpose for each entity, when many more were extracted. This means that its use as a complete knowledgebase is limited. 63 | 64 | ### Recommendations 65 | 66 | Due to the limitations as a complete knowledge base, it is recommended that this dataset is used for the evaluation of knowledgebase-augmented models only. 67 | 68 | ## Dataset Card Contact 69 | 70 | t-isazawat@microsoft.com 71 | -------------------------------------------------------------------------------- /datasets/datasetcard_synthetic.md: -------------------------------------------------------------------------------- 1 | # Dataset Card for KBLaM synthetic grounded questions dataset 2 | 3 | The KBLaM synthetic grounded questions dataset consists of a synthetic, GPT-4 generated knowledgebase of triples and a number of factual questions on this dataset. 4 | 5 | ## Dataset Details 6 | 7 | ### Dataset Description 8 | 9 | - **Curated by:** Microsoft Research 10 | - **Language(s) (NLP):** English 11 | - **License:** MIT 12 | 13 | ## Uses 14 | 15 | The dataset is intended to be used for the training and evaluation of grounded LLMs. The dataset can also be used for needle-in-the-haystack retrieval tasks, by augmenting the questions with a number of noise triples. 16 | 17 | ### Direct Use 18 | 19 | Research model training and evaluation. 20 | 21 | ### Out-of-Scope Use 22 | 23 | The dataset will not work well as a real-world knowledgebase as the triples are entirely synthetic. 24 | 25 | ## Dataset Structure 26 | 27 | A list of JSONs, each with the following properties: 28 | 29 | `name`: name of entity 30 | `description_type`: the name of the property 31 | `description`: the value of the property 32 | `Q`: A question based on the triple 33 | `A`: An answer based on the triple 34 | `key_string`: The key used in KBLaM (created with a template of "The {property name} of {entity name}") 35 | 36 | ## Dataset Creation 37 | 38 | ### Curation Rationale 39 | 40 | The data was created using GPT to allow for training and evaluation of knowledge-base augmented LLMs. 41 | 42 | ### Source Data 43 | 44 | N/A - the data is entirely synthetic, produced by GPT-4 45 | 46 | #### Data Collection and Processing 47 | 48 | The data was created synthetically using GPT-4. 49 | 50 | #### Who are the source data producers? 51 | 52 | The data was created synthetically using GPT-4. 53 | 54 | #### Personal and Sensitive Information 55 | 56 | The data was created synthetically using GPT-4, so personal data is unlikely. 57 | 58 | ## Bias, Risks, and Limitations 59 | 60 | As the data was created by GPT-4, the dataset's distribution will be biased towards GPT-4's builtin biases, and this should be taken as a limitation when creating or evaluating a model. 61 | 62 | ### Recommendations 63 | 64 | Any models should be evaluated also on other, human-created, datasets to ensure good real-world performance. 65 | 66 | ## Dataset Card Contact 67 | 68 | t-isazawat@microsoft.com 69 | -------------------------------------------------------------------------------- /datasets/synthetic.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:f3a9eb6bdb735a14a9edcfc8100a976051db25cf793a00060b127e60a58bfcc9 3 | size 107374110 4 | -------------------------------------------------------------------------------- /experiments/Makefile: -------------------------------------------------------------------------------- 1 | .DEFAULT_GOAL := help # Sets default action to be help 2 | 3 | 4 | define PRINT_HELP_PYSCRIPT # start of Python section 5 | import re, sys 6 | 7 | output = [] 8 | # Loop through the lines in this file 9 | for line in sys.stdin: 10 | # if the line has a command and a comment start with 11 | # two pound signs, add it to the output 12 | match = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line) 13 | if match: 14 | target, help = match.groups() 15 | output.append("%-10s %s" % (target, help)) 16 | # Sort the output in alphanumeric order 17 | output.sort() 18 | # Print the help result 19 | print('\n'.join(output)) 20 | endef 21 | export PRINT_HELP_PYSCRIPT # End of python section 22 | 23 | LLAMA_BASE_DIR='/datadisk/tk/llama3_8b_ins' 24 | LLAMA_HF='/kblam_llama_unified/kblam_unified' 25 | PHI3_HF='/datadisk/data/gcrbackup/hf_models/phi3' 26 | DATASET_DIR='/datadisk/data/gcrbackup/oai_embd' 27 | TEST_DATASET_DIR='/datadisk/data/train_test_split' 28 | CKPT_SAVE_DIR='/home/msalvaris/data/experiments/kblam/exp_v0.3.2' 29 | PHI3_MODEL_CHKPT='/datadisk/data/gcrbackup/experiments/kblam/exp_v0.1/stage1_lr_0.0001KBTokenLayerFreq3UseExtendedQAMultiEntities2UseOutlier1NoDuplicateKBSizedynamicSepQueryHeadUseDataAugKeyFromkey_OAI_synthetic_data_phi3_step_20000' 30 | ENCODER_FOR_PHI3_CHKPT='/datadisk/data/gcrbackup/experiments/kblam/exp_v0.1/stage1_lr_0.0001KBTokenLayerFreq3UseExtendedQAMultiEntities2UseOutlier1NoDuplicateKBSizedynamicSepQueryHeadUseDataAugFineTuneQueryKeyFromkey_synthetic_data_OAI_step_20000' 31 | LLAMA_MODEL_CHKPT='/datadisk/tk/llama3_8b_ins' 32 | ENCODER_FOR_LLAMA_CHKPT='/datadisk/tk/encoder_ckpt_20000_OAI.pt' 33 | QUERY_HEAD_PATH='/datadisk/tk/learned_query_head_20000_OAI.pth' 34 | 35 | ATTN_SAVE_DIR='/datadisk/kblamatt2' 36 | LOG_SAVE_DIR='/datadisk/kblamatt2/acc_results' 37 | 38 | 39 | LR=1e-4 40 | TRAIN_KB_SIZE=0 # Randomly pick KB size during training time 41 | OUTLIER_RATIO=-1 # Ratio of no-answer question in the batch, -1 stands for no such samples 42 | KB_LAYER_FREQ=3 # How frequent to inject kb tokens into the layers 43 | MULTI_ENTITIES_NUM=2 # For questions that involve multiple entities, how many entities are involved. 44 | 45 | 46 | help: 47 | @python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST) 48 | 49 | create_train_test_split: 50 | python ../dataset_generation/create_train_test_split.py --data_path ${DATASET_DIR}/synthetic_data.json \ 51 | --embedding_keys_path ${DATASET_DIR}/synthetic_data_oai_embd_key.npy \ 52 | --embeddings_values_path ${DATASET_DIR}/synthetic_data_oai_embd_value.npy \ 53 | --output_path /datadisk/data/train_test_split \ 54 | --split_index 120000 55 | 56 | 57 | train: ## Train kb adapter 58 | python train.py --model_dir ${LLAMA_BASE_DIR} \ 59 | --dataset_dir ${DATASET_DIR} \ 60 | --model_save_dir ${CKPT_SAVE_DIR} \ 61 | --seed 1607 \ 62 | --dataset gpt_data \ 63 | --N 120000 \ 64 | --B 20 \ 65 | --steps 20001 \ 66 | --encoder_spec all-MiniLM-L6-v2 \ 67 | --use_cached_embd \ 68 | --key_embd_src key \ 69 | --use_data_aug \ 70 | --use_lr_decay \ 71 | --tune_llm_q \ 72 | --sep_query_head \ 73 | --lr ${LR} \ 74 | --kb_size ${TRAIN_KB_SIZE} \ 75 | --kb_token_layer_frequency ${KB_LAYER_FREQ} \ 76 | --multi_entities ${MULTI_ENTITIES_NUM} \ 77 | --use_extended_qa \ 78 | --outlier_ratio ${OUTLIER_RATIO} \ 79 | --gradient_accm_step 20 80 | 81 | train-oai: ## Train kb adapter 82 | python train.py --model_dir ${LLAMA_HF} \ 83 | --dataset_dir ${DATASET_DIR} \ 84 | --model_save_dir ${CKPT_SAVE_DIR} \ 85 | --seed 1607 \ 86 | --dataset gpt_data \ 87 | --N 120000 \ 88 | --B 2 \ 89 | --steps 30001 \ 90 | --encoder_spec OAI \ 91 | --use_cached_embd \ 92 | --key_embd_src key \ 93 | --use_data_aug \ 94 | --use_lr_decay \ 95 | --tune_llm_q \ 96 | --sep_query_head \ 97 | --lr ${LR} \ 98 | --no-duplicate_true_kb \ 99 | --kb_size ${TRAIN_KB_SIZE} \ 100 | --kb_token_layer_frequency ${KB_LAYER_FREQ} \ 101 | --multi_entities ${MULTI_ENTITIES_NUM} \ 102 | --use_extended_qa \ 103 | --outlier_ratio ${OUTLIER_RATIO} \ 104 | --gradient_accm_step 20 105 | 106 | 107 | 108 | 109 | train-phi3-oai: ## Train kb adapter 110 | python train.py --model_dir ${PHI3_HF} \ 111 | --dataset_dir ${DATASET_DIR} \ 112 | --model_save_dir ${CKPT_SAVE_DIR} \ 113 | --seed 1607 \ 114 | --train_dataset synthetic_data \ 115 | --N 120000 \ 116 | --B 32 \ 117 | --steps 30001 \ 118 | --encoder_spec OAI \ 119 | --use_cached_embd \ 120 | --key_embd_src key \ 121 | --use_data_aug \ 122 | --use_lr_decay \ 123 | --tune_llm_q \ 124 | --sep_query_head \ 125 | --lr ${LR} \ 126 | --no-duplicate_true_kb \ 127 | --kb_size ${TRAIN_KB_SIZE} \ 128 | --kb_token_layer_frequency ${KB_LAYER_FREQ} \ 129 | --multi_entities ${MULTI_ENTITIES_NUM} \ 130 | --use_extended_qa \ 131 | --outlier_ratio ${OUTLIER_RATIO} \ 132 | --gradient_accm_step 20 \ 133 | --llm_type 'phi3' \ 134 | --use_cuda 135 | 136 | 137 | #----------------- phi eval -------------------------------------- 138 | 139 | 140 | eval-acc-phi3-oai: ## Eval kb adapter 141 | python eval.py accuracy \ 142 | --seed 1607 \ 143 | --dataset_dir ${TEST_DATASET_DIR} \ 144 | --test_dataset test_synthetic_data.json \ 145 | --llm_base_dir ${PHI3_HF} \ 146 | --model_dir ${PHI3_MODEL_CHKPT} \ 147 | --encoder_dir ${ENCODER_FOR_PHI3_CHKPT} \ 148 | --save_dir ${ATTN_SAVE_DIR} \ 149 | --kb_layer_frequency ${KB_LAYER_FREQ} \ 150 | --kb_size 3200 \ 151 | --kb_scale_factor 100 \ 152 | --no-fancy_instruction \ 153 | --encoder_spec oai \ 154 | --llm_type "phi3" \ 155 | --attn_save_dir ${ATTN_SAVE_DIR} \ 156 | --log_save_dir ${LOG_SAVE_DIR} \ 157 | --precomputed_embed_keys_path ${TEST_DATASET_DIR}/test_synthetic_data_oai_embd_key.npy \ 158 | --precomputed_embed_values_path ${TEST_DATASET_DIR}/test_synthetic_data_oai_embd_value.npy 159 | 160 | 161 | eval-acc-eval-phi3-oai: ## Eval kb adapter 162 | python eval.py acc_results \ 163 | --seed 1607 \ 164 | --dataset_dir ${TEST_DATASET_DIR} \ 165 | --test_dataset test_synthetic_data.json \ 166 | --llm_base_dir ${PHI3_HF} \ 167 | --model_dir ${PHI3_MODEL_CHKPT} \ 168 | --encoder_dir ${ENCODER_FOR_PHI3_CHKPT} \ 169 | --save_dir ${ATTN_SAVE_DIR} \ 170 | --kb_layer_frequency ${KB_LAYER_FREQ} \ 171 | --kb_size 3200 \ 172 | --kb_scale_factor 100 \ 173 | --no-fancy_instruction \ 174 | --encoder_spec oai \ 175 | --llm_type "phi3" \ 176 | --attn_save_dir ${ATTN_SAVE_DIR} \ 177 | --log_save_dir ${LOG_SAVE_DIR} \ 178 | --precomputed_embed_keys_path ${TEST_DATASET_DIR}/test_synthetic_data_oai_embd_key.npy \ 179 | --precomputed_embed_values_path ${TEST_DATASET_DIR}/test_synthetic_data_oai_embd_value.npy 180 | 181 | 182 | eval-gen-phi3-oai: ## Eval kb adapter 183 | python eval.py generation \ 184 | --seed 1607 \ 185 | --dataset_dir ${TEST_DATASET_DIR} \ 186 | --test_dataset test_synthetic_data.json \ 187 | --llm_base_dir ${PHI3_HF} \ 188 | --model_dir ${PHI3_MODEL_CHKPT} \ 189 | --encoder_dir ${ENCODER_FOR_PHI3_CHKPT} \ 190 | --save_dir ${ATTN_SAVE_DIR} \ 191 | --kb_layer_frequency ${KB_LAYER_FREQ} \ 192 | --kb_size 3200 \ 193 | --kb_scale_factor 100 \ 194 | --no-fancy_instruction \ 195 | --encoder_spec oai \ 196 | --llm_type "phi3" \ 197 | --precomputed_embed_keys_path ${TEST_DATASET_DIR}/test_synthetic_data_oai_embd_key.npy \ 198 | --precomputed_embed_values_path ${TEST_DATASET_DIR}/test_synthetic_data_oai_embd_value.npy 199 | 200 | 201 | eval-ref-phi3-oai: ## Eval kb adapter 202 | python eval.py refusal \ 203 | --seed 1607 \ 204 | --dataset_dir ${TEST_DATASET_DIR} \ 205 | --test_dataset test_synthetic_data.json \ 206 | --llm_base_dir ${PHI3_HF} \ 207 | --model_dir ${PHI3_MODEL_CHKPT} \ 208 | --encoder_dir ${ENCODER_FOR_PHI3_CHKPT} \ 209 | --save_dir ${ATTN_SAVE_DIR} \ 210 | --kb_layer_frequency ${KB_LAYER_FREQ} \ 211 | --kb_size 3200 \ 212 | --kb_scale_factor 100 \ 213 | --no-fancy_instruction \ 214 | --encoder_spec oai \ 215 | --llm_type "phi3" \ 216 | --precomputed_embed_keys_path ${TEST_DATASET_DIR}/test_synthetic_data_oai_embd_key.npy \ 217 | --precomputed_embed_values_path ${TEST_DATASET_DIR}/test_synthetic_data_oai_embd_value.npy 218 | 219 | 220 | 221 | eval-basic-phi3-oai: ## Eval kb adapter 222 | python eval.py standard \ 223 | --seed 1607 \ 224 | --dataset_dir ${TEST_DATASET_DIR} \ 225 | --test_dataset test_synthetic_data.json \ 226 | --llm_base_dir ${PHI3_HF} \ 227 | --model_dir ${PHI3_MODEL_CHKPT} \ 228 | --encoder_dir ${ENCODER_FOR_PHI3_CHKPT} \ 229 | --save_dir ${ATTN_SAVE_DIR} \ 230 | --kb_layer_frequency ${KB_LAYER_FREQ} \ 231 | --kb_size 3200 \ 232 | --kb_scale_factor 100 \ 233 | --no-fancy_instruction \ 234 | --encoder_spec oai \ 235 | --llm_type "phi3" \ 236 | --precomputed_embed_keys_path ${TEST_DATASET_DIR}/test_synthetic_data_oai_embd_key.npy \ 237 | --precomputed_embed_values_path ${TEST_DATASET_DIR}/test_synthetic_data_oai_embd_value.npy 238 | 239 | 240 | 241 | #----------------- llama eval -------------------------------------- 242 | 243 | eval-acc-llama-oai: ## Eval kb adapter 244 | python eval.py accuracy \ 245 | --seed 1607 \ 246 | --dataset_dir ${TEST_DATASET_DIR} \ 247 | --test_dataset test_synthetic_data.json \ 248 | --llm_base_dir ${LLAMA_BASE_DIR} \ 249 | --model_dir ${LLAMA_MODEL_CHKPT} \ 250 | --encoder_dir ${ENCODER_FOR_LLAMA_CHKPT} \ 251 | --save_dir ${ATTN_SAVE_DIR} \ 252 | --kb_layer_frequency ${KB_LAYER_FREQ} \ 253 | --kb_size 3200 \ 254 | --kb_scale_factor 100 \ 255 | --no-fancy_instruction \ 256 | --encoder_spec oai \ 257 | --llm_type "llama3" \ 258 | --attn_save_dir ${ATTN_SAVE_DIR} \ 259 | --log_save_dir ${LOG_SAVE_DIR} \ 260 | --query_head_path ${QUERY_HEAD_PATH} \ 261 | --precomputed_embed_keys_path ${TEST_DATASET_DIR}/test_synthetic_data_oai_embd_key.npy \ 262 | --precomputed_embed_values_path ${TEST_DATASET_DIR}/test_synthetic_data_oai_embd_value.npy 263 | 264 | 265 | eval-acc-eval-llama-oai: ## Eval kb adapter 266 | python eval.py acc_results \ 267 | --seed 1607 \ 268 | --dataset_dir ${TEST_DATASET_DIR} \ 269 | --test_dataset test_synthetic_data.json \ 270 | --llm_base_dir ${LLAMA_BASE_DIR} \ 271 | --model_dir ${LLAMA_MODEL_CHKPT} \ 272 | --encoder_dir ${ENCODER_FOR_LLAMA_CHKPT} \ 273 | --save_dir ${ATTN_SAVE_DIR} \ 274 | --kb_layer_frequency ${KB_LAYER_FREQ} \ 275 | --kb_size 100 \ 276 | --kb_scale_factor 100 \ 277 | --no-fancy_instruction \ 278 | --encoder_spec oai \ 279 | --llm_type "llama3" \ 280 | --attn_save_dir ${ATTN_SAVE_DIR} \ 281 | --log_save_dir ${LOG_SAVE_DIR} \ 282 | --query_head_path ${QUERY_HEAD_PATH} \ 283 | --precomputed_embed_keys_path ${TEST_DATASET_DIR}/test_synthetic_data_oai_embd_key.npy \ 284 | --precomputed_embed_values_path ${TEST_DATASET_DIR}/test_synthetic_data_oai_embd_value.npy 285 | 286 | 287 | eval-ref-llama-oai: ## Eval kb adapter 288 | python eval.py refusal \ 289 | --seed 1607 \ 290 | --dataset_dir ${TEST_DATASET_DIR} \ 291 | --test_dataset test_synthetic_data.json \ 292 | --llm_base_dir ${LLAMA_BASE_DIR} \ 293 | --model_dir ${LLAMA_MODEL_CHKPT} \ 294 | --encoder_dir ${ENCODER_FOR_LLAMA_CHKPT} \ 295 | --save_dir ${ATTN_SAVE_DIR} \ 296 | --kb_layer_frequency ${KB_LAYER_FREQ} \ 297 | --kb_size 100 \ 298 | --kb_scale_factor 100 \ 299 | --no-fancy_instruction \ 300 | --encoder_spec oai \ 301 | --llm_type "llama3" \ 302 | --query_head_path ${QUERY_HEAD_PATH} \ 303 | --precomputed_embed_keys_path ${TEST_DATASET_DIR}/test_synthetic_data_oai_embd_key.npy \ 304 | --precomputed_embed_values_path ${TEST_DATASET_DIR}/test_synthetic_data_oai_embd_value.npy 305 | 306 | 307 | eval-basic-llama-oai: ## Eval kb adapter 308 | python eval.py standard \ 309 | --seed 1607 \ 310 | --dataset_dir ${TEST_DATASET_DIR} \ 311 | --test_dataset test_synthetic_data.json \ 312 | --llm_base_dir ${LLAMA_BASE_DIR} \ 313 | --model_dir ${LLAMA_MODEL_CHKPT} \ 314 | --encoder_dir ${ENCODER_FOR_LLAMA_CHKPT} \ 315 | --save_dir ${ATTN_SAVE_DIR} \ 316 | --kb_layer_frequency ${KB_LAYER_FREQ} \ 317 | --exp_config_str "llama3_kb_scale_100" \ 318 | --kb_size 100 \ 319 | --kb_scale_factor 100 \ 320 | --no-fancy_instruction \ 321 | --encoder_spec oai \ 322 | --llm_type "llama3" \ 323 | --query_head_path ${QUERY_HEAD_PATH} \ 324 | --precomputed_embed_keys_path ${TEST_DATASET_DIR}/test_synthetic_data_oai_embd_key.npy \ 325 | --precomputed_embed_values_path ${TEST_DATASET_DIR}/test_synthetic_data_oai_embd_value.npy 326 | 327 | 328 | 329 | 330 | .PHONY: train train-oai 331 | -------------------------------------------------------------------------------- /experiments/eval.py: -------------------------------------------------------------------------------- 1 | """Script for evaluating KB models""" 2 | 3 | import argparse 4 | import json 5 | import os 6 | import re 7 | import shutil 8 | from pathlib import Path 9 | from typing import Any, Dict, List, Optional 10 | 11 | import evaluate 12 | import nltk 13 | import numpy as np 14 | import torch 15 | import transformers 16 | from tqdm import tqdm 17 | from transformers import AutoTokenizer, logging 18 | 19 | from kblam.kb_encoder import KBEncoder 20 | from kblam.models.kblam_config import KBLaMConfig 21 | from kblam.models.llama3_model import KblamLlamaForCausalLM 22 | from kblam.models.phi3_model import KBLaMPhi3ForCausalLM 23 | from kblam.utils.data_utils import aug_row, generate_multi_entity_qa 24 | from kblam.utils.eval_utils import ( 25 | instruction_prompts, 26 | instruction_prompts_multi_entities, 27 | zero_shot_prompt, 28 | zero_shot_prompt_multi_entities, 29 | _format_Q_llama, 30 | _format_Q_phi3, 31 | model_prune_format_mapping, 32 | answer_question, 33 | softmax, 34 | ) 35 | from kblam.utils.train_utils import get_kb_embd 36 | 37 | nltk.download("wordnet") 38 | logging.set_verbosity_warning() 39 | 40 | rouge = evaluate.load("rouge") 41 | bert_score = evaluate.load("bertscore") 42 | 43 | 44 | class KBRetriever: 45 | def __init__( 46 | self, 47 | encoder: KBEncoder, 48 | dataset: List[Dict], 49 | precomputed_embed_keys_path: Optional[str] = None, 50 | precomputed_embed_values_path: Optional[np.ndarray] = None, 51 | ): 52 | self.encoder = encoder 53 | self.dataset = dataset 54 | if precomputed_embed_keys_path is not None: 55 | self.key_embds = np.load(precomputed_embed_keys_path).astype("float32") 56 | else: 57 | self.key_embds = None 58 | if precomputed_embed_values_path is not None: 59 | self.value_embds = np.load(precomputed_embed_values_path).astype("float32") 60 | else: 61 | self.value_embds = None 62 | 63 | if precomputed_embed_keys_path is not None: 64 | assert len(dataset) == len(self.key_embds) 65 | 66 | def _use_cached_embd(self): 67 | if self.key_embds is not None and self.value_embds is not None: 68 | return True 69 | else: 70 | return False 71 | 72 | def get_key_embeddings(self, batch_indices): 73 | if self._use_cached_embd(): 74 | return get_kb_embd( 75 | self.encoder, 76 | batch_indices, 77 | precomputed_embd=(self.key_embds, self.value_embds), 78 | ) 79 | else: 80 | return get_kb_embd(self.encoder, batch_indices, kb_dict=self.dataset) 81 | 82 | 83 | def perform_eval( 84 | model: KBLaMPhi3ForCausalLM | KblamLlamaForCausalLM, 85 | tokenizer: transformers.PreTrainedTokenizer, 86 | kb_retriever: KBRetriever, 87 | encoder_model_spec: str, 88 | kb_config: KBLaMConfig, 89 | eval_mode: str = "kb", 90 | kb_size: int = 250, 91 | seed: int = 1, 92 | topk_size: int = -1, 93 | multi_entites: int = -1, 94 | remove_sorry: bool = False, 95 | ): 96 | np.random.seed(seed) 97 | kb_idx = np.random.randint(0, len(kb_retriever.dataset), kb_size) 98 | test_kb = [kb_retriever.dataset[idx] for idx in kb_idx] 99 | kb_embedding = () 100 | key_str = [row["key_string"] for row in test_kb] 101 | value_str = [row["description"] for row in test_kb] 102 | prompt_strs = "" 103 | for k, v in zip(key_str, value_str): 104 | prompt_strs += f"{k} is {v}; " 105 | 106 | kb_embedding = kb_retriever.get_key_embeddings(kb_idx) 107 | 108 | model_outputs = [] 109 | answers = [] 110 | full_outputs = [] 111 | # answer_question 112 | subset_size = min( 113 | 400, len(test_kb) 114 | ) # Regardless of KB size, always test 250 questions, otherwise it will be too slow 115 | subset_size = min( 116 | 400, len(test_kb) 117 | ) # Regardless of KB size, always test 250 questions, otherwise it will be too slow 118 | # subset_size = 50 119 | for row in tqdm(test_kb[:subset_size]): 120 | if multi_entites == -1: 121 | Q = row["Q"] 122 | answer = row["A"] 123 | else: 124 | kb_subset_idx = np.random.randint(0, len(test_kb), multi_entites) 125 | Q, A = generate_multi_entity_qa( 126 | [test_kb[i]["name"] for i in kb_subset_idx], 127 | [test_kb[i]["description_type"] for i in kb_subset_idx], 128 | [test_kb[i]["description"] for i in kb_subset_idx], 129 | ) 130 | answer = A 131 | 132 | if eval_mode == "kb": 133 | model_output = answer_question( 134 | tokenizer, 135 | model, 136 | Q, 137 | kb=kb_embedding, 138 | topk_size=topk_size, 139 | kb_config=kb_config, 140 | ).split(Q)[1] 141 | elif eval_mode == "icl": 142 | if multi_entites != -1: 143 | ins_prompt = instruction_prompts_multi_entities 144 | else: 145 | ins_prompt = instruction_prompts 146 | model_output = answer_question( 147 | tokenizer, 148 | model, 149 | ins_prompt + prompt_strs + Q, 150 | kb=None, 151 | kb_config=kb_config, 152 | ).split(Q)[1] 153 | elif eval_mode == "zeroshot": 154 | if multi_entites != -1: 155 | ins_prompt = zero_shot_prompt_multi_entities 156 | else: 157 | ins_prompt = zero_shot_prompt 158 | model_output = answer_question( 159 | tokenizer, model, ins_prompt + Q, kb=None, kb_config=kb_config 160 | ).split(Q)[1] 161 | # print(model_output) 162 | if remove_sorry: 163 | if "sorry" in model_output: 164 | continue 165 | full_outputs.append((model_output, answer)) 166 | if multi_entites == -1: 167 | pattern = r'The\s+\w+\s+of\s+[^"]+\s+is\s+(.+)' 168 | match = re.search(pattern, model_output) 169 | answers.append(row["description"]) 170 | if match: 171 | model_output = match.group(1) 172 | else: 173 | pattern = r"(?:is|are) (.*?)(?:\.|;)" 174 | matches = re.findall(pattern, model_output) 175 | model_output = "; ".join(matches) 176 | answers.append(";".join(re.findall(r"(?:is|are) (.*?);", answer))) 177 | model_outputs.append(model_output) 178 | 179 | print(f"KB size: {kb_size}, mode: {eval_mode}") 180 | rouge = evaluate.load("rouge") 181 | 182 | for pred, gt in zip(model_outputs, answers): 183 | print(f"PREDICTION: {pred}") 184 | print(f"GT: {gt}") 185 | rouge_scores = rouge.compute(predictions=model_outputs, references=answers) 186 | print(rouge_scores) 187 | 188 | results_dict = {k: float(v) for k, v in rouge_scores.items()} 189 | 190 | bertscore = bert_score.compute( 191 | predictions=model_outputs, 192 | references=answers, 193 | lang="en", 194 | model_type="microsoft/deberta-xlarge-mnli", 195 | ) 196 | # bert_scores = [] 197 | # bert_scores = {} 198 | for k, v in bertscore.items(): 199 | if isinstance(v, list): 200 | # bert_scores.append(np.mean(v)) 201 | results_dict[f"bert_score_{k}"] = float(np.mean(v)) 202 | print(k, np.mean(v)) 203 | results = "" 204 | for a, A in full_outputs: 205 | results += f"Model output: {a}\nTrue answer: {A}\n-------\n" 206 | if eval_mode == "kb": 207 | eval_mode = encoder_model_spec + eval_mode 208 | 209 | return results, results_dict 210 | 211 | 212 | def perform_eval_refusal( 213 | model: KBLaMPhi3ForCausalLM | KblamLlamaForCausalLM, 214 | tokenizer: transformers.PreTrainedTokenizer, 215 | kb_retriever: KBRetriever, 216 | kb_config: Optional[KBLaMConfig] = None, 217 | eval_mode: str = "kb", 218 | kb_size: int = 250, 219 | seed: int = 1, 220 | outlier_ratio: float = 0.2, 221 | topk_size: int = -1, 222 | question_size: int = 100, 223 | ): 224 | instruction_prompts = ( 225 | 'Please answer questions based on the given text with format: "The {property} of {name} is {description}",' 226 | ' if relevant information cannot be found in the text, please respond "I am sorry I cannot find relevant information in the KB".' 227 | ) 228 | zero_shot_prompt = """ 229 | Please answer the question in a very compact manner with format: The {property} of {name} is {description} 230 | """ 231 | 232 | np.random.seed(seed) 233 | kb_idx = np.random.randint(0, len(kb_retriever.dataset), kb_size) 234 | test_kb = [kb_retriever.dataset[idx] for idx in kb_idx] 235 | kb_embedding = () 236 | key_str = [row["key_string"] for row in test_kb] 237 | value_str = [row["description"] for row in test_kb] 238 | prompt_strs = "" 239 | for k, v in zip(key_str, value_str): 240 | prompt_strs += f"{k} is {v}; " 241 | 242 | kb_embedding = kb_retriever.get_key_embeddings(kb_idx) 243 | 244 | model_outputs = [] 245 | answers = [] 246 | # answer_question 247 | outlier_idx = np.arange(len(kb_retriever.dataset)) 248 | outlier_idx = outlier_idx[~np.isin(outlier_idx, kb_idx)] 249 | np.random.shuffle(outlier_idx) 250 | question_size = min(kb_size, question_size) 251 | outlier_idx = outlier_idx[: int(question_size * outlier_ratio)] 252 | test_kb = test_kb[: int(question_size * (1 - outlier_ratio))] + [ 253 | kb_retriever.dataset[idx] for idx in outlier_idx 254 | ] 255 | change_point = int(question_size * (1 - outlier_ratio)) 256 | for i, row in tqdm(enumerate(test_kb)): 257 | Q = row["Q"] 258 | if eval_mode == "kb": 259 | model_output = answer_question( 260 | tokenizer, 261 | model, 262 | Q, 263 | kb=kb_embedding, 264 | topk_size=topk_size, 265 | kb_config=kb_config, 266 | ).split(Q)[1] 267 | 268 | elif eval_mode == "icl": 269 | model_output = answer_question( 270 | tokenizer, 271 | model, 272 | instruction_prompts + prompt_strs + Q, 273 | kb=None, 274 | kb_config=kb_config, 275 | ).split(Q)[1] 276 | elif eval_mode == "zeroshot": 277 | model_output = answer_question( 278 | tokenizer, 279 | model, 280 | zero_shot_prompt + Q, 281 | kb=None, 282 | kb_config=kb_config, 283 | ).split(Q)[1] 284 | model_outputs.append(model_output) 285 | if i < change_point: 286 | answers.append(row["description"]) 287 | else: 288 | answers.append("Cannot find relevant information in the KB") 289 | true_label = [0] * change_point + [1] * int(question_size * outlier_ratio) 290 | prediction = [int("sorry" in model_output) for model_output in model_outputs] 291 | print(f"KB size: {kb_size}, mode: {eval_mode}, outlier ratio: {outlier_ratio}") 292 | results = "" 293 | for a, A in zip(model_outputs, answers): 294 | results += f"Model output: {a}\nTrue answer: {A}\n-------\n" 295 | return results, np.array([prediction, true_label]) 296 | 297 | 298 | parser = argparse.ArgumentParser(description="Evaluation script") 299 | 300 | # Add arguments that will be shared across all subcommands 301 | parent_parser = argparse.ArgumentParser(add_help=False) 302 | 303 | parent_parser.add_argument( 304 | "--dataset_dir", type=str, help="Directory containing the dataset" 305 | ) 306 | parent_parser.add_argument( 307 | "--encoder_dir", type=str, help="Directory containing the encoder model" 308 | ) 309 | parent_parser.add_argument( 310 | "--encoder_spec", 311 | type=str, 312 | default="OAI", 313 | help="Specification for the encoder model", 314 | ) 315 | parent_parser.add_argument( 316 | "--fancy_instruction", 317 | action=argparse.BooleanOptionalAction, 318 | default=False, 319 | help="Whether to use fancy instructions", 320 | ) 321 | parent_parser.add_argument( 322 | "--kb_layer_frequency", 323 | type=int, 324 | default=3, 325 | help="Frequency of knowledge base layers", 326 | ) 327 | parent_parser.add_argument( 328 | "--kb_scale_factor", 329 | type=int, 330 | default=None, 331 | help="Scaling factor for knowledge base", 332 | ) 333 | parent_parser.add_argument( 334 | "--kb_size", type=int, default=200, help="Size of the knowledge base" 335 | ) 336 | parent_parser.add_argument( 337 | "--llm_base_dir", 338 | type=str, 339 | help="llm to load, can be HF location or local directory", 340 | ) 341 | parent_parser.add_argument( 342 | "--llm_type", 343 | type=str, 344 | default="phi3", 345 | choices=["llama3", "phi3"], 346 | help="Type of language model to use", 347 | ) 348 | parent_parser.add_argument( 349 | "--model_dir", type=str, help="Directory containing the model" 350 | ) 351 | parent_parser.add_argument("--save_dir", type=str, help="Directory to save outputs") 352 | parent_parser.add_argument("--seed", type=int, help="Random seed for reproducibility") 353 | parent_parser.add_argument( 354 | "--test_dataset", type=str, help="Source of test KB (assumes KV pair format)" 355 | ) 356 | parent_parser.add_argument( 357 | "--precomputed_embed_keys_path", type=str, help="Path to precomputed key embeddings" 358 | ) 359 | parent_parser.add_argument( 360 | "--precomputed_embed_values_path", 361 | type=str, 362 | help="Path to precomputed value embeddings", 363 | ) 364 | parent_parser.add_argument( 365 | "--query_head_path", type=str, default="", help="Path to load KB head from" 366 | ) 367 | 368 | # Create subparsers 369 | subparsers = parser.add_subparsers(dest="command", required=True) 370 | 371 | # Create the parser for the generation command 372 | gen_parser = subparsers.add_parser( 373 | "generation", parents=[parent_parser], help="Evaluate generation" 374 | ) 375 | gen_parser.add_argument( 376 | "--eval_mode", 377 | type=str, 378 | choices=["kb", "icl", "zeroshot"], 379 | default="kb", 380 | help="Evaluation mode: knowledge base, in-context learning, or zero-shot", 381 | ) 382 | gen_parser.add_argument( 383 | "--exp_config_name", 384 | type=str, 385 | default="generation_results", 386 | help="Name of the experiment configuration", 387 | ) 388 | gen_parser.add_argument( 389 | "--kb_token_layer_frequency", 390 | type=int, 391 | default=None, 392 | help="Frequency of knowledge base token layers", 393 | ) 394 | gen_parser.add_argument( 395 | "--multi_entites", 396 | type=int, 397 | default=-1, 398 | help="Number of entities to process (-1 for unlimited)", 399 | ) 400 | gen_parser.add_argument( 401 | "--no_outlier", 402 | action=argparse.BooleanOptionalAction, 403 | default=False, 404 | help="Use checkpoints trained without outliers", 405 | ) 406 | gen_parser.add_argument( 407 | "--remove_sorry", 408 | action=argparse.BooleanOptionalAction, 409 | default=False, 410 | help='Filter out "sorry" answers from the output', 411 | ) 412 | gen_parser.add_argument( 413 | "--topk_size", type=int, default=-1, help="Size of top-k selection (-1 for all)" 414 | ) 415 | 416 | 417 | # Create the parser for the accuracy command 418 | acc_parser = subparsers.add_parser( 419 | "accuracy", parents=[parent_parser], help="Evaluate accuracy" 420 | ) 421 | 422 | acc_parser.add_argument( 423 | "--attn_save_dir", type=str, default="", help="Directory to save attention masks" 424 | ) 425 | acc_parser.add_argument( 426 | "--exp_config_name", 427 | type=str, 428 | default="accuracy_results", 429 | help="Name of the experiment configuration", 430 | ) 431 | acc_parser.add_argument( 432 | "--fancy_question", 433 | action=argparse.BooleanOptionalAction, 434 | default=False, 435 | help="Enable fancy question format", 436 | ) 437 | acc_parser.add_argument( 438 | "--log_save_dir", type=str, help="Directory to save accuracy results" 439 | ) 440 | acc_parser.add_argument( 441 | "--test_batch_size", type=int, default=50, help="Batch size for testing" 442 | ) 443 | acc_parser.add_argument( 444 | "--use_shift_match", 445 | action=argparse.BooleanOptionalAction, 446 | default=False, 447 | help="Enable shift matching", 448 | ) 449 | 450 | # Create the parser for the accuracy eval 451 | acc_results_parser = subparsers.add_parser( 452 | "acc_results", parents=[acc_parser], help="run accuracy eval", add_help=False 453 | ) 454 | 455 | 456 | # Create the parser for the refusal command 457 | ref_parser = subparsers.add_parser( 458 | "refusal", parents=[parent_parser], help="Evaluate refusal" 459 | ) 460 | ref_parser.add_argument( 461 | "--eval_mode", 462 | type=str, 463 | choices=["kb", "icl", "zeroshot"], 464 | default="kb", 465 | help="Evaluation mode: knowledge base, in-context learning, or zero-shot", 466 | ) 467 | ref_parser.add_argument( 468 | "--exp_config_name", 469 | type=str, 470 | default="refusal_results", 471 | help="Name of the experiment configuration", 472 | ) 473 | ref_parser.add_argument( 474 | "--kb_token_layer_frequency", 475 | type=int, 476 | default=None, 477 | help="Frequency of knowledge base token layers", 478 | ) 479 | ref_parser.add_argument( 480 | "--multi_entites", 481 | type=int, 482 | default=-1, 483 | help="Number of entities to process (-1 for unlimited)", 484 | ) 485 | ref_parser.add_argument( 486 | "--no_outlier", 487 | action=argparse.BooleanOptionalAction, 488 | default=False, 489 | help="Use checkpoints trained without outliers", 490 | ) 491 | ref_parser.add_argument( 492 | "--remove_sorry", 493 | action=argparse.BooleanOptionalAction, 494 | default=False, 495 | help='Filter out "sorry" answers from the output', 496 | ) 497 | ref_parser.add_argument( 498 | "--topk_size", type=int, default=-1, help="Size of top-k selection (-1 for all)" 499 | ) 500 | 501 | # Create the parser for the standard command 502 | basic_parser = subparsers.add_parser( 503 | "standard", parents=[parent_parser], help="Evaluate basic performance" 504 | ) 505 | basic_parser.add_argument( 506 | "--attn_summary_save_dir", 507 | type=str, 508 | default="", 509 | help="Directory to save attention masks", 510 | ) 511 | basic_parser.add_argument( 512 | "--eval_mode", 513 | type=str, 514 | choices=["kb", "icl", "zeroshot"], 515 | default="kb", 516 | help="Evaluation mode: knowledge base, in-context learning, or zero-shot", 517 | ) 518 | basic_parser.add_argument( 519 | "--exp_config_name", 520 | type=str, 521 | default="basic_results", 522 | help="Name of the experiment configuration", 523 | ) 524 | basic_parser.add_argument( 525 | "--exp_config_str", type=str, help="Experiment configuration string" 526 | ) 527 | basic_parser.add_argument( 528 | "--kb_token_layer_frequency", 529 | type=int, 530 | default=None, 531 | help="Frequency of knowledge base token layers", 532 | ) 533 | basic_parser.add_argument( 534 | "--no_outlier", 535 | action=argparse.BooleanOptionalAction, 536 | default=False, 537 | help="Use checkpoints trained without outliers", 538 | ) 539 | basic_parser.add_argument( 540 | "--sample_size", default=5, type=int, help="Number of samples to process" 541 | ) 542 | basic_parser.add_argument( 543 | "--subset_size", default=100, type=int, help="Size of the data subset to use" 544 | ) 545 | basic_parser.add_argument( 546 | "--topk_size", type=int, default=-1, help="Size of top-k selection (-1 for all)" 547 | ) 548 | 549 | 550 | def eval_generate(): 551 | """Evaluate generation using KB""" 552 | args = parser.parse_args() 553 | 554 | dataset_dir = args.dataset_dir 555 | encoder_model_spec = args.encoder_spec 556 | encoder_path = args.encoder_dir 557 | eval_mode = args.eval_mode 558 | exp_config = args.exp_config_name 559 | kb_layer_frequency = args.kb_layer_frequency 560 | kb_scale_factor = args.kb_scale_factor 561 | kb_size = args.kb_size 562 | llm_base_dir = args.llm_base_dir 563 | llm_type = args.llm_type 564 | model_path = args.model_dir 565 | seed = args.seed 566 | test_dataset = args.test_dataset 567 | query_head_path = args.query_head_path 568 | precomputed_embed_keys_path = args.precomputed_embed_keys_path 569 | precomputed_embed_values_path = args.precomputed_embed_values_path 570 | 571 | dataset = json.load(open(os.path.join(dataset_dir, test_dataset))) 572 | 573 | tokenizer, encoder, model, kb_config = _prepare_models( 574 | encoder_model_spec, 575 | encoder_path, 576 | llm_type, 577 | llm_base_dir, 578 | model_path, 579 | query_head_path, 580 | kb_layer_frequency, 581 | kb_scale_factor, 582 | ) 583 | 584 | kb_retriever = KBRetriever( 585 | encoder, 586 | dataset, 587 | precomputed_embed_keys_path=precomputed_embed_keys_path, 588 | precomputed_embed_values_path=precomputed_embed_values_path, 589 | ) 590 | 591 | gen_results, score_results = perform_eval( 592 | model, 593 | tokenizer, 594 | kb_retriever, 595 | encoder_model_spec, 596 | kb_config, 597 | eval_mode, 598 | seed=seed, 599 | kb_size=kb_size, 600 | topk_size=args.topk_size, 601 | multi_entites=args.multi_entites, 602 | ) 603 | mem_cost = torch.cuda.max_memory_reserved("cuda") 604 | score_results["mem_cost"] = mem_cost 605 | 606 | (Path(args.save_dir) / exp_config).mkdir(exist_ok=True, parents=True) 607 | write_to_json(score_results, Path(args.save_dir) / f"{exp_config}.json") 608 | print(score_results) 609 | text_file = open(os.path.join(args.save_dir, exp_config + ".txt"), "w") 610 | text_file.write(gen_results) 611 | 612 | 613 | def _prepare_models( 614 | encoder_spec, 615 | encoder_path, 616 | llm_type, 617 | llm_base_dir, 618 | model_path, 619 | query_head_path, 620 | kb_layer_frequency, 621 | kb_scale_factor, 622 | ): 623 | tokenizer = AutoTokenizer.from_pretrained( 624 | llm_base_dir, trust_remote_code=True, padding_side="left" 625 | ) 626 | tokenizer.pad_token = "^" 627 | 628 | if llm_type == "llama3": 629 | if query_head_path: 630 | model = KblamLlamaForCausalLM.from_pretrained( 631 | model_path, 632 | device_map="cuda", 633 | torch_dtype="auto", 634 | trust_remote_code=True, 635 | ) 636 | model.load_query_head(query_head_path) 637 | else: 638 | model = KblamLlamaForCausalLM.from_pretrained( 639 | model_path, 640 | device_map="cuda", 641 | torch_dtype="auto", 642 | trust_remote_code=True, 643 | ) 644 | else: 645 | model = KBLaMPhi3ForCausalLM.from_pretrained( 646 | model_path, 647 | device_map="cuda", 648 | torch_dtype="auto", 649 | trust_remote_code=True, 650 | ) 651 | model.generation_config.pad_token_id = tokenizer.pad_token_id 652 | model.generation_config.eos_token_id = tokenizer.eos_token_id 653 | model.eval() 654 | 655 | # config = model.config.to_dict() 656 | kb_config = KBLaMConfig( 657 | sep_query_head=True, 658 | kb_layer_frequency=kb_layer_frequency, 659 | kb_scale_factor=kb_scale_factor, 660 | ) 661 | # config.update(kb_config.to_dict()) 662 | # new_config = KBLaMConfig(**config) 663 | # model.config = new_config 664 | 665 | encoder = KBEncoder( 666 | encoder_name=encoder_spec.upper(), 667 | projector_type="linear", 668 | endpoint_url="", 669 | out_dim=model.config.hidden_size 670 | * (model.config.num_hidden_layers // kb_layer_frequency + 1), 671 | frozen_base_model=True, 672 | projector_kwargs={"mlp_depth": 1, "mlp_hidden_dim": 512}, 673 | device=torch.device("cuda"), 674 | ) 675 | 676 | encoder.load_state_dict(torch.load(encoder_path)) 677 | return tokenizer, encoder, model, kb_config 678 | 679 | 680 | def eval_accuracy( 681 | tokenizer, 682 | kb_retriever, 683 | model, 684 | dataset, 685 | exp_config, 686 | fancy_question, 687 | kb_config, 688 | kb_size, 689 | llm_type, 690 | test_batch_size, 691 | save_dir, 692 | attn_save_dir, 693 | ): 694 | """Evaluate accuracy using KB""" 695 | 696 | if kb_size == len(dataset): 697 | dataset_subset_idx = range(len(dataset)) 698 | elif kb_size > len(dataset): 699 | raise IndexError( 700 | f"The KB size {kb_size} is greater than the dataset size {len(dataset)}" 701 | ) 702 | else: 703 | dataset_subset_idx = np.random.choice(len(dataset), kb_size, replace=False) 704 | 705 | dataset_subset = [dataset[i] for i in dataset_subset_idx] 706 | 707 | kb_embedding_real = kb_retriever.get_key_embeddings(dataset_subset_idx) 708 | 709 | format_func_map = {"llama3": _format_Q_llama, "phi3": _format_Q_phi3} 710 | 711 | if not fancy_question: 712 | input_strs_gen = (dataset_subset[i]["Q"] for i in range(test_batch_size)) 713 | else: 714 | input_strs_gen = (aug_row(dataset_subset[i]) for i in range(test_batch_size)) 715 | input_strs = [format_func_map[llm_type](ex) for ex in input_strs_gen] 716 | 717 | tokenizer_output = tokenizer(input_strs, return_tensors="pt", padding=True).to( 718 | "cuda" 719 | ) 720 | input_ids, attention_masks = ( 721 | tokenizer_output["input_ids"], 722 | tokenizer_output["attention_mask"], 723 | ) 724 | 725 | with torch.autograd.no_grad(): 726 | outputs = model.generate( 727 | input_ids=input_ids, 728 | attention_mask=attention_masks, 729 | kb_kvs=kb_embedding_real, 730 | max_new_tokens=60, 731 | tokenizer=tokenizer, 732 | output_attentions=True, 733 | save_attention_weights=True, 734 | kb_config=kb_config, 735 | attention_save_loc=attn_save_dir, 736 | attention_file_base_name=exp_config, 737 | ) 738 | outputs = tokenizer.batch_decode(outputs.squeeze(), skip_special_tokens=False) 739 | 740 | save_path = Path(save_dir) 741 | save_path.mkdir(exist_ok=True, parents=True) 742 | 743 | with open(save_path / f"{exp_config}_acc.txt", "w+") as text_file: 744 | for output in outputs: 745 | output_string = output.strip("^") 746 | text_file.write(f"{str(output_string)}\n") 747 | 748 | accs = [] 749 | with torch.autograd.no_grad(): 750 | for idx in range(0, 32, kb_config.kb_layer_frequency): 751 | weight = np.load(os.path.join(attn_save_dir, f"{exp_config}_{idx}.npy")) 752 | weight = weight[..., :kb_size] 753 | label = np.arange(test_batch_size) 754 | weight = weight.reshape(test_batch_size, -1, kb_size) 755 | acc = (weight.sum(1).argmax(1) == label).mean() 756 | top_5_predictions = torch.topk(torch.from_numpy(weight.sum(1)), 5, dim=1)[1] 757 | top_5_acc = (top_5_predictions.numpy() == label[:, None]).any(1).mean() 758 | if idx == 15: 759 | print(f"ACC & TOP 5 ACC: {idx} {(acc, top_5_acc)}") 760 | print(f"min: {np.min(weight)} max: {np.max(weight)}") 761 | accs.append( 762 | { 763 | "idx": idx, 764 | "acc": float(acc), 765 | "top5acc": float(top_5_acc), 766 | } 767 | ) 768 | 769 | np.save( 770 | save_path / f"{exp_config}_acc.npy", 771 | np.array([(a["acc"], a["top5acc"]) for a in accs]), 772 | ) 773 | 774 | return accs 775 | 776 | 777 | def eval_accuracy_cli(): 778 | """Evaluate accuracy using KB""" 779 | args = parser.parse_args() 780 | 781 | dataset_dir = args.dataset_dir 782 | encoder_path = args.encoder_dir 783 | encoder_spec = args.encoder_spec 784 | exp_config = args.exp_config_name 785 | fancy_question = args.fancy_question 786 | kb_layer_frequency = args.kb_layer_frequency 787 | kb_scale_factor = args.kb_scale_factor 788 | kb_size = args.kb_size 789 | llm_base_dir = args.llm_base_dir 790 | llm_type = llm_type = args.llm_type 791 | model_path = args.model_dir 792 | test_batch_size = args.test_batch_size 793 | test_dataset = args.test_dataset 794 | precomputed_embed_keys_path = args.precomputed_embed_keys_path 795 | precomputed_embed_values_path = args.precomputed_embed_values_path 796 | 797 | query_head_path = args.query_head_path 798 | tokenizer, encoder, model, kb_config = _prepare_models( 799 | encoder_spec, 800 | encoder_path, 801 | llm_type, 802 | llm_base_dir, 803 | model_path, 804 | query_head_path, 805 | kb_layer_frequency, 806 | kb_scale_factor, 807 | ) 808 | dataset = json.load(open(os.path.join(dataset_dir, test_dataset))) 809 | 810 | kb_retriever = KBRetriever( 811 | encoder, 812 | dataset, 813 | precomputed_embed_keys_path=precomputed_embed_keys_path, 814 | precomputed_embed_values_path=precomputed_embed_values_path, 815 | ) 816 | 817 | eval_accuracy( 818 | tokenizer, 819 | kb_retriever, 820 | model, 821 | dataset, 822 | exp_config, 823 | fancy_question, 824 | kb_config, 825 | kb_size, 826 | llm_type, 827 | test_batch_size, 828 | args.log_save_dir, 829 | args.attn_save_dir, 830 | ) 831 | 832 | 833 | def write_to_json( 834 | data: Any, filepath: str, indent: int = 4, encoding: str = "utf-8" 835 | ) -> bool: 836 | """ 837 | Write a dictionary to a JSON file with error handling and formatting options. 838 | 839 | Args: 840 | data: Dictionary to write to JSON file 841 | filepath: Path where the JSON file should be saved 842 | indent: Number of spaces for indentation (default: 4) 843 | encoding: File encoding (default: 'utf-8') 844 | 845 | Raises: 846 | TypeError: If data is not a dictionary 847 | """ 848 | 849 | try: 850 | # Convert string path to Path object 851 | file_path = Path(filepath) 852 | 853 | # Write the JSON file 854 | with open(file_path, "w", encoding=encoding) as f: 855 | json.dump( 856 | data, 857 | f, 858 | indent=indent, 859 | sort_keys=True, # For consistent output 860 | default=str, # Handle non-serializable objects by converting to string 861 | ) 862 | 863 | except Exception as e: 864 | print(f"Error writing JSON file: {str(e)}") 865 | 866 | 867 | def run_accuracy_evalution(): 868 | args = parser.parse_args() 869 | 870 | dataset_dir = args.dataset_dir 871 | encoder_path = args.encoder_dir 872 | encoder_spec = args.encoder_spec 873 | exp_config = args.exp_config_name 874 | fancy_question = args.fancy_question 875 | kb_layer_frequency = args.kb_layer_frequency 876 | kb_scale_factor = args.kb_scale_factor 877 | llm_base_dir = args.llm_base_dir 878 | llm_type = llm_type = args.llm_type 879 | model_path = args.model_dir 880 | test_dataset = args.test_dataset 881 | 882 | query_head_path = args.query_head_path 883 | precomputed_embed_keys_path = args.precomputed_embed_keys_path 884 | precomputed_embed_values_path = args.precomputed_embed_values_path 885 | 886 | tokenizer, encoder, model, kb_config = _prepare_models( 887 | encoder_spec, 888 | encoder_path, 889 | llm_type, 890 | llm_base_dir, 891 | model_path, 892 | query_head_path, 893 | kb_layer_frequency, 894 | kb_scale_factor, 895 | ) 896 | 897 | dataset = json.load(open(os.path.join(dataset_dir, test_dataset))) 898 | kb_retriever = KBRetriever( 899 | encoder, 900 | dataset, 901 | precomputed_embed_keys_path=precomputed_embed_keys_path, 902 | precomputed_embed_values_path=precomputed_embed_values_path, 903 | ) 904 | 905 | xs = [50, 100, 200, 400, 800, 1600, 3200, 6400] 906 | accuracy_results = [] 907 | for x in xs: 908 | print(f"kb_size {x}") 909 | 910 | accs = eval_accuracy( 911 | tokenizer, 912 | kb_retriever, 913 | model, 914 | dataset, 915 | exp_config, 916 | fancy_question, 917 | kb_config, 918 | x, 919 | llm_type, 920 | min(x, 200), 921 | args.log_save_dir, 922 | args.attn_save_dir, 923 | ) 924 | shutil.rmtree(args.attn_save_dir) 925 | os.mkdir(args.attn_save_dir) 926 | accuracy_results.append({"kb_size": x, "accuracy_results": accs}) 927 | write_to_json( 928 | accuracy_results, os.path.join(args.log_save_dir, "accuracy_results.json") 929 | ) 930 | 931 | 932 | def eval_refusal(): 933 | """Evaluate refusal to answer questions for which the answer does not exist in the KB""" 934 | args = parser.parse_args() 935 | dataset_dir = args.dataset_dir 936 | encoder_model_spec = args.encoder_spec 937 | encoder_path = args.encoder_dir 938 | eval_mode = args.eval_mode 939 | exp_config = args.exp_config_name 940 | kb_layer_frequency = args.kb_layer_frequency 941 | kb_scale_factor = args.kb_scale_factor 942 | kb_size = args.kb_size 943 | llm_base_dir = args.llm_base_dir 944 | llm_type = args.llm_type 945 | model_path = args.model_dir 946 | seed = args.seed 947 | test_dataset = args.test_dataset 948 | precomputed_embed_keys_path = args.precomputed_embed_keys_path 949 | precomputed_embed_values_path = args.precomputed_embed_values_path 950 | query_head_path = args.query_head_path 951 | 952 | dataset = json.load(open(os.path.join(dataset_dir, test_dataset))) 953 | 954 | tokenizer, encoder, model, kb_config = _prepare_models( 955 | encoder_model_spec, 956 | encoder_path, 957 | llm_type, 958 | llm_base_dir, 959 | model_path, 960 | query_head_path, 961 | kb_layer_frequency, 962 | kb_scale_factor, 963 | ) 964 | 965 | kb_retriever = KBRetriever( 966 | encoder, 967 | dataset, 968 | precomputed_embed_keys_path=precomputed_embed_keys_path, 969 | precomputed_embed_values_path=precomputed_embed_values_path, 970 | ) 971 | 972 | gen_results, refusal_results = perform_eval_refusal( 973 | model, 974 | tokenizer, 975 | kb_retriever, 976 | eval_mode=eval_mode, 977 | seed=seed, 978 | kb_size=kb_size, 979 | topk_size=args.topk_size, 980 | kb_config=kb_config, 981 | ) 982 | 983 | np.save(os.path.join(args.save_dir, "OutLierTest" + exp_config), refusal_results) 984 | text_file = open( 985 | os.path.join(args.save_dir, "OutLierTest" + exp_config + ".txt"), "w" 986 | ) 987 | text_file.write(gen_results) 988 | 989 | 990 | def eval(): 991 | """Evaluate the KB model""" 992 | args = parser.parse_args() 993 | attn_summary_save_dir = args.attn_summary_save_dir 994 | dataset_dir = args.dataset_dir 995 | encoder_model_spec = args.encoder_spec 996 | encoder_path = args.encoder_dir 997 | exp_config_str = args.exp_config_str 998 | kb_layer_frequency = args.kb_layer_frequency 999 | kb_scale_factor = args.kb_scale_factor 1000 | kb_size = args.kb_size 1001 | llm_base_dir = args.llm_base_dir 1002 | llm_type = args.llm_type 1003 | model_path = args.model_dir 1004 | output_dir = args.save_dir 1005 | sample_size = args.sample_size 1006 | seed = args.seed 1007 | subset_size = args.subset_size 1008 | test_dataset = args.test_dataset 1009 | precomputed_embed_keys_path = args.precomputed_embed_keys_path 1010 | precomputed_embed_values_path = args.precomputed_embed_values_path 1011 | query_head_path = args.query_head_path 1012 | sep_query_head = True 1013 | actual_kb_token_layer_frequency = 3 1014 | 1015 | if kb_size == -1: 1016 | kb_size = None 1017 | 1018 | # validation_part_start_idx = 120000 if 'gpt' in test_dataset else 0 1019 | dataset = json.load(open(os.path.join(dataset_dir, test_dataset))) 1020 | 1021 | if sep_query_head: 1022 | print("Having seperate query head for KB!") 1023 | 1024 | torch.manual_seed(seed) 1025 | np.random.seed(seed) 1026 | 1027 | os.environ["ATTN_SAVE_DIR"] = output_dir 1028 | os.environ["EVAL_MODE"] = "1" 1029 | 1030 | tokenizer, encoder, model, kb_config = _prepare_models( 1031 | encoder_model_spec, 1032 | encoder_path, 1033 | llm_type, 1034 | llm_base_dir, 1035 | model_path, 1036 | query_head_path, 1037 | kb_layer_frequency, 1038 | kb_scale_factor, 1039 | ) 1040 | 1041 | for param in model.parameters(): 1042 | param.requires_grad = False 1043 | 1044 | # Set up the encoder 1045 | encoder = KBEncoder( 1046 | encoder_name=encoder_model_spec.upper(), 1047 | projector_type="linear", 1048 | endpoint_url="", 1049 | out_dim=model.config.hidden_size # type: ignore 1050 | * (model.config.num_hidden_layers // actual_kb_token_layer_frequency + 1), # type: ignore 1051 | frozen_base_model=True, 1052 | device=torch.device("cuda"), 1053 | ) 1054 | encoder.load_state_dict(torch.load(encoder_path)) 1055 | 1056 | kb_retriever = KBRetriever( 1057 | encoder, 1058 | dataset, 1059 | precomputed_embed_keys_path=precomputed_embed_keys_path, 1060 | precomputed_embed_values_path=precomputed_embed_values_path, 1061 | ) 1062 | no_kb_predictions = [] 1063 | predictions = [] 1064 | answer = [] 1065 | 1066 | for _ in range(sample_size): 1067 | print("******") 1068 | dataset_subset_idx = np.random.choice(len(dataset), subset_size, replace=False) 1069 | dataset_subset = [dataset[i] for i in dataset_subset_idx] 1070 | encoder.eval() 1071 | with torch.autograd.no_grad(): 1072 | kb_embedding_real = kb_retriever.get_key_embeddings(dataset_subset_idx) 1073 | kb_embedding_key, kb_embedding_val = kb_embedding_real 1074 | kb_embedding_real = (kb_embedding_key, kb_embedding_val) 1075 | 1076 | format_func_map = {"llama3": _format_Q_llama, "phi3": _format_Q_phi3} 1077 | 1078 | input_strs = [ 1079 | format_func_map[llm_type](dataset_subset[i]["Q"]) 1080 | for i in range(subset_size) 1081 | ] 1082 | 1083 | tokenizer_output = tokenizer(input_strs, return_tensors="pt", padding=True).to( 1084 | "cuda" 1085 | ) 1086 | input_ids, attention_masks = ( 1087 | tokenizer_output["input_ids"], 1088 | tokenizer_output["attention_mask"], 1089 | ) 1090 | kb_embedding_real = (kb_embedding_real[0], kb_embedding_real[1]) 1091 | 1092 | config_str = f"{exp_config_str}__kb_{subset_size}__seed_{seed}" 1093 | with torch.autograd.no_grad(): 1094 | outputs_no_kb = model.generate( 1095 | input_ids=input_ids, 1096 | attention_mask=attention_masks, 1097 | kb_kvs=None, 1098 | max_new_tokens=40, 1099 | tokenizer=tokenizer, 1100 | output_attentions=False, 1101 | kb_config=kb_config, 1102 | ) 1103 | 1104 | outputs_true_kb = model.generate( 1105 | input_ids=input_ids, 1106 | attention_mask=attention_masks, 1107 | kb_kvs=kb_embedding_real, 1108 | max_new_tokens=40, 1109 | tokenizer=tokenizer, 1110 | output_attentions=True, 1111 | save_attention_weights=True, 1112 | attention_save_loc=output_dir, 1113 | attention_file_base_name=config_str, 1114 | kb_config=kb_config, 1115 | ) 1116 | print("decoding") 1117 | outputs_no_kb = tokenizer.batch_decode(outputs_no_kb, skip_special_tokens=False) 1118 | 1119 | outputs_true_kb = tokenizer.batch_decode( 1120 | outputs_true_kb, skip_special_tokens=False 1121 | ) 1122 | print("KB:") 1123 | for i in range(subset_size): 1124 | print( 1125 | "{} : {}".format( 1126 | dataset_subset[i]["name"], dataset_subset[i]["description"] 1127 | ) 1128 | ) 1129 | 1130 | for m in model_prune_format_mapping: 1131 | if isinstance(model, m): 1132 | prune_str = model_prune_format_mapping[m] 1133 | 1134 | print("------------------") 1135 | for i in range(subset_size): 1136 | print("True KB", prune_str(outputs_true_kb[i])) 1137 | print("True answer: ", dataset_subset[i]["A"]) 1138 | no_kb_predictions.append( 1139 | prune_str(outputs_no_kb[i]).split(dataset_subset[i]["Q"])[1] 1140 | ) 1141 | predictions.append( 1142 | prune_str(outputs_true_kb[i]).split(dataset_subset[i]["Q"])[1] 1143 | ) 1144 | answer.append(dataset_subset[i]["A"]) 1145 | print("--------------------") 1146 | print("******") 1147 | 1148 | rogue_score = rouge.compute(predictions=predictions, references=answer) 1149 | np.savez( 1150 | os.path.join(attn_summary_save_dir, f"{config_str}_rouge.npy"), **rogue_score 1151 | ) 1152 | 1153 | rogue_score_no_kb = rouge.compute(predictions=no_kb_predictions, references=answer) 1154 | np.savez( 1155 | os.path.join(attn_summary_save_dir, f"{config_str}_rouge_no_kb.npy"), 1156 | **rogue_score_no_kb, 1157 | ) 1158 | 1159 | # Start inspecting attention masks 1160 | ranges = [(0, 6), (6, 12), (12, 18), (18, 24), (24, 30), (30, 32)] 1161 | 1162 | save_dir = output_dir 1163 | Path(args.save_dir).mkdir(exist_ok=True, parents=True) 1164 | 1165 | accs, confidences = [], [] 1166 | for left, right in ranges: 1167 | weights = [] 1168 | kb_size = subset_size 1169 | for idx in range(32)[left:right]: 1170 | if idx % 3 == 0: 1171 | weight = np.load(os.path.join(save_dir, f"{config_str}_{idx}.npy")) 1172 | weights.append(weight[..., :kb_size].reshape(kb_size, -1, kb_size)) 1173 | print(len(weights)) 1174 | weights = np.stack(weights) 1175 | weights = weights.transpose(1, 0, 2, 3).reshape(kb_size, -1, kb_size) 1176 | acc = (weights.sum(1).argmax(1) == np.arange(kb_size)).mean() 1177 | top_5_predictions = torch.topk(torch.from_numpy(weights.sum(1)), 5, dim=1)[1] 1178 | top_5_acc = ( 1179 | (top_5_predictions == torch.arange(kb_size)[:, None]).any(1).float().mean() 1180 | ) 1181 | accs.append((acc, top_5_acc)) 1182 | confidence = softmax(weights.mean(1), -1).max() 1183 | confidences.append(confidence) 1184 | np.save( 1185 | os.path.join(attn_summary_save_dir, f"{config_str}_acc.npy"), np.array(accs) 1186 | ) 1187 | np.save( 1188 | os.path.join(attn_summary_save_dir, f"{config_str}_conf.npy"), 1189 | np.array(confidences), 1190 | ) 1191 | 1192 | 1193 | def main(): 1194 | args = parser.parse_args() 1195 | print(args) 1196 | if args.command == "generation": 1197 | eval_generate() 1198 | elif args.command == "accuracy": 1199 | eval_accuracy_cli() 1200 | elif args.command == "acc_results": 1201 | run_accuracy_evalution() 1202 | elif args.command == "refusal": 1203 | eval_refusal() 1204 | elif args.command == "standard": 1205 | eval() 1206 | else: 1207 | raise ValueError(f"command {args.command} not recognised") 1208 | 1209 | 1210 | if __name__ == "__main__": 1211 | main() 1212 | -------------------------------------------------------------------------------- /experiments/output_scorer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from dataclasses import dataclass 4 | 5 | import numpy as np 6 | 7 | from kblam.gpt_session import GPT 8 | 9 | 10 | @dataclass 11 | class EvalExample: 12 | text: str 13 | true_answer: str 14 | score: float 15 | 16 | 17 | def save_example(example: EvalExample, output_file: str) -> None: 18 | try: 19 | with open(output_file, "a+") as f: 20 | json.dump(example.__dict__, f) 21 | f.write("\n") 22 | except Exception as e: 23 | print("Error saving example.") 24 | print(e) 25 | 26 | 27 | class Evaluator(GPT): 28 | def __init__(self, model, endpoint_url, **kwargs) -> None: 29 | self.system_msg = """You are an AI system that evaluates the quality of generated text. 30 | You will be given a text and a ground truth answer, your goals is to return a score between 0 and 1.""" 31 | self.prompt = """ Given a text and a ground truth answer, evaluate the quality of the text. 32 | Return a score of 1 if the text is exactly the same as the ground truth answer, 33 | Return a score of 0 if the text is completely wrong, 34 | Return a score between 0 and 1 if the text is partially correct. A more correct text should have a higher score. 35 | Do NOT generate anything else. 36 | Example: 37 | 38 | Model output: "The sky is blue." 39 | True answer: "The sky is blue." 40 | Score: 1 41 | 42 | Example 2: 43 | Model output: "The color of Alexandria is blue." 44 | True answer: "The color of Alexandria is green." 45 | Score: 0 46 | 47 | Example 3: 48 | Model output: "The purpose of Alexandria is to extract knowledge." 49 | True answer: "The color of Alexandria is to discover and organize knowledge into a structured form." 50 | Score: 0.9 51 | 52 | **Important**: Only generate a number. 53 | """ 54 | self.prompt += ( 55 | "\n Score the following text: \n model prediction: {0}, \n true answer: {1}" 56 | ) 57 | self.seed = 42 58 | super().__init__(model, endpoint_url, **kwargs) 59 | 60 | def evaluate_output(self, prompt: str, text: str, true_answer: str) -> str: 61 | prompt = self.prompt.format(text, true_answer) 62 | score = self.generate_response(prompt) 63 | example = EvalExample(text, true_answer, float(score)) 64 | return example 65 | 66 | def evaluate_output_batch(self, examples: list[str]) -> list[str]: 67 | eval_examples = [] 68 | for example in examples: 69 | try: 70 | text = ( 71 | example.split("True answer:")[0] 72 | .replace("Model output:", "") 73 | .strip() 74 | ) 75 | true_answer = example.split("True answer:")[1].strip() 76 | eval_example = self.evaluate_output(self.prompt, text, true_answer) 77 | eval_examples.append(eval_example) 78 | except Exception as e: 79 | print("Error evaluating example.") 80 | print(e) 81 | return eval_examples 82 | 83 | 84 | def parser_args(): 85 | parser = argparse.ArgumentParser(description="GPT Session") 86 | parser.add_argument("--model", type=str, default="GPT4", help="The model to use.") 87 | parser.add_argument("--endpoint_url", type=str, help="The endpoint url.") 88 | parser.add_argument( 89 | "--predictions_file", 90 | type=str, 91 | default="llama.txt", 92 | help="The input file with examples.", 93 | ) 94 | parser.add_argument( 95 | "--output_file", 96 | type=str, 97 | default="eval_examples1.json", 98 | help="The output file to save the examples.", 99 | ) 100 | return parser.parse_args() 101 | 102 | 103 | if __name__ == "__main__": 104 | args = parser_args() 105 | with open(args.predictions_file, "r") as f: 106 | examples = f.read() 107 | examples = examples.split("-------") 108 | 109 | eval = Evaluator(args.model, args.endpoint_url) 110 | eval_examples = eval.evaluate_output_batch(examples) 111 | for example in eval_examples: 112 | save_example(example, args.output_file) 113 | 114 | mean_score = np.mean([example.score for example in eval_examples]) 115 | print(f"Mean score: {mean_score}") 116 | -------------------------------------------------------------------------------- /experiments/output_scorer_open_ended.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import re 4 | from dataclasses import dataclass 5 | 6 | import numpy as np 7 | 8 | from kblam.gpt_session import GPT 9 | 10 | 11 | @dataclass 12 | class EvalExample: 13 | evidence: str 14 | question: str 15 | response: str 16 | score: float 17 | reason: str 18 | 19 | 20 | def save_example(example: EvalExample, output_file: str) -> None: 21 | try: 22 | with open(output_file, "a+") as f: 23 | json.dump(example.__dict__, f) 24 | f.write("\n") 25 | except Exception as e: 26 | print("Error saving example.") 27 | print(e) 28 | 29 | 30 | class Evaluator(GPT): 31 | def __init__(self, model, endpoint_url, seed, **kwargs) -> None: 32 | self.system_msg = """You are an AI system that evaluates the quality of generated response. Your goals is to return a score between 0 and 5 33 | indicating how accurate and useful the response is. An accrate and useful response should get a high score of 5.""" 34 | self.prompt_open_ended = """ 35 | A model is given a question about some information and evidence. 36 | The question is composed of two parts, a part that involves repeating information in the evidence and a part that potentially involves open-ended thinking. 37 | Then the model generates a response. 38 | Evaluate the response based on how grounded it is given the evidence and how reasonable it is. 39 | Return an integer score and step by step explanation of how you arrived at the score. 40 | Score of 5 means the response is accurate, relevant and reasonable (in that it meets common sense). 41 | If the responce addresses the question and uses the evidence in a relevant way, it should get a high score of 5. 42 | Score of 0 means the response is inaccurate and irrelevant or model is hallucinating. 43 | Score between 0 and 5 means the response is partially correct and relevant. 44 | 45 | Example 1: 46 | Evidence: "The purpose of Alexandria is to extract knowledge." 47 | Question: "Describe the purpose of Alexandria and how it can benefit users." 48 | Model output: "The purpose of Alexandria is to extract knowledge, it can benefit users by providing a structured way to organize knowledge." 49 | Score: 5 50 | Reason: The model's response is accurate and relevant to the question and evidence, the open-ended part is reasonable. 51 | 52 | Example 2: 53 | Evidence: "The purpose of Alexandria is to extract knowledge." 54 | Question: "Describe the purpose of Alexandria and what can it extract." 55 | Model output: "The purpose of Alexandria is to extract knowledge, it can extract knowledge knowledge." 56 | Score: 5 57 | Reason: The model's response is accurate and relevant to the question and evidence. 58 | 59 | Example 3: 60 | Evidence: "GreatTool is an app that helps users to be more productive." 61 | Question: "Describe GreatTool and how it may affect the community." 62 | Model output: "GreatTool is an app that helps users to be more productive. It may affect the community by helping users to sleep better." 63 | Score: 3 64 | Reason: The model's response is accurate and relevant to the question and evidence but it is not very reasonable. 65 | 66 | 67 | Example 4: 68 | Evidence: "GreatTool is an app that helps users to be more productive." 69 | Question: "Describe GreatTool and how it may affect the community." 70 | Model output: "GreatTool is an app that helps users to be more productive. It may affect the community by helping users to organize their tasks and manage their time better improving their productivity." 71 | Score: 5 72 | Reason: The model's response is accurate and relevant to the question and evidence and the open ended part is sensible and reasonable. 73 | 74 | Example 5: 75 | Evidence: "GreatTool is an app that helps users to be more productive." 76 | Question: "Tell me the description of GreatTool and what can it help users to achieve." 77 | Model output: "GreatTool is an app that helps users to be more productive. It can help users to organize their tasks and manage their time better improving their productivity." 78 | Score: 5 79 | Reason: The model's response is accurate and relevant to the question and evidence. 80 | 81 | Example 6: 82 | Evidence: "GreatTool is an app that helps users to be more productive." 83 | Question: "Describe GreatTool and how it may affect the community." 84 | Model output: "GreatTool is great tool with many feature" 85 | Score: 0 86 | Reason: The model's response is not accurate and doesn't answer the question. 87 | 88 | Example 7: 89 | Evidence: "GreatTool is an app that helps users to be more productive." 90 | Question: "Describe GreatTool and how it may affect the community." 91 | Model output: "GreatTool is an app that helps users to be more productive, it improves community income level." 92 | Score: 3 93 | Reason: The model's response is accurate but is not very reasonable. 94 | """ 95 | self.prompt_open_ended += "\n Score the following responce: \n evidence: {0}, question: {1} and \n model response: {2}" 96 | 97 | self.seed = seed 98 | super().__init__(model, endpoint_url, **kwargs) 99 | 100 | def evaluate_open_ended( 101 | self, prompt, evidence: str, question: str, response: str 102 | ) -> str: 103 | prompt = prompt.format(evidence, question, response) 104 | return self.generate_response(prompt) 105 | 106 | def evaluate_output_batch(self, examples: list[str]) -> list[str]: 107 | score_pattern = r"Score: (.+)" 108 | reason_pattern = r"Reason: (.+)" 109 | 110 | eval_examples = [] 111 | for example in examples: 112 | try: 113 | evidence_start = example.find("Evidence:") 114 | question_start = example.find("Question:") 115 | model_output_start = example.find("Model output:") 116 | 117 | # Extract the parts based on the indices 118 | evidence = example[evidence_start:question_start].strip() 119 | question = example[question_start:model_output_start].strip() 120 | model_output = example[model_output_start:].strip() 121 | 122 | eval_example = self.evaluate_open_ended( 123 | self.prompt_open_ended, evidence, question, model_output 124 | ) 125 | score = float(re.search(score_pattern, eval_example).group(1).strip()) 126 | reason = re.search(reason_pattern, eval_example).group(1).strip() 127 | eval_example = EvalExample( 128 | evidence, question, model_output, score, reason 129 | ) 130 | eval_examples.append(eval_example) 131 | 132 | save_example(eval_example, args.output_file) 133 | 134 | except Exception as e: 135 | print("Error evaluating example.") 136 | print(e) 137 | return eval_examples 138 | 139 | 140 | def parser_args(): 141 | parser = argparse.ArgumentParser(description="GPT Session") 142 | parser.add_argument("--model", type=str, default="GPT4", help="The model to use.") 143 | parser.add_argument("--endpoint_url", type=str, help="The endpoint url.") 144 | parser.add_argument( 145 | "--predictions_file", 146 | type=str, 147 | help="The file containing the model predictions.", 148 | ) 149 | parser.add_argument( 150 | "--output_file", 151 | type=str, 152 | default="eval_examples_open_ended_icl.json", 153 | help="The output file to save the examples.", 154 | ) 155 | parser.add_argument("--seed", type=int, default=42) 156 | return parser.parse_args() 157 | 158 | 159 | if __name__ == "__main__": 160 | args = parser_args() 161 | with open(args.predictions_file, "r") as f: 162 | examples = f.read() 163 | examples = examples.split("-------") 164 | 165 | eval = Evaluator(args.model, args.endpoint_url, args.seed) 166 | eval_examples = eval.evaluate_output_batch(examples) 167 | mean_score = np.mean([example.score for example in eval_examples]) 168 | print(f"Mean score: {mean_score}") 169 | -------------------------------------------------------------------------------- /experiments/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | import pathlib 6 | import re 7 | from functools import partial 8 | from itertools import chain 9 | from typing import Callable, Dict, List, Optional 10 | 11 | import wandb 12 | import numpy as np 13 | import torch 14 | from torch.nn.parallel import DistributedDataParallel 15 | import transformers 16 | from rich.console import Console 17 | from rich.logging import RichHandler 18 | from rich.progress import ( 19 | BarColumn, 20 | Progress, 21 | SpinnerColumn, 22 | TaskProgressColumn, 23 | TextColumn, 24 | TimeRemainingColumn, 25 | ) 26 | from rich.theme import Theme 27 | from torch.nn import CrossEntropyLoss 28 | from transformers import AutoTokenizer 29 | from accelerate import Accelerator 30 | 31 | from kblam.kb_encoder import KBEncoder 32 | from kblam.models.kblam_config import KBLaMConfig 33 | from kblam.models.llama3_model import KblamLlamaForCausalLM 34 | from kblam.models.phi3_model import KBLaMPhi3ForCausalLM 35 | from kblam.utils.data_utils import ( 36 | augment_row, 37 | generate_multi_entity_qa, 38 | get_i_dont_know_ans, 39 | ) 40 | from kblam.utils.train_utils import ( 41 | context_set_size_scheduler, 42 | get_kb_embd, 43 | setup_scheduler_and_optimizer, 44 | ) 45 | 46 | LOGFORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 47 | LOGFORMAT_RICH = "%(message)s" 48 | 49 | # setup logging 50 | # Create a custom theme for Rich 51 | custom_theme = Theme( 52 | { 53 | "info": "cyan", 54 | "warning": "yellow", 55 | "error": "bold red", 56 | "critical": "bold white on red", 57 | } 58 | ) 59 | 60 | # Create a Rich console with the custom theme 61 | console = Console(theme=custom_theme) 62 | 63 | # Configure the root logger to WARNING 64 | logging.basicConfig( 65 | level=logging.WARNING, # Set the root logger to WARNING 66 | format=LOGFORMAT_RICH, 67 | datefmt="[%X]", 68 | handlers=[RichHandler(console=console, rich_tracebacks=True)], 69 | ) 70 | 71 | # fmt: off 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument("--seed", type=int, default=1) 74 | parser.add_argument("--train_dataset",type=str,default="gpt_data") 75 | parser.add_argument("--N", type=int, default=120000, help="Size of training set, select the first N samples for training") 76 | parser.add_argument("--B", type=int, default=10, help="Batch size") 77 | parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") 78 | parser.add_argument("--sep_query_head", action=argparse.BooleanOptionalAction, help="Train a separate query head") 79 | parser.add_argument("--use_oai_embd", action="store_true", help="Use OpenAI embedding") 80 | parser.add_argument("--use_cached_embd", action="store_true", help="Choose to use pre-computed KV embeddings") 81 | parser.add_argument("--total_steps", type=int, default=20000, help="Total steps") 82 | parser.add_argument("--encoder_spec", type=str, default="OAI") 83 | parser.add_argument("--key_embd_src", type=str, default="key", choices=["key", "answer", "questions", None], help="Source of key embedding") 84 | parser.add_argument("--use_data_aug", action="store_true", help="Randomly pick templates for the question") 85 | parser.add_argument("--use_lr_decay", action="store_true") 86 | parser.add_argument("--dataset_dir", type=str, default="synthetic_data") 87 | parser.add_argument("--model_dir_to_resume", type=str, default=None, help="Checkpoint directory to resume training") 88 | parser.add_argument("--hf_model_spec", type=str, default="meta-llama/Llama-3.2-1B-Instruct", choices=["meta-llama/Meta-Llama-3-8B", "microsoft/Phi-3-mini-4k-instruct", "meta-llama/Llama-3.2-1B-Instruct"]) 89 | parser.add_argument("--hf_token", type=str,default=None,help="Huggingface token") 90 | parser.add_argument("--model_save_dir", type=str, default="output", help="Place to save the checkpoints") 91 | parser.add_argument("--kb_size", type=int, default=None, help="The size of the KB set size") 92 | parser.add_argument("--dynamic_kb_size", nargs=2, type=int, default=None, help="The size of the KB set size. Set a dynamic range for the kbsize specify min and max") 93 | parser.add_argument("--duplicate_true_kb", action=argparse.BooleanOptionalAction, default=True, help="Duplicate true entity's KB token") 94 | parser.add_argument("--length_invariance", action=argparse.BooleanOptionalAction, default=False, help="Scale the raw attention score") 95 | parser.add_argument("--outlier_num", type=int, default=1, help="Introduce questions without correct KB entites") 96 | parser.add_argument("--multi_entities", type=int, default=None, help="Introduce questions involving multiple entities") 97 | parser.add_argument("--use_extended_qa", action="store_true", help="Introduce QA with extended open-ended parts") 98 | parser.add_argument("--kb_token_layer_frequency", type=int, default=3, help="Introduce QA with extended open-ended parts") 99 | parser.add_argument("--gradient_accm_step", type=int, default=20, help="Introduce QA with extended open-ended parts") 100 | parser.add_argument("--verbose", action="store_true", help="Set logging to debug") 101 | parser.add_argument("--log_to_file", action="store_true", help="Log to file as well as stdout") 102 | parser.add_argument("--llm_type",type=str,default="llama3",choices=["llama3", "phi3"]) 103 | parser.add_argument("--max_seq_len",type=int,default=None) 104 | # fmt: on 105 | 106 | 107 | def create_custom_progress_bar( 108 | console: Console = None, # type: ignore 109 | color: str = "cyan", 110 | show_time: bool = True, 111 | show_spinner: bool = True, 112 | spinner_style: str = "dots", 113 | disable=False, 114 | ) -> Progress: 115 | """ 116 | Create a custom progress bar using Rich, optionally including loss reporting. 117 | 118 | :param description: Description of the task 119 | :param total: Total number of steps 120 | :param console: Rich Console object (if None, a new one will be created) 121 | :param color: Color of the progress bar 122 | :param show_time: Whether to show the time remaining 123 | :param show_spinner: Whether to show a spinner 124 | :param spinner_style: Style of the spinner (e.g., "dots", "dots12", "line", "arrow") 125 | :param show_loss: Whether to show loss information 126 | :return: A Rich Progress object and task ID 127 | """ 128 | if console is None: 129 | console = Console() 130 | columns = [] 131 | 132 | if show_spinner: 133 | columns.append(SpinnerColumn(spinner_name=spinner_style, style=color)) 134 | 135 | columns.extend( 136 | [ 137 | TextColumn("[bold blue]{task.description}", justify="right"), 138 | BarColumn(bar_width=None, style=color, complete_style=f"bold {color}"), 139 | TaskProgressColumn(), 140 | TextColumn("[bold yellow]Loss: {task.fields[loss]:.4f}", justify="right"), 141 | ] 142 | ) 143 | 144 | if show_time: 145 | columns.append(TimeRemainingColumn()) 146 | 147 | progress = Progress(*columns, console=console, expand=True, disable=disable) 148 | return progress 149 | 150 | 151 | def _format_QA_llama(Q: str, A: str): 152 | return ( 153 | "<|start_header_id|>user<|end_header_id|> " 154 | + Q 155 | + "<|eot_id|>" 156 | + "<|start_header_id|>assistant<|end_header_id|>" 157 | + A 158 | + "<|eot_id|>" 159 | ) 160 | 161 | 162 | def _format_QA_phi3(Q: str, A: str): 163 | return "<|user|>\n" + Q + "<|end|>\n" + "<|assistant|>\n" + A + "<|end|>\n" 164 | 165 | 166 | def _create_labels_for_llama(input_ids: torch.Tensor, input_strs: List[str], tokenizer): 167 | # Not sure this is correct. This method simply masks the <|start_header_id|>user<|end_header_id|> then leaves the rest in the labels 168 | # Possibly what they want is to mask out the query. To do that swap the index from the tokenizer below from 1 to 2 169 | answer_indices = torch.argmax( 170 | (input_ids == tokenizer("<|start_header_id|>assistant<|end_header_id|>")["input_ids"][1]).long(), 171 | -1, 172 | ) 173 | answer_mask = torch.ones_like(input_ids) 174 | for b in range(len(input_strs)): 175 | answer_mask[b, : (answer_indices[b].item() + 2)] = 0 176 | labels = input_ids * answer_mask + (1 - answer_mask) * (-100) 177 | return labels 178 | 179 | 180 | def _create_labels_for_phi3(input_ids: torch.Tensor, input_strs: List[str], tokenizer): 181 | # We just want to mask out the starting token. 182 | # The tokenized values are left padded so we want to know where our Q/A pairs start 183 | # Not 100% this is correct 184 | answer_indices = torch.argmax( 185 | (input_ids == tokenizer("<|user|>")["input_ids"][0]).long(), 186 | -1, 187 | ) 188 | answer_mask = torch.ones_like(input_ids) 189 | for b in range(len(input_strs)): 190 | answer_mask[b, : (answer_indices[b].item() + 1)] = 0 191 | labels = input_ids * answer_mask + (1 - answer_mask) * (-100) 192 | return labels 193 | 194 | 195 | def get_batch( 196 | qa_format_func: Callable[[str, str], str], 197 | label_func: Callable[[torch.Tensor, List, Callable], torch.Tensor], 198 | dataset: List[Dict], 199 | tokenizer, 200 | device: torch.device, 201 | B: int = 20, 202 | random_sample=True, 203 | use_data_aug=False, 204 | include_outlier=False, 205 | multi_entities=None, 206 | use_extended_qa=False, 207 | ): 208 | """ 209 | dataset: List of dictionary, denoting the KB, used to extract QA pairs 210 | model: The LLM, used to provide the embedding 211 | kb_embedding: KB embedding (differentiable) 212 | B: Batchsize 213 | include_outlier : Create a batch of question without answer in the KB. 214 | multi_entities : Create a batch of question that involves more than one entities. 215 | """ 216 | labels = [] 217 | if multi_entities is not None: 218 | assert not include_outlier 219 | 220 | if random_sample: 221 | if multi_entities is not None: 222 | batch_indices = np.random.choice(len(dataset), (B, multi_entities), replace=False) 223 | else: 224 | batch_indices = np.random.choice(len(dataset), B, replace=False) 225 | else: 226 | batch_indices = np.arange(B) 227 | 228 | def get_question_and_answer(idx: int) -> tuple[str, str]: 229 | if use_extended_qa: 230 | Q, A = dataset[idx]["extended_Q"], dataset[idx]["extended_A"] 231 | 232 | elif multi_entities is not None: 233 | Q, A = generate_multi_entity_qa( 234 | [dataset[i]["name"] for i in idx], 235 | [dataset[i]["description_type"] for i in idx], 236 | [dataset[i]["description"] for i in idx], 237 | ) 238 | else: 239 | Q = augment_row(dataset[idx]) if use_data_aug else dataset[idx]["Q"] 240 | A = get_i_dont_know_ans() if include_outlier else dataset[idx]["A"] 241 | return Q, A 242 | 243 | with torch.autograd.no_grad(): 244 | input_strs = [] 245 | real_batch_indices = [] 246 | for idx in batch_indices: 247 | Q, A = get_question_and_answer(idx) 248 | if Q is not None and A is not None: 249 | input_strs.append(qa_format_func(Q, A)) 250 | real_batch_indices.append(idx) 251 | else: 252 | print("Q or Answer is none") 253 | batch_indices = real_batch_indices 254 | tokenizer_output = tokenizer(input_strs, return_tensors="pt", padding=True).to(device) 255 | input_ids, attention_masks = ( 256 | tokenizer_output["input_ids"], 257 | tokenizer_output["attention_mask"], 258 | ) 259 | 260 | labels = label_func(input_ids, input_strs, tokenizer) 261 | if include_outlier: 262 | # Generate a new set of indices, such that the KB does not contain the entity where the question comes from 263 | batch_indices = np.random.choice(len(dataset), B, replace=False) 264 | return input_ids, attention_masks, labels, batch_indices 265 | 266 | 267 | def get_prefix_str(args): 268 | use_data_aug = args.use_data_aug 269 | sep_query_head = args.sep_query_head 270 | kb_size = args.kb_size 271 | dynamic_kb_size = args.dynamic_kb_size 272 | 273 | if dynamic_kb_size is not None: 274 | kb_size = "dynamic" # Random size 275 | 276 | duplicate_true_kb = args.duplicate_true_kb 277 | length_invariance = args.length_invariance 278 | outlier_ratio = args.outlier_num 279 | use_outlier = outlier_ratio != -1 280 | multi_entities = args.multi_entities 281 | use_extended_qa = args.use_extended_qa 282 | kb_token_layer_frequency = args.kb_token_layer_frequency 283 | lr = args.lr 284 | 285 | prefix_string = f"stage1_lr_{lr}" 286 | if kb_token_layer_frequency is not None: 287 | prefix_string += f"KBTokenLayerFreq{kb_token_layer_frequency}" 288 | if use_extended_qa: 289 | prefix_string += "UseExtendedQA" 290 | if multi_entities is not None: 291 | prefix_string += f"MultiEntities{multi_entities}" 292 | if use_outlier: 293 | prefix_string += f"UseOutlier{outlier_ratio}" 294 | if length_invariance: 295 | prefix_string += "LengthInvariant" 296 | if not duplicate_true_kb: 297 | prefix_string += "NoDuplicate" 298 | if kb_size is not None: 299 | prefix_string += f"KBSize{kb_size}" 300 | if sep_query_head: 301 | prefix_string += "SepQueryHead" 302 | if use_data_aug: 303 | prefix_string += "UseDataAug" 304 | return prefix_string 305 | 306 | 307 | def _load_cached_embeddings(encoder_model_spec: str, dataset_dir: str, dataset_name: str, key_embd_src: str): 308 | if encoder_model_spec == "OAI": 309 | encoder_model_spec_str = "oai" 310 | else: 311 | encoder_model_spec_str = encoder_model_spec 312 | key_embds = np.load( 313 | os.path.join( 314 | dataset_dir, 315 | f"{dataset_name}_{encoder_model_spec_str}_embd_{key_embd_src}.npy", 316 | ) 317 | ).astype("float32") 318 | if key_embd_src == "answer": 319 | # If we are using the answer string as the key, we also use it as the value string 320 | value_embds = np.load( 321 | os.path.join( 322 | dataset_dir, 323 | f"{dataset_name}_{encoder_model_spec_str}_embd_answer.npy", 324 | ) 325 | ).astype("float32") 326 | else: 327 | value_embds = np.load( 328 | os.path.join( 329 | dataset_dir, 330 | f"{dataset_name}_{encoder_model_spec_str}_embd_value.npy", 331 | ) 332 | ).astype("float32") 333 | return key_embds, value_embds 334 | 335 | 336 | def get_step_config( 337 | current_accum_step: int, 338 | total_accum_step: int, 339 | use_data_aug: bool, 340 | outlier_num: int, 341 | multi_entities: int | None, 342 | use_extended_qa: bool, 343 | ): 344 | """ 345 | Our instruction tuning dataset is composed of different types of instructions. 346 | Strategies: 347 | Outlier QA takes the last `outlier_num` accum steps; 348 | Multiple entites QA (if included) takes 1/3 of the rest accum_steps; 349 | Extended QA (if included) takes 1/3 of the rest accum_steps; 350 | Standard QA takes the rest. 351 | """ 352 | config = {} 353 | config["use_data_aug"] = use_data_aug 354 | config["include_outlier"] = False 355 | config["multi_entities"] = None 356 | config["use_extended_qa"] = False 357 | include_outlier = current_accum_step >= total_accum_step - 1 - outlier_num 358 | # Decide to include outlier and has reached the time 359 | if include_outlier: 360 | config["include_outlier"] = True 361 | return config 362 | if current_accum_step % 3 == 0: 363 | # multi_entities could be None, 364 | # in which case we just use standard QA 365 | config["multi_entities"] = multi_entities 366 | return config 367 | if current_accum_step % 3 == 1: 368 | config["use_extended_qa"] = use_extended_qa 369 | return config 370 | return config 371 | 372 | 373 | def _get_parameter_count(encoder): 374 | param_count = 0.0 375 | for p in encoder.parameters(): 376 | if p.requires_grad: 377 | param_count += p.numel() 378 | return param_count 379 | 380 | 381 | def _get_phi3_query_head_parameters( 382 | model: KblamLlamaForCausalLM | KBLaMPhi3ForCausalLM, 383 | sep_query_head: bool, 384 | kb_token_layer_frequency: int, 385 | ): 386 | llm_q_params = [] 387 | for name, param in model.named_parameters(): 388 | if sep_query_head: 389 | # For phi3 390 | if "qkv_proj.weight" in name: 391 | layer_id = int(re.search(r"\d+", name)[0]) # type: ignore 392 | if layer_id % kb_token_layer_frequency == 0: 393 | old_weight = param.detach() 394 | if "q_proj_new.weight" in name: 395 | layer_id = int(re.search(r"\d+", name)[0]) # type: ignore 396 | if layer_id % kb_token_layer_frequency == 0: 397 | param.copy_(old_weight[: model.config.hidden_size, :]) # type: ignore 398 | param.requires_grad = True 399 | llm_q_params.append(param) 400 | else: 401 | if "q_proj.weight" in name: 402 | layer_id = int(re.search(r"\d+", name)[0]) # type: ignore 403 | if layer_id % kb_token_layer_frequency == 0: 404 | param.requires_grad = True 405 | llm_q_params.append(param) 406 | return llm_q_params 407 | 408 | 409 | def _get_llama3_query_head_parameters( 410 | model: KblamLlamaForCausalLM | KBLaMPhi3ForCausalLM, 411 | sep_query_head: bool, 412 | kb_token_layer_frequency: int, 413 | ): 414 | llm_q_params = [] 415 | for name, param in model.named_parameters(): 416 | if sep_query_head: # TODO: this is different for each model type 417 | # For llama3 418 | if "q_proj.weight" in name: 419 | layer_id = int(re.search(r"\d+", name)[0]) # type: ignore 420 | if layer_id % kb_token_layer_frequency == 0: 421 | old_weight = param.detach() 422 | if "q_proj_new.weight" in name: 423 | layer_id = int(re.search(r"\d+", name)[0]) # type: ignore 424 | if layer_id % kb_token_layer_frequency == 0: 425 | param.copy_(old_weight) # type: ignore 426 | param.requires_grad = True 427 | llm_q_params.append(param) 428 | else: 429 | if "q_proj.weight" in name: 430 | layer_id = int(re.search(r"\d+", name)[0]) # type: ignore 431 | if layer_id % kb_token_layer_frequency == 0: 432 | param.requires_grad = True 433 | llm_q_params.append(param) 434 | return llm_q_params 435 | 436 | 437 | class KBRetriever: 438 | def __init__( 439 | self, 440 | encoder: KBEncoder, 441 | dataset: List[Dict], 442 | key_embds: Optional[np.ndarray], 443 | value_embds: Optional[np.ndarray], 444 | ): 445 | self.encoder = encoder 446 | self.key_embds = key_embds 447 | self.value_embds = value_embds 448 | self.dataset = dataset 449 | 450 | def _use_cached_embd(self): 451 | if self.key_embds is not None and self.value_embds is not None: 452 | return True 453 | else: 454 | return False 455 | 456 | def get_key_embeddings(self, batch_indices, batch_size, step, kb_size): 457 | if self._use_cached_embd(): 458 | train_set_key, train_set_val = get_kb_embd( 459 | self.encoder, 460 | batch_indices, 461 | precomputed_embd=(self.key_embds, self.value_embds), 462 | ) 463 | else: 464 | train_set_key, train_set_val = get_kb_embd(self.encoder, batch_indices, kb_dict=self.dataset) 465 | 466 | if len(train_set_key.shape) == 2: 467 | # Add comment on why we need this line 468 | train_set_key = train_set_key.unsqueeze(0).transpose(0, 1) 469 | train_set_val = train_set_val.unsqueeze(0).transpose(0, 1) 470 | 471 | context_set_size = context_set_size_scheduler(step, kb_size) 472 | context_set_index = np.random.choice(len(self.dataset), context_set_size, replace=False) # type: ignore 473 | if self._use_cached_embd(): 474 | context_set_key, context_set_val = get_kb_embd( 475 | self.encoder, 476 | context_set_index, 477 | precomputed_embd=(self.key_embds, self.value_embds), 478 | ) 479 | else: 480 | context_set_key, context_set_val = get_kb_embd(self.encoder, context_set_index, kb_dict=self.dataset) 481 | context_set_key = context_set_key.unsqueeze(0).expand(batch_size, *context_set_key.shape) 482 | context_set_val = context_set_val.unsqueeze(0).expand(batch_size, *context_set_val.shape) 483 | # context_set_val = torch.randn_like(context_set_val) 484 | # Idea: Try torch.randn here context_set_tokens?? 485 | true_kb_copy = 1 486 | kb_embedding = ( 487 | torch.concat([*([train_set_key] * true_kb_copy), context_set_key], 1), 488 | torch.concat([*([train_set_val] * true_kb_copy), context_set_val], 1), 489 | ) 490 | return kb_embedding 491 | 492 | 493 | class Trainer: 494 | def __init__( 495 | self, 496 | llm_model: KBLaMPhi3ForCausalLM | KblamLlamaForCausalLM, 497 | kbretriever: KBRetriever, 498 | tokenizer: transformers.PreTrainedTokenizer, 499 | kb_token_layer_frequency: int, 500 | num_steps: int, 501 | lr: float, 502 | device: torch.device | None, 503 | use_lr_decay: bool, 504 | kb_size: int | List[int], 505 | llm_savename: str, 506 | output_dir: str, 507 | sep_query_head: bool = False, 508 | max_seq_len: int | None = None, 509 | ): 510 | self.accelerator = Accelerator() 511 | self.logger = logging.getLogger("training") 512 | self.tokenizer = tokenizer 513 | self.sep_query_head = sep_query_head 514 | self.kb_token_layer_frequency = kb_token_layer_frequency 515 | self.num_steps = num_steps 516 | self.lr = lr 517 | self.max_seq_len = max_seq_len 518 | 519 | self.model = llm_model 520 | self.model.gradient_checkpointing_enable() 521 | 522 | self.device = device if device is not None else self.accelerator.device 523 | self.kbretriever = kbretriever 524 | self.kb_size = kb_size 525 | self.use_lr_decay = use_lr_decay 526 | self.llm_savename = llm_savename 527 | self.output_path = pathlib.Path(output_dir) 528 | 529 | if isinstance(llm_model, KBLaMPhi3ForCausalLM): # Phi3 530 | self._get_batch = partial(get_batch, _format_QA_phi3, _create_labels_for_phi3) 531 | self._get_params = _get_phi3_query_head_parameters 532 | elif isinstance(llm_model, KblamLlamaForCausalLM): # llama 533 | self._get_batch = partial(get_batch, _format_QA_llama, _create_labels_for_llama) 534 | self._get_params = _get_llama3_query_head_parameters 535 | else: 536 | raise ValueError(f"{llm_model} not recognised") 537 | 538 | self.scheduler, self.optim = self.setup_scheduler_and_optim() 539 | 540 | self.model, self.optim, self._get_batch, self.kbretriever.encoder = self.accelerator.prepare( 541 | self.model, self.optim, self._get_batch, self.kbretriever.encoder 542 | ) 543 | 544 | def setup_scheduler_and_optim(self): 545 | if self.sep_query_head: 546 | self.logger.info("Query head being fine tuned!") 547 | llm_q_params = self._get_params(self.model, self.sep_query_head, self.kb_token_layer_frequency) 548 | scheduler, optim = setup_scheduler_and_optimizer( 549 | chain(self.kbretriever.encoder.parameters(), llm_q_params), 550 | self.lr, 551 | self.num_steps, 552 | ) 553 | self.logger.info("Optimizer recreated") 554 | else: 555 | scheduler, optim = setup_scheduler_and_optimizer( 556 | self.kbretriever.encoder.parameters(), self.lr, self.num_steps 557 | ) 558 | self.logger.info("Optimizer recreated") 559 | return scheduler, optim 560 | 561 | def train( 562 | self, 563 | training_set: List[Dict], 564 | batch_size, 565 | grad_accum_steps: int, 566 | outlier_num: int, 567 | use_data_aug: bool = False, 568 | multi_entities: bool = False, 569 | use_extended_qa: bool = False, 570 | save_period: int = 2000, 571 | resumed_step: int = 0, 572 | kb_config: KBLaMConfig = None, 573 | ): 574 | train_losses = [] 575 | start_step = resumed_step 576 | 577 | loss_fct = CrossEntropyLoss(reduction="none") 578 | 579 | # Calculate accumulation steps per GPU 580 | num_processes = self.accelerator.num_processes 581 | accum_steps_per_gpu = max(1, grad_accum_steps // num_processes) 582 | effective_batch_size = batch_size * grad_accum_steps 583 | 584 | if self.accelerator.is_main_process: 585 | self.logger.info(f"Training with {num_processes} GPUs") 586 | self.logger.info(f"Total accumulation steps: {grad_accum_steps}, Steps per GPU: {accum_steps_per_gpu}") 587 | self.logger.info(f"Batch size: {batch_size}") 588 | self.logger.info(f"Effective batch size: {effective_batch_size}") 589 | 590 | with create_custom_progress_bar(console=console, disable=not self.accelerator.is_main_process) as pbar: 591 | task = pbar.add_task("Training", total=self.num_steps, loss=100) 592 | for step in range(start_step, self.num_steps, 1): 593 | self.optim.zero_grad() 594 | losses = [] 595 | 596 | # Calculate which accumulation steps this GPU should process 597 | process_rank = self.accelerator.process_index 598 | start_accum_step = process_rank * accum_steps_per_gpu 599 | end_accum_step = min(start_accum_step + accum_steps_per_gpu, grad_accum_steps) 600 | 601 | # Accumulate gradients 602 | for a_step in range(start_accum_step, end_accum_step): 603 | step_config = get_step_config( 604 | a_step, 605 | grad_accum_steps, 606 | use_data_aug, 607 | outlier_num, 608 | multi_entities, 609 | use_extended_qa, 610 | ) 611 | input_ids, attention_masks, labels, batch_indices = self._get_batch( 612 | training_set, 613 | self.tokenizer, 614 | self.device, 615 | B=batch_size, 616 | random_sample=True, 617 | **step_config, 618 | ) 619 | 620 | if a_step == 0 and step % 10 == 0: 621 | self.logger.info(f"INPUT IDs SHAPE: {input_ids.shape}") 622 | 623 | if self.max_seq_len is not None: 624 | input_ids = input_ids[:, : self.max_seq_len] 625 | attention_masks = attention_masks[:, : self.max_seq_len] 626 | labels = labels[:, : self.max_seq_len] 627 | if a_step == 0 and step % 10 == 0: 628 | self.logger.info(f"TRUNCATED INPUT IDs SHAPE: {input_ids.shape}") 629 | 630 | kb_embedding = self.kbretriever.get_key_embeddings( 631 | batch_indices, len(input_ids), step, self.kb_size 632 | ) 633 | out = self.model( 634 | input_ids=input_ids, 635 | attention_mask=attention_masks, 636 | kb_kvs=kb_embedding, 637 | output_attentions=True, 638 | kb_config=kb_config, 639 | ) 640 | logits = out["logits"] 641 | 642 | # display ground truth and model prediction to quickly check model 643 | if a_step == 0 and step % 10 == 0: 644 | batch_index = 0 # Which example in the batch to select 645 | max_logits = logits.argmax(axis=2) 646 | decoded_pred = self.tokenizer.decode(max_logits[batch_index, :-1]) 647 | sel_labels = labels[batch_index, :] 648 | sel_labels = sel_labels[sel_labels >= 0] # Remove padding token -100 649 | decoded_gt = self.tokenizer.decode(sel_labels) 650 | self.logger.info(f"KB SHAPE: {kb_embedding[0].shape}") 651 | self.logger.info(f"GT: {decoded_gt}") 652 | self.logger.info(f"PRED: {decoded_pred}") 653 | wandb.log({"kbsize": kb_embedding[0].shape[1]}) 654 | 655 | shift_logits = logits[..., :-1, :].contiguous() 656 | shift_labels = labels[..., 1:].contiguous() 657 | weights = (shift_labels > 0).sum(-1, keepdim=True).expand(-1, shift_labels.shape[1]).contiguous() 658 | # Flatten the tokens 659 | model_config = ( 660 | self.model.config 661 | if not isinstance(self.model, DistributedDataParallel) 662 | else self.model.module.config 663 | ) 664 | shift_logits = shift_logits.view(-1, model_config.vocab_size) 665 | shift_labels = shift_labels.view(-1) 666 | weights = weights.view(-1) 667 | 668 | shift_labels = shift_labels.to(shift_logits.device) 669 | 670 | loss = ( 671 | loss_fct(shift_logits, shift_labels) * weights.max() / weights 672 | ).mean() # Make sure each sample is equally weighted 673 | 674 | self.accelerator.backward(loss) 675 | losses.append(loss.item()) 676 | 677 | self.optim.step() 678 | if self.use_lr_decay: 679 | self.scheduler.step() 680 | 681 | # Gather and average losses from all GPUs for reporting 682 | if losses: # Only if this GPU processed any batches 683 | local_loss = torch.tensor(np.mean(losses), device=self.device) 684 | else: 685 | local_loss = torch.tensor(0.0, device=self.device) 686 | 687 | # Gather losses from all processes 688 | all_losses = self.accelerator.gather(local_loss) 689 | valid_losses = all_losses[all_losses > 0] # Filter out zeros from GPUs that didn't process batches 690 | avg_loss = valid_losses.mean().item() if len(valid_losses) > 0 else 0.0 691 | 692 | # Only log from main process 693 | if self.accelerator.is_main_process: 694 | self.logger.info(f"step: {step}, loss: {avg_loss}") 695 | wandb.log({'train_loss': np.mean(losses)}) 696 | train_losses.append(avg_loss) 697 | pbar.update(task, advance=1, loss=avg_loss) 698 | 699 | if (step % save_period) == 0 and (step != start_step): 700 | try: 701 | # Log memory usage before synchronization 702 | self.logger.info( 703 | f"Is main process: {self.accelerator.is_main_process}, GPU memory before save: {torch.cuda.memory_allocated()/1e9:.2f}GB / {torch.cuda.get_device_properties(0).total_memory/1e9:.2f}GB" 704 | ) 705 | 706 | # Try to free up memory 707 | torch.cuda.empty_cache() 708 | 709 | # Synchronize before saving 710 | self.accelerator.wait_for_everyone() 711 | 712 | if self.accelerator.is_main_process: 713 | 714 | self.logger.info("Saving checkpoint...") 715 | self.logger.info("Making dirs...") 716 | # Save model - using proper directory creation 717 | model_ckpt_name = self.output_path / f"{self.llm_savename}_step_{step}" 718 | model_ckpt_name.mkdir(parents=True, exist_ok=True) 719 | 720 | # Also create encoder directory 721 | encoder_dir = self.output_path / f"{self.llm_savename}_step_{step}_encoder" 722 | encoder_dir.mkdir(parents=True, exist_ok=True) 723 | 724 | self.logger.info("Saving model...") 725 | # Unwrap and save model 726 | unwrapped_model = self.accelerator.unwrap_model(self.model) 727 | unwrapped_model.save_pretrained( 728 | model_ckpt_name, 729 | is_main_process=self.accelerator.is_main_process, 730 | save_function=self.accelerator.save, 731 | ) 732 | 733 | self.logger.info("Saving encoder...") 734 | # Save encoder and config from main process 735 | encoder_ckpt_name = encoder_dir / "encoder.pt" 736 | torch.save(self.kbretriever.encoder.state_dict(), encoder_ckpt_name) 737 | 738 | self.logger.info("Saving config...") 739 | # Explicitly save config as JSON 740 | config_path = model_ckpt_name / "kb_config_explicit.json" 741 | with open(config_path, 'w') as f: 742 | f.write(kb_config.to_json_string()) 743 | 744 | except Exception as e: 745 | self.logger.error(f"Error saving checkpoint: {e}") 746 | self.logger.error(f"Error details: {str(e)}") 747 | raise e 748 | 749 | 750 | def main(): 751 | os.environ["NCCL_TIMEOUT"] = "1200000" 752 | logger = logging.getLogger("training") 753 | 754 | args = parser.parse_args() 755 | if torch.cuda.is_available(): 756 | device = torch.device("cuda") 757 | 758 | if args.verbose: 759 | logger.setLevel(logging.DEBUG) 760 | else: 761 | logger.setLevel(logging.INFO) 762 | 763 | print(vars(args)) 764 | dataset_name = args.train_dataset 765 | seed = args.seed 766 | N = args.N 767 | B = args.B 768 | 769 | total_steps = args.total_steps 770 | encoder_spec = args.encoder_spec 771 | key_embd_src = args.key_embd_src 772 | use_data_aug = args.use_data_aug 773 | use_lr_decay = args.use_lr_decay 774 | use_cached_embd = args.use_cached_embd 775 | dataset_dir = args.dataset_dir 776 | model_dir_to_resume = args.model_dir_to_resume 777 | model_save_dir = args.model_save_dir 778 | sep_query_head = args.sep_query_head 779 | kb_size = args.kb_size 780 | dynamic_kb_size = args.dynamic_kb_size 781 | max_seq_len = args.max_seq_len 782 | 783 | if kb_size is not None and dynamic_kb_size is not None: 784 | raise ValueError("Can't specify kb_size and dynamic_kb_size. Use only one") 785 | 786 | kb_size = kb_size if kb_size is not None else dynamic_kb_size 787 | 788 | gradient_accm_step = args.gradient_accm_step 789 | 790 | length_invariance = args.length_invariance 791 | outlier_num = args.outlier_num 792 | multi_entities = args.multi_entities 793 | use_extended_qa = args.use_extended_qa 794 | kb_token_layer_frequency = args.kb_token_layer_frequency 795 | llm_type = args.llm_type 796 | hf_model_spec = args.hf_model_spec 797 | hf_token = args.hf_token 798 | 799 | torch.manual_seed(seed) 800 | np.random.seed(seed) 801 | 802 | pathlib.Path(model_save_dir).mkdir(parents=True, exist_ok=True) 803 | 804 | if Accelerator().is_main_process: 805 | wandb.init( 806 | # set the wandb project where this run will be logged 807 | project="kb-llm", 808 | # track hyperparameters and run metadata 809 | config={ 810 | "learning_rate": args.lr, 811 | 'sep_query_head': sep_query_head, 812 | 'kb_size': kb_size, 813 | 'length_invariance': length_invariance, 814 | 'dataset': dataset_name, 815 | 'outlier_num': outlier_num, 816 | 'multi_entities': multi_entities, 817 | 'use_extended_qa': use_extended_qa, 818 | 'kb_token_layer_frequency': kb_token_layer_frequency, 819 | 'gradient_accm_step': gradient_accm_step, 820 | "encoder_spec": encoder_spec, 821 | "max_seq_len": max_seq_len, 822 | }, 823 | ) 824 | 825 | # Try to free up memory 826 | torch.cuda.empty_cache() 827 | 828 | if args.log_to_file: 829 | formatter = logging.Formatter(LOGFORMAT) 830 | f_handler = logging.FileHandler(model_save_dir / "log.txt") 831 | f_handler.setFormatter(formatter) 832 | logger.addHandler(f_handler) 833 | 834 | logger.info(f"Running on {device}") 835 | 836 | logger.info("🚨 Started training 🚨") 837 | logger.info(f"💽 Saving to {model_save_dir}💽") 838 | if sep_query_head: 839 | os.environ["SEP_QUERY_HEAD"] = "TRUE" 840 | logger.info("Having seperate query head for KB!") 841 | 842 | if length_invariance: 843 | os.environ["LENGTH_INVARIANCE"] = "TRUE" 844 | logger.info("Having seperate query head for KB!") 845 | 846 | os.environ["SCALE_FACTOR"] = "" 847 | 848 | if use_cached_embd: 849 | # We load the pre-computed version stored on the disk rather 850 | # than computing them on the fly to make things faster 851 | logger.info(f"Using pre-computed {encoder_spec} embedding") 852 | key_embds, value_embds = _load_cached_embeddings(encoder_spec, dataset_dir, dataset_name, key_embd_src) 853 | 854 | prefix_string = get_prefix_str(args) 855 | logger.info(f"Experiment prefix {get_prefix_str(args)}") 856 | 857 | if use_extended_qa: 858 | dataset = json.load(open(os.path.join(dataset_dir, f"{dataset_name}_augmented.json"))) 859 | else: 860 | dataset = json.load(open(os.path.join(dataset_dir, f"{dataset_name}.json"))) 861 | 862 | training_set = dataset[:N] 863 | 864 | # Set up the LLM 865 | llm_model_spec = model_dir_to_resume if model_dir_to_resume else hf_model_spec 866 | 867 | resumed_step = 0 if not model_dir_to_resume else int(model_dir_to_resume.split("_")[-1]) 868 | 869 | if llm_model_spec is None: 870 | raise ValueError("Either supply model_dir_to_resume or hf_model_spec") 871 | 872 | if hf_token is None and args.llm_type == "llama3": 873 | raise ValueError("Please supply HuggingFace token(hf_token) when loading model Llama weights from HuggingFace") 874 | 875 | # Tokenizer comes from the base model 876 | tokenizer = AutoTokenizer.from_pretrained( 877 | hf_model_spec, 878 | trust_remote_code=True, 879 | token=hf_token if hf_token is args.llm_type == "llama3" else None, 880 | ) 881 | tokenizer.pad_token = tokenizer.eos_token 882 | 883 | if args.llm_type == "llama3": 884 | model = KblamLlamaForCausalLM.from_pretrained( 885 | llm_model_spec, 886 | device_map=device, 887 | torch_dtype=torch.bfloat16, 888 | trust_remote_code=True, 889 | token=hf_token, 890 | ) 891 | elif args.llm_type == "phi3": 892 | model = KBLaMPhi3ForCausalLM.from_pretrained( 893 | llm_model_spec, 894 | device_map=device, 895 | torch_dtype="auto", 896 | trust_remote_code=True, 897 | ) 898 | else: 899 | ValueError(f"LLM type {args.llm_type} not recognised") 900 | 901 | logger.info(model.config) # type: ignore 902 | 903 | model.eval() # type: ignore 904 | # freeze model 905 | for _, param in model.named_parameters(): # type: ignore 906 | param.requires_grad = False 907 | 908 | # Set up the encoder 909 | encoder = KBEncoder( 910 | encoder_name=encoder_spec, 911 | projector_type="linear", 912 | endpoint_url="", 913 | out_dim=model.config.hidden_size # type: ignore 914 | * (model.config.num_hidden_layers // kb_token_layer_frequency + 1), # type: ignore 915 | frozen_base_model=True, 916 | device=device, 917 | ) 918 | 919 | if model_dir_to_resume: 920 | encoder.load_state_dict(torch.load(os.path.join(model_dir_to_resume, "encoder.pt"))) 921 | kb_config = KBLaMConfig.from_pretrained(os.path.join(model_dir_to_resume, "kb_config.json")) 922 | else: 923 | kb_config = KBLaMConfig( 924 | sep_query_head=sep_query_head, 925 | kb_layer_frequency=kb_token_layer_frequency, 926 | ) 927 | 928 | encoder.train() 929 | 930 | kbretriever = KBRetriever( 931 | encoder, 932 | training_set, 933 | key_embds=key_embds, # type: ignore 934 | value_embds=value_embds, # type: ignore 935 | ) 936 | 937 | logger.info("Model ready 🚀") 938 | 939 | # Get the training started 940 | llm_ckpt_name = f"{prefix_string}KeyFrom{key_embd_src}_{encoder_spec}_{dataset_name}_{llm_type}" 941 | 942 | trainer = Trainer( 943 | model, # type: ignore 944 | kbretriever, 945 | tokenizer, 946 | kb_token_layer_frequency, 947 | total_steps, 948 | args.lr, 949 | device, 950 | use_lr_decay, 951 | kb_size, # type: ignore 952 | llm_ckpt_name, 953 | model_save_dir, 954 | sep_query_head=sep_query_head, 955 | max_seq_len=max_seq_len, 956 | ) 957 | 958 | logger.info(f"Number of trainable parameters: {_get_parameter_count(encoder):,}") 959 | 960 | trainer.train( 961 | training_set, 962 | B, 963 | gradient_accm_step, 964 | outlier_num, 965 | use_data_aug=use_data_aug, 966 | multi_entities=multi_entities, 967 | use_extended_qa=use_extended_qa, 968 | save_period=3000, 969 | resumed_step=resumed_step, 970 | kb_config=kb_config, 971 | ) 972 | 973 | 974 | if __name__ == "__main__": 975 | main() 976 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | [build-system] 5 | requires = ["pdm-backend"] 6 | build-backend = "pdm.backend" 7 | 8 | [project] 9 | name = "kblam" 10 | version = "0.0.1" 11 | description= "Knowledge Base augmented Language Model" 12 | 13 | readme = "README.md" 14 | requires-python = ">=3.10" 15 | license = {text = "MIT"} 16 | dependencies = [ 17 | "accelerate", 18 | "datasets==3.2.0", 19 | "transformers==4.46.0", 20 | "torch", 21 | "sentence_transformers", 22 | "azure-ai-ml", 23 | "azure-core", 24 | "azure-identity", 25 | "azure-mgmt-authorization", 26 | "openai", 27 | "pytest", 28 | "rich", 29 | "evaluate", 30 | "nltk", 31 | "rouge-score", 32 | "bert_score" 33 | ] 34 | 35 | [dependency-groups] 36 | dev = [ 37 | "mypy>=1.5.1", 38 | "pre-commit", 39 | "pylint>=3.0.0", 40 | "pytest>=7.4.2", 41 | "codespell", 42 | "coverage", 43 | "ipykernel", 44 | "jinja2", 45 | "ruff", 46 | "pyright", 47 | ] 48 | 49 | experiment = [ 50 | "lm-eval==0.4.5", 51 | "wandb", 52 | ] 53 | 54 | all = ["kblam[dev,experiment]"] 55 | 56 | [tool.setuptools.packages.find] 57 | where = ["src"] 58 | exclude = ["*_test.py", "tests/*"] 59 | 60 | [tool.setuptools.package-data] 61 | "*" = ["py.typed"] 62 | 63 | [tool.black] 64 | line-length = 120 65 | skip-string-normalization = true 66 | 67 | [tool.isort] 68 | py_version = 310 69 | profile = "black" 70 | line_length = 120 71 | 72 | [tool.mypy] 73 | python_version = "3.10" 74 | disallow_untyped_defs = true 75 | 76 | ignore_missing_imports = true 77 | 78 | [tool.pylint.main] 79 | max-line-length = 120 80 | suggestion-mode = true 81 | py-version = "3.10" 82 | 83 | [tool.pylint.messages_control] 84 | # Disable the message, report, category or checker with the given id(s). 85 | disable = [ 86 | "logging-fstring-interpolation" 87 | ] 88 | 89 | [tool.pylint.basic] 90 | docstring-min-length = 10 91 | 92 | [tool.pdm.build] 93 | includes = ["src"] 94 | -------------------------------------------------------------------------------- /src/kblam/gpt_session.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | from pathlib import Path 5 | 6 | from azure.identity import ( 7 | AuthenticationRecord, 8 | DeviceCodeCredential, 9 | TokenCachePersistenceOptions, 10 | get_bearer_token_provider, 11 | ) 12 | from openai import AzureOpenAI 13 | 14 | valid_models = ["gpt-4o", "ada-embeddings", "text-embedding-3-large"] 15 | 16 | 17 | class GPT: 18 | def __init__( 19 | self, 20 | model_name: str, 21 | endpoint_url: str, 22 | api_version: str = "2024-02-15-preview", 23 | system_msg: str = "You are an AI assistant.", 24 | max_retries: int = 12, 25 | temperature: int = 1.0, 26 | max_tokens: int = 4096, 27 | top_p: float = 0.95, 28 | frequency_penalty: int = 0, 29 | presence_penalty: int = 0, 30 | seed: int = None, 31 | ): 32 | if model_name not in valid_models: 33 | raise ValueError( 34 | f"Invalid model: {model_name}. Valid models are: {valid_models}" 35 | ) 36 | 37 | token_provider = get_bearer_token_provider( 38 | self._get_credential(), "https://cognitiveservices.azure.com/.default" 39 | ) 40 | 41 | self.OA_client = AzureOpenAI( 42 | azure_endpoint=endpoint_url, 43 | api_version=api_version, 44 | azure_ad_token_provider=token_provider, 45 | ) 46 | 47 | self.max_retries = max_retries 48 | self.system_msg = system_msg 49 | self.model_name = model_name 50 | self.temperature = temperature 51 | self.max_tokens = max_tokens 52 | self.top_p = top_p 53 | self.frequency_penalty = frequency_penalty 54 | self.presence_penalty = presence_penalty 55 | self.seed = seed 56 | 57 | def set_seed(self, seed: int): 58 | self.seed = seed 59 | 60 | def _get_credential(self, lib_name: str = "azure_openai") -> DeviceCodeCredential: 61 | """Retrieves a credential to be used for authentication in Azure""" 62 | if sys.platform.startswith("win"): 63 | auth_record_root_path = Path(os.environ["LOCALAPPDATA"]) 64 | else: 65 | auth_record_root_path = Path.home() 66 | 67 | auth_record_path = auth_record_root_path / lib_name / "auth_record.json" 68 | cache_options = TokenCachePersistenceOptions( 69 | name=f"{lib_name}.cache", allow_unencrypted_storage=True 70 | ) 71 | 72 | if auth_record_path.exists(): 73 | with open(auth_record_path, "r") as f: 74 | record_json = f.read() 75 | deserialized_record = AuthenticationRecord.deserialize(record_json) 76 | credential = DeviceCodeCredential( 77 | authentication_record=deserialized_record, 78 | cache_persistence_options=cache_options, 79 | ) 80 | else: 81 | auth_record_path.parent.mkdir(parents=True, exist_ok=True) 82 | credential = DeviceCodeCredential(cache_persistence_options=cache_options) 83 | record_json = credential.authenticate().serialize() 84 | with open(auth_record_path, "w") as f: 85 | f.write(record_json) 86 | 87 | return credential 88 | 89 | def api_call_chat(self, messages: list[dict]) -> str | None: 90 | for _ in range(self.max_retries): 91 | completion = self.OA_client.chat.completions.create( 92 | model=self.model_name, 93 | messages=messages, 94 | temperature=self.temperature, 95 | max_tokens=self.max_tokens, 96 | top_p=self.top_p, 97 | frequency_penalty=self.frequency_penalty, 98 | presence_penalty=self.presence_penalty, 99 | seed=self.seed if self.seed else None, 100 | ) 101 | if completion: 102 | return completion.choices[0].message.content 103 | return None 104 | 105 | def _api_call_embedding(self, text: str) -> list[float] | None: 106 | for _ in range(self.max_retries): 107 | embedding = self.OA_client.embeddings.create( 108 | input=text, model=self.model_name 109 | ) 110 | if embedding: 111 | return embedding.data[0].embedding 112 | return None 113 | 114 | def generate_response(self, prompt: str) -> str | None: 115 | """ 116 | Generate a response for the given prompt. 117 | This setup can be used for GPT4 models but not for embedding genneration. 118 | """ 119 | messages = [ 120 | { 121 | "role": "system", 122 | "content": self.system_msg, 123 | }, 124 | { 125 | "role": "user", 126 | "content": prompt, 127 | }, 128 | ] 129 | 130 | response = self.api_call_chat(messages) 131 | return response 132 | 133 | def generate_embedding(self, text: str) -> list[float] | None: 134 | """ 135 | Generate an embedding for the given text. 136 | This setup can be used for Ada embeddings but not for text generation. 137 | """ 138 | embedding = self._api_call_embedding(text) 139 | return embedding 140 | 141 | 142 | def parser_args(): 143 | parser = argparse.ArgumentParser(description="GPT Session") 144 | parser.add_argument( 145 | "--model_name", 146 | type=str, 147 | default="ada-embeddings", 148 | help="Model name to use for embedding generation", 149 | ) 150 | parser.add_argument( 151 | "--prompt", 152 | type=str, 153 | default="Embedding text", 154 | help="Prompt for text generation", 155 | ) 156 | parser.add_argument( 157 | "--endpoint_url", 158 | type=str, 159 | help="Endpoint URL for the model", 160 | ) 161 | 162 | return parser.parse_args() 163 | 164 | 165 | if __name__ == "__main__": 166 | args = parser_args() 167 | gpt = GPT(args.model_name, args.endpoint_url) 168 | response = gpt.generate_embedding(args.prompt) 169 | 170 | assert response is not None 171 | -------------------------------------------------------------------------------- /src/kblam/kb_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import FeatureExtractionMixin 4 | from sentence_transformers import SentenceTransformer 5 | from .gpt_session import GPT 6 | from typing import Union 7 | 8 | 9 | class IdentityMap(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def forward(self, x, *args, **kwargs): 14 | return x 15 | 16 | 17 | def get_projector( 18 | projector_type: str, in_dim: int, out_dim: int, projector_kwargs: dict 19 | ) -> nn.Module: 20 | assert isinstance(projector_kwargs, dict) 21 | if projector_type == "identity": 22 | return IdentityMap() 23 | elif projector_type == "linear": 24 | return nn.Linear(in_dim, out_dim) 25 | elif projector_type == "mlp": 26 | mlp_depth, mlp_hidden_dim = ( 27 | projector_kwargs["mlp_depth"], 28 | projector_kwargs["mlp_hidden_dim"], 29 | ) 30 | modules = [nn.Linear(in_dim, mlp_hidden_dim)] 31 | for _ in range(mlp_depth): 32 | modules.append(nn.Linear(mlp_hidden_dim, mlp_hidden_dim)) 33 | modules.append(nn.GELU()) 34 | modules.append(nn.Linear(mlp_hidden_dim, out_dim)) 35 | return nn.Sequential(*modules) 36 | else: 37 | raise NotImplementedError(f"Projector type {projector_type} not found") 38 | 39 | 40 | # TODO(t-isazawat): Add support for batching here 41 | class KBEncoder(nn.Module, FeatureExtractionMixin): 42 | kb_special_token = { 43 | "": 0, 44 | "": 1, 45 | "": 2, 46 | "": 3, 47 | "": 4, 48 | "": 5, 49 | } 50 | 51 | def __init__( 52 | self, 53 | encoder_name: str, 54 | projector_type: str, 55 | out_dim: int, 56 | endpoint_url: str, 57 | projector_kwargs: dict = {}, 58 | frozen_base_model: bool = True, 59 | device: Union[str, torch.device] = "cuda", 60 | get_oai_embd_online: bool = False, 61 | ): 62 | super().__init__() 63 | # Define the KB encoder backbone 64 | self.encoder_spec = encoder_name 65 | 66 | if encoder_name in ["OAI", "BigOAI"]: 67 | big = "Big" in encoder_name 68 | if get_oai_embd_online: 69 | if big: 70 | self.gs = GPT("text-embedding-3-large", endpoint_url) 71 | else: 72 | self.gs = GPT("ada-embeddings", endpoint_url) 73 | 74 | self.base_model_encode = lambda s: torch.tensor( 75 | self.gs.generate_embedding(s) 76 | ).to(self.device) 77 | else: 78 | self.base_model_encode = None 79 | self.in_dim = 3072 if big else 1536 80 | else: 81 | self.base_model = SentenceTransformer(encoder_name) 82 | self.base_model_encode = lambda s: self.base_model.encode( 83 | s, convert_to_numpy=False 84 | ) 85 | self.frozen_base_model = frozen_base_model 86 | if frozen_base_model: 87 | self.base_model.eval() 88 | for param in self.base_model.parameters(): 89 | param.requires_grad = False 90 | else: 91 | self.base_model.train() 92 | self.in_dim = self.base_model.get_sentence_embedding_dimension() 93 | self.out_dim = out_dim 94 | self.projector_k = get_projector( 95 | projector_type, self.in_dim, self.out_dim, projector_kwargs 96 | ) 97 | self.projector_v = get_projector( 98 | projector_type, self.in_dim, self.out_dim, projector_kwargs 99 | ) 100 | self.key_layernorm = nn.LayerNorm( 101 | self.out_dim, elementwise_affine=False, bias=False 102 | ) 103 | self.embedding = nn.Embedding(len(self.kb_special_token), out_dim) 104 | self.device = device 105 | self.to(self.device) 106 | 107 | def freeze_v(self): 108 | for param in self.projector_v.parameters(): 109 | param.requires_grad = False 110 | 111 | def encode_key(self, S=None, base_emb=None): 112 | """ 113 | Convert the keys to embedding using the backbone model + adapter 114 | """ 115 | if S: 116 | base_embedding = self.base_model_encode(S) 117 | elif base_emb is not None: 118 | base_embedding = torch.from_numpy(base_emb).to(self.device) 119 | return self.key_layernorm(self.projector_k(base_embedding)).bfloat16() 120 | 121 | def encode_val(self, S=None, base_emb=None): 122 | """ 123 | Convert the values to embedding using the backbone model + adapter 124 | """ 125 | if S: 126 | base_embedding = self.base_model_encode(S) 127 | elif base_emb is not None: 128 | base_embedding = torch.from_numpy(base_emb).to(self.device) 129 | return self.projector_v(base_embedding).bfloat16() 130 | 131 | def encode_key_value(self, key, value): 132 | key_embd = self.encode_key(S=key) 133 | value_embd = self.encode_val(S=value) 134 | return key_embd, value_embd 135 | 136 | def encode_key_value_embeddings(self, key_embd, value_embd): 137 | key_embd = self.encode_key(base_emb=key_embd) 138 | value_embd = self.encode_val(base_emb=value_embd) 139 | return key_embd, value_embd 140 | 141 | def encode_base_embeddings( 142 | self, kb: tuple[torch.Tensor, torch.Tensor] 143 | ) -> tuple[torch.Tensor, torch.Tensor]: 144 | """ 145 | Encode the knowledge base into embeddings. Assumes that the input KB is given as a tuple of two torch tensors: keys and values 146 | """ 147 | key_embds, value_embds = [], [] 148 | for key, value in zip(kb[0], kb[1]): 149 | key_embd, value_embd = self.encode_key_value_embeddings(key, value) 150 | key_embds.append(key_embd) 151 | value_embds.append(value_embd) 152 | return torch.stack(key_embds), torch.stack(value_embds) 153 | 154 | def encode(self, kb: list[tuple]) -> tuple[torch.Tensor, torch.Tensor]: 155 | """ 156 | Encode the knowledge base into embeddings 157 | """ 158 | key_embds, value_embds = [], [] 159 | for key, value in kb: 160 | key_embd, value_embd = self.encode_key_value(key, value) 161 | key_embds.append(key_embd) 162 | value_embds.append(value_embd) 163 | return torch.stack(key_embds), torch.stack(value_embds) 164 | 165 | def get_special_token_embd(self, token_type): 166 | """ 167 | Get the embedding for the special token, 168 | take in a string, returns a tensor 169 | """ 170 | idx = torch.tensor(self.kb_special_token[token_type]).to( 171 | self.embedding.weight.device 172 | ) 173 | return self.embedding(idx).bfloat16() 174 | -------------------------------------------------------------------------------- /src/kblam/models/kblam_config.py: -------------------------------------------------------------------------------- 1 | from transformers import PretrainedConfig 2 | 3 | 4 | class KBLaMConfig(PretrainedConfig): 5 | def __init__( 6 | self, 7 | base_model_name_or_path: str = "", 8 | kb_layer_frequency: int = 3, 9 | kb_scale_factor: int | None = None, 10 | top_k_kb: int = 100, 11 | dynamic_sparsify: bool = False, 12 | sep_query_head: bool = False, 13 | attn_implementation: str = "eager", 14 | **kwargs, 15 | ): 16 | self.base_model_name_or_path = base_model_name_or_path 17 | self.kb_layer_frequency = kb_layer_frequency 18 | self.kb_scale_factor = kb_scale_factor 19 | self.top_k_kb = top_k_kb 20 | self.dynamic_sparsify = dynamic_sparsify 21 | self.sep_query_head = sep_query_head 22 | self.attn_implementation = attn_implementation 23 | super().__init__(**kwargs) 24 | -------------------------------------------------------------------------------- /src/kblam/models/kblam_processor.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from transformers.processing_utils import ProcessorMixin 3 | from transformers import AutoTokenizer 4 | from transformers.tokenization_utils_base import PreTokenizedInput, TextInput 5 | from transformers import BatchFeature 6 | 7 | from kblam.kb_encoder import KBEncoder 8 | import torch 9 | 10 | from dataclasses import dataclass 11 | 12 | 13 | @dataclass 14 | class EncoderArgs: 15 | encoder_name: str 16 | hidden_size: int 17 | num_hidden_layers: int 18 | kb_layer_frequency: int 19 | encoder_dir: str 20 | projector_type: str 21 | endpoint_url: str 22 | 23 | 24 | class KBLaMProcessor(ProcessorMixin): 25 | feature_extractor_class = "AutoFeatureExtractor" 26 | tokenizer_class = "AutoTokenizer" 27 | 28 | def __init__(self, tokenizer: AutoTokenizer, args: EncoderArgs, **kwargs): 29 | self.kb_encoder = self.load_encoder(args) 30 | self.tokenizer = tokenizer 31 | self.tokenizer.pad_token = self.tokenizer.eos_token 32 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 33 | super().__init__(self.kb_encoder, self.tokenizer) 34 | 35 | def load_encoder(self, args: EncoderArgs): 36 | encoder = KBEncoder( 37 | encoder_name=args.encoder_name, 38 | projector_type=args.projector_type, 39 | endpoint_url=args.endpoint_url, 40 | out_dim=args.hidden_size 41 | * (args.num_hidden_layers // args.kb_layer_frequency + 1), 42 | frozen_base_model=True, 43 | projector_kwargs={"mlp_depth": 1, "mlp_hidden_dim": 512}, 44 | get_oai_embd_online=False, 45 | ) 46 | 47 | encoder.load_state_dict(torch.load(args.encoder_dir)) 48 | return encoder 49 | 50 | def __call__( 51 | self, 52 | knowledge_base: list[tuple[torch.Tensor]] | list[tuple[str]] = None, 53 | text: Union[ 54 | TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput] 55 | ] = None, 56 | ) -> BatchFeature: 57 | # Process the knowledge base if needed 58 | if ( 59 | knowledge_base 60 | and isinstance(knowledge_base, list) 61 | and isinstance(knowledge_base[0][0], torch.Tensor) 62 | ): 63 | knowledge_base = self.kb_encoder.encode_base_embeddings(knowledge_base) 64 | elif ( 65 | knowledge_base 66 | and isinstance(knowledge_base, list) 67 | and isinstance(knowledge_base[0][0], str) 68 | ): 69 | knowledge_base = self.kb_encoder.encode(knowledge_base) 70 | 71 | # Process the text 72 | input_str = ( 73 | "<|start_header_id|>user<|end_header_id|> " 74 | + text 75 | + "<|eot_id|>" 76 | + "<|start_header_id|>assistant<|end_header_id|>" 77 | ) 78 | 79 | text_inputs = self.tokenizer(input_str, return_tensors="pt", padding=True).to( 80 | self.device 81 | ) 82 | return BatchFeature(data={**text_inputs, "kb_kvs": knowledge_base}) 83 | 84 | def batch_decode(self, *args, **kwargs): 85 | """ 86 | This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please 87 | refer to the docstring of this method for more information. 88 | """ 89 | return self.tokenizer.batch_decode(*args, **kwargs) 90 | 91 | def decode(self, *args, **kwargs): 92 | """ 93 | This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to 94 | the docstring of this method for more information. 95 | """ 96 | return self.tokenizer.decode(*args, **kwargs) 97 | -------------------------------------------------------------------------------- /src/kblam/utils/convert.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import json 3 | 4 | 5 | def flatten_results(data): 6 | rows = [] 7 | for exp in data: 8 | kb_size = exp["kb_size"] 9 | for result in exp["accuracy_results"]: 10 | rows.append( 11 | { 12 | "kb_size": kb_size, 13 | "idx": result["idx"], 14 | "acc": result["acc"], 15 | "top5acc": result["top5acc"], 16 | } 17 | ) 18 | return rows 19 | 20 | 21 | # Load JSON from file 22 | with open("/datadisk/kblam_attention_acc_results/accuracy_results.json", "r") as f: 23 | results = json.load(f) 24 | 25 | # Convert to DataFrame and save as CSV 26 | df = pd.DataFrame(flatten_results(results)) 27 | df = df.sort_values(["kb_size", "idx"]).reset_index(drop=True) 28 | 29 | # Save to CSV 30 | df.to_csv("accuracy_results.csv", index=False) 31 | -------------------------------------------------------------------------------- /src/kblam/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass 3 | 4 | import numpy as np 5 | 6 | 7 | @dataclass 8 | class Entity: 9 | name: str 10 | description: str 11 | objectives: str 12 | purpose: str 13 | 14 | 15 | @dataclass 16 | class DataPoint: 17 | name: str 18 | description_type: str 19 | description: str 20 | Q: str = None 21 | A: str = None 22 | key_string: str = None 23 | extended_Q: str = None 24 | extended_A: str = None 25 | 26 | 27 | def save_entity(pair: Entity | DataPoint, output_file: str) -> None: 28 | """Save a JSON entity to a file.""" 29 | try: 30 | with open(output_file, "a+") as f: 31 | json.dump(pair.__dict__, f) 32 | f.write("\n") 33 | except Exception as e: 34 | print("Error saving entity.") 35 | print(e) 36 | 37 | 38 | def load_entities(inout_file: str) -> list[Entity | DataPoint]: 39 | """Load entities from a file.""" 40 | entities = [] 41 | try: 42 | with open(inout_file, "r") as f: 43 | for line in f: 44 | entity = json.loads(line) 45 | entities.append(entity) 46 | except Exception as e: 47 | print("Error loading entities.") 48 | print(e) 49 | return entities 50 | 51 | 52 | def get_i_dont_know_ans(): 53 | return "I am sorry I cannot find relevant information in the KB." 54 | 55 | 56 | def augment_row(row: dict[str, str]) -> list[dict[str, str]]: 57 | """Augment an entity with questions from pre-defined templates.""" 58 | templates = [ 59 | "What {} does {} have?", 60 | "What is the {} of {}?", 61 | "Tell me about the {} of {}.", 62 | "Can you let me know the {} of {}?", 63 | "Can you inform me about the {} of {}?", 64 | "Describe the {} of {}.", 65 | "What details can you share about the {} of {}?", 66 | "What kind of {} does {} have?", 67 | "Provide details on the {} of {}.", 68 | "What features does the {} of {} include?", 69 | "Can you elaborate on the {} of {}?", 70 | "How would you describe the {} of {}?", 71 | "What can you tell me about the {} characteristics of {}?", 72 | "Can you explain the {} of {}?", 73 | "What insights can you provide about the {} of {}?", 74 | "What should I know about the {} of {}?", 75 | ] 76 | dtype = row["description_type"] 77 | name = row["name"] 78 | tid = np.random.randint(0, len(templates)) 79 | return templates[tid].format(dtype, name) 80 | 81 | 82 | def generate_multi_entity_qa( 83 | names: list[str], properties: list[str], answers: list[str] 84 | ) -> tuple[str, str]: 85 | """Generate a question-answer pair for multiple entities.""" 86 | templates = [ 87 | "What is {}?", 88 | "Tell me {}.", 89 | "Can you let me know {}?", 90 | "Can you inform me {}?", 91 | "Describe {}.", 92 | "Explain {}.", 93 | "Could you describe the {}?", 94 | "What can you tell me about {}?", 95 | "Could you provide information on {}?", 96 | "Please enlighten me about {}.", 97 | "Can you clarify {} for me?", 98 | "Could you give me a detailed description of {}?", 99 | "I need more information on {}.", 100 | ] 101 | template_idx = np.random.randint(0, len(templates)) 102 | question_body = "" 103 | for name, property in zip(names[:-1], properties[:-1]): 104 | question_body += f"the {property} of {name}," 105 | question_body += f" and the {properties[-1]} of {names[-1]}" 106 | answer_str = "" 107 | for answer, name, property in zip(answers, names, properties): 108 | answer_str += f"The {property} of {name} is {answer}; " 109 | return templates[template_idx].format(question_body), answer_str.strip() 110 | -------------------------------------------------------------------------------- /src/kblam/utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import numpy as np 4 | import torch 5 | import transformers 6 | 7 | from KBLaM.src.kblam.models.kblam_config import KBLaMConfig 8 | from KBLaM.src.kblam.models.llama3_model import KblamLlamaForCausalLM 9 | from KBLaM.src.kblam.models.phi3_model import KBLaMPhi3ForCausalLM 10 | 11 | instruction_prompts = """ 12 | Please answer questions based on the given text with format: "The {property} of {name} is {description}" 13 | """ 14 | 15 | instruction_prompts_multi_entities = """ 16 | Please answer questions based on the given text with format: "The {property} of {name1} is {description}; The {property} of {name2} is {description}; ..." 17 | """ 18 | 19 | zero_shot_prompt = """ 20 | Please answer the question in a very compact manner with format: The {property} of {name} is {description} 21 | """ 22 | 23 | zero_shot_prompt_multi_entities = """ 24 | Please answer the question in a very compact manner with format: "The {property} of {name1} is {description}; The {property} of {name2} is {description}; ... 25 | """ 26 | 27 | 28 | def _prune_for_llama(S: str) -> str: 29 | S = S.replace("<|eot_id|>", "") 30 | S = S.replace("<|start_header_id|>assistant<|end_header_id|>", "") 31 | S = S.replace("<|start_header_id|>user<|end_header_id|>", "") 32 | S = S.replace("<|end_of_text|>", "") 33 | return S 34 | 35 | 36 | def _prune_for_phi3(S: str) -> str: 37 | S = S.replace("<|end|>", "") 38 | S = S.replace("<|assistant|>", "") 39 | S = S.replace("<|user|>", "") 40 | return S 41 | 42 | 43 | def softmax(x: np.array, axis: int) -> np.array: 44 | """Compute softmax values for each sets of scores in x.""" 45 | e_x = np.exp(x - np.max(x)) 46 | return e_x / e_x.sum(axis=axis) 47 | 48 | 49 | def _format_Q_llama(Q: str): 50 | return ( 51 | "<|start_header_id|>user<|end_header_id|> " 52 | + Q 53 | + "<|eot_id|>" 54 | + "<|start_header_id|>assistant<|end_header_id|>" 55 | ) 56 | 57 | 58 | def _format_Q_phi3(Q: str): 59 | return "<|user|>\n" + Q + "<|end|>\n" + "<|assistant|>\n" 60 | 61 | 62 | model_question_format_mapping = { 63 | KblamLlamaForCausalLM: _format_Q_llama, 64 | KBLaMPhi3ForCausalLM: _format_Q_phi3, 65 | } 66 | model_prune_format_mapping = { 67 | KblamLlamaForCausalLM: _prune_for_llama, 68 | KBLaMPhi3ForCausalLM: _prune_for_phi3, 69 | } 70 | 71 | 72 | def answer_question( 73 | tokenizer: transformers.PreTrainedTokenizer, 74 | model: KBLaMPhi3ForCausalLM | KblamLlamaForCausalLM, 75 | Q: str, 76 | kb=None, 77 | kb_config: Optional[KBLaMConfig] = None, 78 | ): 79 | for m in model_question_format_mapping: 80 | if isinstance(model, m): 81 | input_str = model_question_format_mapping[m](Q) 82 | tokenizer_output = tokenizer(input_str, return_tensors="pt", padding=True).to( 83 | "cuda" 84 | ) 85 | input_ids, attention_masks = ( 86 | tokenizer_output["input_ids"], 87 | tokenizer_output["attention_mask"], 88 | ) 89 | 90 | with torch.autograd.no_grad(): 91 | outputs = model.generate( 92 | input_ids=input_ids, 93 | attention_mask=attention_masks, 94 | kb_kvs=kb, 95 | max_new_tokens=150, 96 | tokenizer=tokenizer, 97 | output_attentions=True, 98 | kb_config=kb_config, 99 | ).squeeze() 100 | outputs = tokenizer.decode(outputs, skip_special_tokens=False) 101 | 102 | for m in model_prune_format_mapping: 103 | if isinstance(model, m): 104 | pruned_output = model_prune_format_mapping[m](outputs) 105 | return pruned_output 106 | -------------------------------------------------------------------------------- /src/kblam/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | 3 | from kblam.models.kblam_processor import EncoderArgs, KBLaMProcessor 4 | from kblam.models.llama3_model import KblamLlamaForCausalLM 5 | 6 | 7 | def load_model_and_processor( 8 | model_path: str, encoder_name: str, kb_layer_frequency: int, encoder_dir: str 9 | ) -> tuple[AutoModelForCausalLM, KBLaMProcessor]: 10 | model = KblamLlamaForCausalLM.from_pretrained(model_path).bfloat16() 11 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 12 | 13 | args = EncoderArgs( 14 | encoder_name=encoder_name, 15 | hidden_size=model.config.hidden_size, 16 | num_hidden_layers=model.config.num_hidden_layers, 17 | kb_layer_frequency=kb_layer_frequency, 18 | encoder_dir=encoder_dir, 19 | ) 20 | 21 | processor = KBLaMProcessor(tokenizer, args) 22 | return model, processor 23 | -------------------------------------------------------------------------------- /src/kblam/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.nn import CrossEntropyLoss 4 | import argparse 5 | 6 | from torch.optim.optimizer import ParamsT 7 | from torch.nn.parallel import DistributedDataParallel 8 | 9 | 10 | def get_tensor_config(x: torch.tensor) -> dict[str, any]: 11 | return {"dtype": x.dtype, "layout": x.layout, "device": x.device} 12 | 13 | 14 | def preprocess_embds(emb1: list, emb2: list, kb_mask_val: int = 1): 15 | """ 16 | emb1: List of embeddings of the KB. 17 | emb2: List of embeddings of the query. 18 | kb_mask_val: Attention mask value for the KB, i.e. emb1, part. 19 | 20 | The function would first pad emb1 on the left and then concat emb1 and emb2. 21 | 22 | Return: 23 | A single 2-D embedding tensor, the attention mask, and the position_ids, 24 | where the position_ids for the KB embeddings parts are set 0. 25 | """ 26 | assert isinstance(emb1, list) 27 | assert isinstance(emb2, list) 28 | assert len(emb1) == len(emb2) 29 | max_length = max([e1.shape[0] + e2.shape[0] for e1, e2 in zip(emb1, emb2)]) 30 | joint_embs = [] 31 | attention_masks = [] 32 | position_ids = [] 33 | kb_masks = [] 34 | for e1, e2 in zip(emb1, emb2): 35 | tensor_config = get_tensor_config(e1) 36 | pad_size = max_length - e1.shape[0] - e2.shape[0] 37 | padding = torch.zeros((pad_size, e1.shape[1]), **tensor_config) 38 | joint_embs.append(torch.concat([padding, e1, e2])) 39 | attention_mask = torch.cat( 40 | [ 41 | torch.zeros(pad_size, **tensor_config), 42 | torch.zeros(e1.shape[0], **tensor_config) + kb_mask_val, # Attention mask for KB 43 | torch.ones(e2.shape[0], **tensor_config), # Attention mask for the question 44 | ] 45 | ) 46 | 47 | attention_masks.append(attention_mask) 48 | position_id = torch.cat( 49 | [ 50 | torch.zeros(max_length - e2.shape[0], **tensor_config) - 1, 51 | torch.arange(1, e2.shape[0] + 1, **tensor_config) - 1, 52 | ] 53 | ) 54 | position_ids.append(position_id) 55 | 56 | kb_mask = torch.cat( 57 | [ 58 | torch.zeros(pad_size, **tensor_config), 59 | torch.ones(e1.shape[0], **tensor_config), 60 | torch.zeros(e2.shape[0], **tensor_config), 61 | ] 62 | ) 63 | kb_masks.append(kb_mask) 64 | 65 | return ( 66 | torch.stack(joint_embs), 67 | torch.stack(attention_masks), 68 | torch.stack(position_ids), 69 | torch.stack(kb_masks), 70 | ) 71 | 72 | 73 | def kb_to_embd(kb_encoder, kb_dict=None, precomputed_base_embd=None): 74 | if isinstance(kb_encoder, DistributedDataParallel): 75 | kb_encoder = kb_encoder.module 76 | key_embds, value_embds = [], [] 77 | if precomputed_base_embd is not None: 78 | for key_base_embd, value_base_embd in zip(*precomputed_base_embd): 79 | key_embds.append(kb_encoder.encode_key(base_emb=key_base_embd)) 80 | value_embds.append(kb_encoder.encode_val(base_emb=value_base_embd)) 81 | else: 82 | for entity in kb_dict: 83 | key_embds.append(kb_encoder.encode_key(S=entity["key_string"])) 84 | value_embds.append(kb_encoder.encode_val(S=entity["description"])) 85 | return (torch.stack(key_embds), torch.stack(value_embds)) 86 | 87 | 88 | def get_kb_embd( 89 | kb_encoder: torch.nn.Module, 90 | indices: list[int], 91 | kb_dict: dict = None, 92 | precomputed_embd: tuple[torch.tensor] = None, 93 | ) -> tuple[torch.tensor]: 94 | if precomputed_embd: 95 | key_embds, value_embds = precomputed_embd 96 | train_set_key, train_set_val = kb_to_embd( 97 | kb_encoder, 98 | precomputed_base_embd=np.stack([key_embds[indices], value_embds[indices]]), 99 | ) 100 | elif kb_dict: 101 | if len(indices.shape) == 2: 102 | # Sampling batch of multi entities 103 | train_set_key, train_set_val = [], [] 104 | for indices_row in indices.T: 105 | _train_set_key, _train_set_val = kb_to_embd(kb_encoder, kb_dict=[kb_dict[i] for i in indices_row]) 106 | (train_set_key.append(_train_set_key),) 107 | train_set_val.append(_train_set_val) 108 | train_set_key = torch.stack(train_set_key, 1) 109 | train_set_val = torch.stack(train_set_val, 1) 110 | elif len(indices.shape) == 1: 111 | train_set_key, train_set_val = kb_to_embd(kb_encoder, kb_dict=[kb_dict[i] for i in indices]) 112 | return train_set_key, train_set_val 113 | 114 | 115 | def weighted_nll(model, input_ids, attention_mask, labels, kb=None): 116 | out = model( 117 | input_ids=input_ids, 118 | attention_mask=attention_mask, 119 | kb_kv=kb, 120 | output_attentions=True, 121 | ) 122 | logits = out["logits"] 123 | shift_logits = logits[..., :-1, :].contiguous() 124 | shift_labels = labels[..., 1:].contiguous() 125 | weights = (shift_labels > 0).sum(-1, keepdim=True).expand(-1, shift_labels.shape[1]).contiguous() 126 | shift_logits = shift_logits.view(-1, model.config.vocab_size) 127 | shift_labels = shift_labels.view(-1) 128 | weights = weights.view(-1) 129 | loss_fct = CrossEntropyLoss(reduction="none") 130 | shift_labels = shift_labels.to(shift_logits.device) 131 | loss = (loss_fct(shift_logits, shift_labels) * weights.max() / weights).mean() 132 | return loss 133 | 134 | 135 | def compute_perplexity_gain(model, kb, input_ids, attention_mask, labels): 136 | with torch.autograd.no_grad(): 137 | unconditioned_nll = weighted_nll(model, input_ids, attention_mask, labels, kb=None) 138 | conditioned_nll = weighted_nll(model, input_ids, attention_mask, labels, kb) 139 | return unconditioned_nll, conditioned_nll # Loss should decrease 140 | 141 | 142 | def context_set_size_scheduler(curr_step: int, kb_size: list[int] | int | str) -> int: 143 | """Determines the KB size for the current training step. 144 | The KB size can be a fixed number, a list of numbers or a "dynamic" value. 145 | If no KB size is provided, the KB size is dynamicly increased every 100 steps.""" 146 | 147 | dynamic_range = (10, 200) 148 | if kb_size == "dynamic": 149 | return np.random.randint(dynamic_range[0], dynamic_range[1]) 150 | 151 | if isinstance(kb_size, list): 152 | return np.random.randint(kb_size[0], kb_size[1]) 153 | 154 | increase_kb_size_every = 100 155 | if not kb_size: 156 | round = (curr_step) // increase_kb_size_every 157 | return 4 * (round + 1) 158 | 159 | return kb_size 160 | 161 | 162 | def get_prefix_str(args: argparse.Namespace) -> str: 163 | kb_size = args.kb_size 164 | if kb_size == -1: 165 | kb_size = None # Progressively increase size 166 | elif kb_size == 0: 167 | kb_size = "dynamic" # Random size 168 | 169 | prefix_string = f"stage1_lr_{args.lr}" 170 | if args.kb_token_layer_frequency is not None: 171 | prefix_string += f"KBTokenLayerFreq{args.kb_token_layer_frequency}" 172 | if args.use_extended_qa: 173 | prefix_string += "UseExtendedQA" 174 | if args.multi_entities is not None: 175 | prefix_string += f"MultiEntities{args.multi_entities}" 176 | if args.outlier_num > 0: 177 | prefix_string += f"UseOutlier{args.outlier_num}" 178 | if args.length_invariance: 179 | prefix_string += "LengthInvariant" 180 | if kb_size is not None: 181 | prefix_string += f"KBSize{kb_size}" 182 | if args.sep_query_head: 183 | prefix_string += "SepQueryHead" 184 | if args.use_data_aug: 185 | prefix_string += "UseDataAug" 186 | return prefix_string 187 | 188 | 189 | def setup_scheduler_and_optimizer(model_parapmeters: ParamsT, lr: float, max_iter: int) -> tuple: 190 | optim = torch.optim.AdamW(model_parapmeters, lr=lr, weight_decay=0.0) # type: ignore 191 | 192 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, max_iter, eta_min=lr * 0.01) # type: ignore 193 | return scheduler, optim 194 | -------------------------------------------------------------------------------- /tests/sample_data.json: -------------------------------------------------------------------------------- 1 | [{"name": "interest-based negotiation workshop", "entity_type": "workshop", "description_type": "purpose", "description": "introduces and reinforces interest-based negotiation skills", "Q": "What is the purpose of interest-based negotiation workshop?", "A": "introduces and reinforces interest-based negotiation skills"}, {"name": "Meeting", "entity_type": "recurring meeting", "description_type": "purpose", "description": "Discuss recent IT issues", "Q": "What is the purpose of Meeting?", "A": "Discuss recent IT issues"}, {"name": "PDCI", "entity_type": "product", "description_type": "definition", "description": "TTC (Time Change Capacity) values for power distribution", "Q": "What is the definition of PDCI?", "A": "TTC (Time Change Capacity) values for power distribution"}, {"name": "OMAR-Online Market Simulation", "entity_type": "project", "description_type": "purpose", "description": "Conduct a market simulation to enhance features of the MDAS-Online application", "Q": "What is the purpose of OMAR-Online Market Simulation?", "A": "Conduct a market simulation to enhance features of the MDAS-Online application"}] -------------------------------------------------------------------------------- /tests/test_dataset.json: -------------------------------------------------------------------------------- 1 | {"name": "Sigma Learning Hub", "description_type": "objectives", "description": "Our goal? Build a rad community of culinary lovers, spark some tasty creativity, and get people swapping cooking tips and tricks.", "Q": "What is the objectives of Sigma Learning Hub?", "A": "The objectives of Sigma Learning Hub is Our goal? Build a rad community of culinary lovers, spark some tasty creativity, and get people swapping cooking tips and tricks..", "key_string": "the objectives of Sigma Learning Hub"} 2 | {"name": "DeltaPrime Technologies", "description_type": "objectives", "description": "To cook up new, green farming techniques and seriously boost crop yields.", "Q": "What is the objectives of DeltaPrime Technologies?", "A": "The objectives of DeltaPrime Technologies is To cook up new, green farming techniques and seriously boost crop yields..", "key_string": "the objectives of DeltaPrime Technologies"} 3 | -------------------------------------------------------------------------------- /tests/test_dataset_construction.py: -------------------------------------------------------------------------------- 1 | from kblam.utils.data_utils import load_entities 2 | 3 | 4 | def test_dataset_QA(): 5 | dataset_path = "tests/test_dataset.json" 6 | dataset = load_entities(dataset_path) 7 | 8 | assert len(dataset) == 2 9 | assert isinstance(dataset[0], dict) 10 | assert isinstance(dataset[1], dict) 11 | -------------------------------------------------------------------------------- /tests/test_kb_encoder.py: -------------------------------------------------------------------------------- 1 | import json 2 | from kblam.kb_encoder import KBEncoder 3 | 4 | 5 | def test_kb_encoder(): 6 | dataset = json.load(open("tests/sample_data.json")) 7 | all_keys = [row["name"] for row in dataset] 8 | all_values = [row["description"] for row in dataset] 9 | 10 | out_dim = 3072 11 | proj_kwargs = {"mlp_depth": 2, "mlp_hidden_dim": 512} 12 | 13 | kb_encoder = KBEncoder( 14 | "all-MiniLM-L6-v2", "mlp", out_dim, None, proj_kwargs, device="cpu" 15 | ) 16 | 17 | assert ( 18 | kb_encoder.encode_key(all_keys[0]).shape 19 | == kb_encoder.encode_val(all_values[0]).shape 20 | ) 21 | 22 | for k, v in kb_encoder.named_parameters(): 23 | if v.requires_grad: 24 | assert ("projector" in k) or ("embedding" in k) 25 | --------------------------------------------------------------------------------