├── .github └── workflows │ ├── headers.yml │ └── style.yml ├── .gitignore ├── .gitmodules ├── .licenserc.yaml ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── gh-pages ├── README.md ├── _config.yaml └── workshop │ └── COLM_2025 │ ├── README.md │ └── ram2.jpeg ├── projects ├── README.md ├── cocomix │ ├── README.md │ ├── cocomix.png │ ├── conf │ │ ├── config.yaml │ │ ├── config_eval.yaml │ │ ├── ddp.yaml │ │ ├── fsdp_bf16.yaml │ │ └── setup │ │ │ ├── gpt2_1b_cocomix.yaml │ │ │ ├── gpt2_1b_ntp.yaml │ │ │ ├── gpt2_386m_cocomix.yaml │ │ │ ├── gpt2_386m_ntp.yaml │ │ │ ├── gpt2_69m_cocomix.yaml │ │ │ └── gpt2_69m_ntp.yaml │ ├── data │ │ ├── __init__.py │ │ ├── data.py │ │ └── openwebtext_preprocess │ │ │ ├── prepare.py │ │ │ └── readme.md │ ├── main.py │ ├── models │ │ ├── __init__.py │ │ ├── concept_extractor.py │ │ ├── modeling_gpt2_cocomix.py │ │ └── sparse_autoencoder │ │ │ ├── __init__.py │ │ │ ├── kernels.py │ │ │ ├── loss.py │ │ │ ├── model.py │ │ │ └── paths.py │ ├── requirements.txt │ ├── slurm_bash │ │ └── slurm_multi.sh │ ├── test.py │ ├── train │ │ ├── __init__.py │ │ ├── train_func │ │ │ ├── __init__.py │ │ │ ├── cocomix.py │ │ │ └── ntp.py │ │ └── trainer.py │ └── utils.py ├── cope │ ├── README.md │ ├── eval.py │ ├── figures │ │ ├── CoPE.png │ │ └── counting_task.png │ ├── requirements.txt │ ├── run.sh │ ├── scripts │ │ └── count_data_gen.py │ ├── src │ │ ├── cope │ │ │ └── context_position.py │ │ ├── data │ │ │ ├── __init__.py │ │ │ ├── constants.py │ │ │ ├── data_collator.py │ │ │ ├── simple.py │ │ │ └── tokenizer.py │ │ ├── main.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── relative_position.py │ │ │ ├── simple_transformer.py │ │ │ └── transformer.py │ │ ├── trainer.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── checkpoint.py │ │ │ ├── distributed.py │ │ │ ├── logger.py │ │ │ └── world.py │ └── train.py ├── length_instruct │ ├── README.md │ └── fig.png ├── mta │ ├── README.md │ ├── __init__.py │ ├── configs │ │ ├── 1B_baseline.yaml │ │ ├── 1B_mta.yaml │ │ ├── 1B_talking_heads.yaml │ │ ├── 300M_baseline.yaml │ │ ├── 300M_mta.yaml │ │ ├── 300M_talking_heads.yaml │ │ ├── 550M_baseline.yaml │ │ ├── 550M_mta.yaml │ │ ├── 550M_talking_heads.yaml │ │ ├── 830M_baseline.yaml │ │ ├── 830M_mta.yaml │ │ └── 830M_talking_heads.yaml │ ├── data.py │ ├── eval.py │ ├── figures │ │ └── attn_schema.png │ ├── generate.py │ ├── mta_transformer.py │ ├── single_json.py │ ├── tokenizer.py │ ├── train.py │ └── transformer.py ├── sd-ra-it │ ├── README.md │ ├── configs │ │ ├── dpo_70b.yaml │ │ ├── dpo_8b.yaml │ │ ├── sft_70b.yaml │ │ └── sft_8b.yaml │ └── scripts │ │ ├── create_self_demo_train_set.sh │ │ ├── data │ │ └── io_to_qas_format.py │ │ ├── eval.py │ │ ├── generate.py │ │ ├── get_demos.py │ │ ├── prompt_optimization.py │ │ ├── relevance.py │ │ └── reward_model_gemma.py ├── self_notes │ ├── README.md │ ├── fig_method.png │ ├── fig_rel.png │ └── toy_story │ │ ├── constants.py │ │ ├── graph.py │ │ ├── main.py │ │ ├── relation.py │ │ ├── rules.py │ │ └── world.py └── self_taught_evaluator │ ├── README.md │ ├── data │ ├── prompts │ │ ├── eval_plan.prompt │ │ └── worse_response.prompt │ └── training_data.yaml │ ├── figures │ └── self_taught_dpo.png │ ├── run_inference_wvllm.sh │ ├── run_rewardbench.sh │ ├── src │ ├── __init__.py │ ├── load_dpo_data_from_hf.py │ ├── prepare_dpo_data.py │ ├── prepare_sft_data.py │ ├── requirements.txt │ ├── run_model.py │ └── utils.py │ └── training_configs │ ├── dpo_training.yaml │ └── sft_training.yaml ├── ram ├── __init__.py ├── data.py └── data_utils.py └── setup.py /.github/workflows/headers.yml: -------------------------------------------------------------------------------- 1 | name: headers 2 | run-name: License headers check by ${{ github.actor }} 3 | 4 | on: 5 | # Trigger the workflow on push to main or any pull request 6 | push: 7 | branches: 8 | - main 9 | pull_request: 10 | 11 | jobs: 12 | header-checks: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Checkout code 16 | uses: actions/checkout@v4 17 | - name: Check License Header 18 | uses: apache/skywalking-eyes/header@main 19 | -------------------------------------------------------------------------------- /.github/workflows/style.yml: -------------------------------------------------------------------------------- 1 | name: style 2 | run-name: Style test by ${{ github.actor }} 3 | 4 | on: 5 | # Trigger the workflow on push to main or any pull request 6 | push: 7 | branches: 8 | - main 9 | pull_request: 10 | 11 | jobs: 12 | style-tests: 13 | runs-on: ubuntu-latest 14 | strategy: 15 | matrix: 16 | python-version: ["3.8", "3.12"] 17 | steps: 18 | - uses: actions/checkout@v4 19 | - name: Set up Python 20 | uses: actions/setup-python@v4 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | - name: Black check 24 | run: | 25 | pip install black==24.8.0 26 | black --diff --check projects 27 | 28 | - name: isort check 29 | run: | 30 | pip install isort==5.13.2 31 | isort --profile black --check projects 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### Python ### 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 112 | .pdm.toml 113 | .pdm-python 114 | .pdm-build/ 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ 165 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "public_repos/lingua"] 2 | path = public_repos/lingua 3 | url = git@github.com:facebookresearch/lingua.git 4 | -------------------------------------------------------------------------------- /.licenserc.yaml: -------------------------------------------------------------------------------- 1 | header: 2 | license: 3 | content: | 4 | Copyright (c) Meta Platforms, Inc. and affiliates. 5 | 6 | This source code is licensed under the MIT license found in the 7 | LICENSE file in the root directory of this source tree. 8 | 9 | paths: 10 | - "**/*.py" 11 | - "**/*.sh" 12 | 13 | comment: never 14 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.5.0 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: end-of-file-fixer 7 | - id: no-commit-to-branch 8 | args: ['--branch', 'main'] 9 | - id: check-added-large-files 10 | args: ['--maxkb=2000'] 11 | - id: check-merge-conflict 12 | - id: detect-aws-credentials 13 | args: ['--allow-missing-credentials'] 14 | - repo: https://github.com/psf/black 15 | rev: 24.8.0 16 | hooks: 17 | - id: black 18 | language_version: python3.10 19 | - repo: https://github.com/PyCQA/isort 20 | rev: 5.13.2 21 | hooks: 22 | - id: isort 23 | args: ['--profile', 'black'] 24 | language_version: python3.10 25 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to RAM 2 | 3 | While we are seeding this project with an initial set of popular tasks and a few 4 | models and examples, ongoing contributions from the research community are 5 | desired to increase the pool of tasks, models, and baselines. 6 | 7 | ## Deploy pages 8 | 1. Make necessary edits in `gh-pages` directory, commit and send PR. 9 | 2. Push `gh-pages` as a subtree to its branch: `git subtree push --prefix gh-pages origin gh-pages` 10 | 3. The page build should be started automatically. 11 | 12 | 13 | ## Pull Requests 14 | We actively welcome your pull requests. 15 | 16 | 1. Fork the repo and then clone the forked repository. (See this [github guide](https://guides.github.com/activities/forking/) on forking for more info). 17 | **If you have already cloned the repo directly and committed changes, follow the steps in the [section below](#moving-changes-youve-committed-to-a-fork)** 18 | 2. Create your branch from `main`. Set up your environment 19 | and run `pre-commit install` once. 20 | 3. Make your changes 21 | 4. If you've added code that should be tested, [add tests](http://parl.ai/docs/tutorial_tests.html). 22 | 5. If you've changed APIs, update the documentation. 23 | 6. Autoformat and lint your code (`bash autoformat.sh`) 24 | 7. (Optional) Ensure the test suite passes. Run `python -m pytest -m unit`. 25 | 8. If you've added a new dataset, you should also run 26 | `python -m pytest -m data`. Copy-paste the output into a comment in your PR. 27 | 9. If you haven't already, complete the Contributor License Agreement ("CLA"). 28 | 10. Link [CircleCI](https://circleci.com/vcs-authorize/) to your github account 29 | if you haven't done so previously (and make sure the CircleCI tests run 30 | successfully on the PR after you push your changes). 31 | 11. Push your changes! 32 | 12. Once the PR is accepted and CI is passing, we will merge the PR for you. 33 | 34 | ### Moving changes you've committed to a fork 35 | 1. Fork the repo 36 | 2. In your local repo, rename your origin remote to upstream 37 | ``` 38 | git remote rename origin upstream 39 | ``` 40 | 3. Point origin to the forked repo (instead of to the original repo) 41 | ``` 42 | git remote add origin git@github... 43 | ``` 44 | 4. Fetch from the new origin 45 | ``` 46 | git fetch origin 47 | ``` 48 | 5. Make your local branch track the remote branch (of the forked repo) 49 | ``` 50 | git branch --set-upstream-to origin/main main 51 | ``` 52 | 53 | ## Contributor License Agreement ("CLA") 54 | In order to accept your pull request, we need you to submit a CLA. You only need 55 | to do this once to work on any of Facebook's open source projects. 56 | 57 | Complete your CLA here: 58 | 59 | ## Issues 60 | We use GitHub issues for general feature discussion, Q&A and public bugs tracking. 61 | Please ensure your description is clear and has sufficient instructions to be able to 62 | reproduce the issue or understand the problem. 63 | 64 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 65 | disclosure of security bugs. In those cases, please go through the process 66 | outlined on that page and do not file a public issue. 67 | 68 | ## Coding Style 69 | We try to follow the PEP style guidelines and encourage you to as well. You 70 | should run the `lint_changed.sh` script before you submit. 71 | 72 | ## License 73 | By contributing to RAM, you agree that your contributions will be licensed 74 | under the LICENSE file in the root directory of this source tree. 75 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Meta Research 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RAM 2 | ## Introduction 3 | This repository focuses on developing advanced algorithms and methods for RAM (Reasoning, Alignment, Memory). 4 | 5 | 17 | 18 | ## Projects 19 | Please go to [Projects](projects/README.md) for a up-to-date list of projects released by RAM. 20 | 21 | 22 | 33 | 34 | ## Contributing 35 | Please read CONTRIBUTING.md for details on our code of conduct, and the process for submitting pull requests. 36 | 37 | ## License 38 | This project is licensed under the MIT License - see the `LICENSE` file for details. The license applies to the released data as well. 39 | 40 | ## Contact 41 | RAM is currently maintained by Olga Golovneva, Ilia Kulikov, Janice Lan, Xian Li, Richard Pang, Sainbayar Sukhbaatar, Tianlu Wang, Jason Weston, Jing Xu, Jane Dwivedi-Yu, Ping Yu, Weizhe Yuan. 42 | For any queries, please reach out to [Jing Xu](https://github.com/jxmsML). 43 | -------------------------------------------------------------------------------- /gh-pages/README.md: -------------------------------------------------------------------------------- 1 | # RAM team @ Meta AI 2 | 3 | [Our projects](https://github.com/facebookresearch/RAM/tree/main/projects#readme) 4 | 5 | [RAM 2 Workshop at COLM 2025](https://facebookresearch.github.io/RAM/workshop/COLM_2025) -------------------------------------------------------------------------------- /gh-pages/_config.yaml: -------------------------------------------------------------------------------- 1 | name: RAM @ Meta AI 2 | title: null 3 | -------------------------------------------------------------------------------- /gh-pages/workshop/COLM_2025/README.md: -------------------------------------------------------------------------------- 1 | ![Workshop logo](ram2.jpeg) 2 | 3 | # RAM 2: Reasoning, Attention & Memory – 10 Years On 4 | 5 | Ten years ago... in Montreal 2015, the RAM workshop took place to bring together the burgeoning field covering the "interplay of reasoning, attention and memory", just before Transformers were invented – but when many of the components to get there had just been published and were in place. The workshop included many speakers who are still prominent in pushing these directions today: Yoshua Bengio, Kyunghyun Cho, Jürgen Schmidhuber, Sainbayar Sukhbaatar, Ilya Sutskever, and more. See the historical website for more details. 6 | 7 | Ten years later... we are hosting RAM 2 in the same location in Montreal, with a two-fold purpose. Firstly, as a retrospective and analysis of what has happened in the last 10 years. We are inviting presenters from the first workshop to this end, as well as to add their current perspectives. Hence secondly, and more importantly, we will bring together the field to discuss new trends and future directions for the next 10 years – which is further enabled by inviting new speakers, panelists and poster presenters discussing these fresh ideas. 8 | 9 | Why does this make sense? The RAM topic is as important as ever, and has gone on to dominate the field. 10 | These new directions include: 11 | 12 | - R: New reasoning methods including both token-based and that use continuous vectors, and how they combine with memory. 13 | 14 | - A: New attention methods that enable better reasoning and use of short and long-term memory. 15 | 16 | - M: Architectural changes to LLMs to improve memory and reasoning capabilities. 17 | 18 | Overall, we highlight that the workshop is most concerned with methods that aim to explore the interplay between these three aspects. 19 | 20 | ## Workshop Event 21 | 22 | * **Location**: Palais des Congrès, Montreal, Canada 23 | * **Date**: October 10, 2025 24 | 25 | ## Call for Papers 26 | 27 | We will host paper submissions on [open review](https://openreview.net/group?id=colmweb.org/COLM/2025/Workshop/RAM2). We invite researchers and practitioners to submit their work to the COLM 2025 Workshop on Reasoning, Attention & Memory 2 (RAM2@COLM25). 28 | 29 | * **Submission Deadline:** June 23, 2025 30 | * **Author Notification Deadline:** July 24, 2025 31 | * **Submission Details:** Submissions should follow the [general guide for COLM conference](https://colmweb.org/cfp.html). Papers can be up to 9 pages (not including references) and have to be anonymized. All submissions must be in PDF format, please use the [LaTeX style files provided by organizers](https://github.com/COLM-org/Template/archive/refs/tags/2025.zip). 32 | * **Dual Submission Policy:** We encourage submissions of novel research and permit submissions currently under review at other conferences or venues. 33 | 34 | 35 | ## Invited speakers 36 | + Yoshua Bengio, Univ. of Montreal 37 | + Kyunghyun Cho, NYU & Prescient Design 38 | + Yejin Choi, Stanford & NVIDIA 39 | + Azalia Mirhoseini, Stanford 40 | + Juergen Schmidhuber, KAUST 41 | + Sainbayar Sukhbaatar, Meta 42 | + Jason Wei, OpenAI 43 | 44 | 45 | ## Organizing Committee 46 | + Ilia Kulikov 47 | + Jason Weston 48 | + Jing XU 49 | + Olga Golovneva 50 | + Swarnadeep Saha 51 | + Marjan Ghazvininejad 52 | + Ping Yu 53 | 54 | 55 | ## Contact us 56 | 57 | Email: pc_ram2 [at] googlegroups.com 58 | -------------------------------------------------------------------------------- /gh-pages/workshop/COLM_2025/ram2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/RAM/c567dfb73c0bc60c7ca1f114b72cd3bf7d9442bd/gh-pages/workshop/COLM_2025/ram2.jpeg -------------------------------------------------------------------------------- /projects/cocomix/README.md: -------------------------------------------------------------------------------- 1 | # CoCoMix 2 | 3 | Official PyTorch implementation of "LLM Pretraining with Continuous Concepts". 4 | 5 |

6 | 7 |

8 | 9 | ## Environment 10 | ``` 11 | conda create -n cocomix python=3.10 -y 12 | conda activate cocomix 13 | 14 | # we have developed/tested CoCoMix on torch 2.3.0+cuda12.1 15 | pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu121 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | ## Code structure 20 | 21 | ``` 22 | Home 23 | |--conf 24 | |--setup 25 | |--gpt2_69m_ntp.yaml # config for gpt2 69m pretraining 20B tokens for next token prediction 26 | |--gpt2_69m_cocomix.yaml # config for gpt2 50m pretraining 20B tokens for cocomix 27 | |--... 28 | |--config.yaml # general config for training 29 | |--config_eval.yaml # general config for evaluation 30 | |--ddp.yaml # config for huggingface accelerate ddp 31 | |--fsdp_bf16.yaml # config for huggingface accelerate fsdp with bf16 32 | |--data 33 | |--data.py # dataset definition / loader 34 | |--model 35 | |--sparse_autoencoder 36 | ... # code for top-k sparse autoencoder 37 | |--__init__.py # Define model loading, concept extractor loading 38 | |--concept_extractor.py # GPT2-124M model with SAE 39 | |--modeling_gpt2_cocomix.py # CoCoMix for GPT2 40 | |--train 41 | |--train_func 42 | |--ntp.py # next token prediction 43 | |--cocomix.py # CoCoMix 44 | |--trainer.py # trainer function defined: optimizer, scheduler, evaluation 45 | |--main.py # main file, define model, define dataset, define trainer 46 | |--test.py # evaluation functions, we use EleutherAI lm-evaluation-harness 47 | |--utils.py # utility functions: loggers 48 | ``` 49 | 50 | ## Preparation and configurations 51 | 52 | **dataset**: 53 | - OpenWebText: run `./data/openwebtext_preprocess/prepare.py`. Readme file `./data/openwebtext_preprocess/readme.md` 54 | - Set `data_dir` in `./conf/config.yaml` (e.g., `./data/openwebtext_preprocess`) 55 | 56 | **WANDB**: To use weight and bias (wandb) logging 57 | - Create a wandb account and get your wandb key 58 | - Set `wandb_key` in `./conf/config.yaml` as your wandb key 59 | - `wandb_project` in `./conf/config.yaml` is the name of your wandb project 60 | - `wandb_entity` in `./conf/config.yaml` is your wandb entity name 61 | - Set `wandb_log` as false if you don't want to use wandb logging 62 | 63 | **Concept related**: 64 | - `insert_layer_index`: Which layer to predict concept labels, insert continous concepts 65 | - `sae_layer_index`: Which layer to extract concepts (from the pretrained model) 66 | - `lam_concept`: concept prediction loss hyperparameter (default: 0.1) 67 | - `concept_dim`: number of concepts on the sparse autoencoder (SAE) latent: pretrained SAE uses 32768 (fixed) 68 | - `concept_num`: number of active concepts (i.e., TopK value of sparse activatation) in TopK SAE: pretrained SAE uses 32 (fixed) 69 | 70 | All configuration for next token prediction and cocomix are presented in `./conf/setup/` 71 | 72 | ## Train code 73 | For all experiments, we have used multi-node training. We have provided a slurm job submit example file in `./slurm_bash`. 74 | - Note that the user needs to fill the details in `./slurm_bash/slurm_multi.sh` to use the slurm file (e.g., account, env_name) 75 | - Currently assuming FSDP (to use DDP, change `--config_file` to `./conf/ddp.yaml`) 76 | 77 | We also provide a single-node training example code (without slurm).\ 78 | If OOM occurs, please increase the gradient accumulation step `grad_acc_steps` and reduce the micro batch size `update_batch_size`. 79 | ``` 80 | # train gpt2 69m on openwebtext with next token prediction 81 | sbatch ./slurm_bash/slurm_multi.sh setup=gpt2_69m_ntp 82 | 83 | # train gpt2 69m on openwebtext with cocomix 84 | sbatch ./slurm_bash/slurm_multi.sh setup=gpt2_69m_cocomix 85 | 86 | # train gpt2 69m on single node with FSDP 87 | accelerate launch --config_file ./conf/fsdp_bf16.yaml --num_processes=8 main.py setup=gpt2_69m_ntp 88 | 89 | # train gpt2 69m on single node with DDP 90 | accelerate launch --config_file ./conf/ddp.yaml --num_processes=8 main.py setup=gpt2_69m_ntp 91 | ``` 92 | 93 | ## Evaluation code 94 | Set `data_dir` in `./conf/config_eval.yaml` with the preprocessed openwebtext dataset path (e.g., `./data/openwebtext_preprocess`).\ 95 | We use [lm-eval-harness](https://github.com/EleutherAI/lm-evaluation-harness) for the evaluation (except for openwebtext validation perplexity). To evaluate on different dataset, please modify `eval_tasks` in `./conf/config_eval.yaml`.\ 96 | Note that `eval_single_ckpt` defines whether to evaluate a single checkpoint or evaluate the entire saved checkpoints with a given freqencey (e.g., if the user have saved the ckpt every 2000 training steps, by setting true, it will evaluate all ckpts at once). 97 | ``` 98 | # two options 99 | # eval_single_ckpt=True or False 100 | 101 | # if True, pass the path including the step (e.g., ./logs/.../step_xxx/), this will only evaluate single ckpt 102 | # the eval_results.json will be saved in ./logs/.../step_xxx/ 103 | CUDA_VISIBLE_DEVICES=0 python test.py eval_single_ckpt=True load_path= 104 | 105 | # else, pass the path excluding the step (e.g., ./logs/.../), this will evaluate all ckpts with a frequency of eval_freq (e.g., step_2000, step_4000, ...) 106 | # the eval_results.json will be saved in ./logs/.../ 107 | CUDA_VISIBLE_DEVICES=0 python test.py load_path= eval_freq=2000 108 | ``` 109 | -------------------------------------------------------------------------------- /projects/cocomix/cocomix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/RAM/c567dfb73c0bc60c7ca1f114b72cd3bf7d9442bd/projects/cocomix/cocomix.png -------------------------------------------------------------------------------- /projects/cocomix/conf/config.yaml: -------------------------------------------------------------------------------- 1 | wandb_log: true 2 | wandb_entity: null 3 | wandb_project: null 4 | wandb_key: null 5 | 6 | defaults: 7 | - _self_ 8 | - setup: 'gpt2_69m' 9 | 10 | hydra: 11 | run: 12 | dir: . 13 | 14 | mode: 'ntp' 15 | seed: 22 16 | rank: 0 17 | suffix: null 18 | 19 | # model 20 | base_model: 'openai-community/gpt2' 21 | pretrained_model: 'openai-community/gpt2' 22 | dataset: openwebtext 23 | data_dir: './data/openwebtext_preprocess' # set your data path 24 | n_embd: null 25 | n_layer: null 26 | n_head: null 27 | vocab_size: null 28 | 29 | load_path: null 30 | port: 9819 31 | distributed: False 32 | world_size: 1 33 | use_torch_compile: True 34 | compile_dynamo_cache_size_limit: 256 35 | 36 | # optimization 37 | lr: 6e-4 38 | lr_schedule: 'cosine_with_min_lr' # 'cosine' 'constant_with_warmup' 'constant', 39 | beta1: 0.9 40 | beta2: 0.95 41 | grad_clip_thresh: 1. 42 | warmup_steps: 2000 43 | min_lr: 6e-5 44 | eps: 1e-8 45 | mixed_precision: null 46 | weight_decay: 0.1 47 | train_steps: 600000 # 600k steps 48 | n_epochs: 0 49 | num_workers: 2 50 | 51 | # total batch size = 1024 (context length) * 64 (update_batch_size) * 8 (grad_acc_steps) = 524,288 (~0.5M) 52 | # total number of tokens = train_steps * total batch size = 600k * 0.5M = 300B tokens 53 | update_batch_size: 256 # micro batch size is update_batch_size // num_gpus 54 | grad_acc_steps: 2 55 | block_size: 1024 # context length 56 | dropout: 0.0 57 | bias: False 58 | 59 | log_path: null 60 | use_accelerator: True 61 | 62 | # saving/evaluation/logging frequency 63 | save_step_freq: 10000 64 | eval_step_freq: 1000 65 | log_step_freq: 50 66 | global_step: 0 67 | val_datasets: ['openwebtext'] # measuring ppl 68 | batch_size_eval: 256 69 | eval_limit: 1000 70 | 71 | topK_attri: 4 # TopK for concept label 72 | concept_num: 32 # TopK for SAE activation 73 | concept_dim: 32768 # SAE concept dimention 74 | 75 | # sae 76 | sae_location: 'resid_post_mlp' 77 | insert_layer_index: null # CoCoMix model's layer that predict and insert the concept 78 | sae_layer_index: null # SAE layer that is used for concept extraction 79 | lam_concept: 0.1 80 | -------------------------------------------------------------------------------- /projects/cocomix/conf/config_eval.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | 4 | hydra: 5 | run: 6 | dir: . 7 | 8 | rank: 0 9 | seed: 42 10 | base_model: 'openai-community/gpt2' 11 | data_dir: './data/openwebtext_preprocess' # set your data path 12 | load_path: null 13 | batch_size: 64 14 | eval_freq: 2000 15 | eval_single_ckpt: False 16 | eval_tasks: ['lambada_openai','wikitext','hellaswag','piqa','social_iqa','arc_easy','winogrande'] 17 | save_result: True 18 | -------------------------------------------------------------------------------- /projects/cocomix/conf/ddp.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | enable_cpu_affinity: true 6 | gpu_ids: all 7 | machine_rank: 0 8 | main_process_ip: null 9 | main_process_port: null 10 | main_training_function: main 11 | mixed_precision: bf16 12 | num_machines: 1 13 | num_processes: 8 14 | rdzv_backend: static 15 | same_network: true 16 | tpu_env: [] 17 | tpu_use_cluster: false 18 | tpu_use_sudo: false 19 | use_cpu: false 20 | -------------------------------------------------------------------------------- /projects/cocomix/conf/fsdp_bf16.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: FSDP 4 | downcast_bf16: 'yes' 5 | fsdp_config: 6 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 7 | fsdp_backward_prefetch: BACKWARD_PRE 8 | fsdp_cpu_ram_efficient_loading: true 9 | fsdp_forward_prefetch: false 10 | fsdp_offload_params: false 11 | fsdp_sharding_strategy: FULL_SHARD 12 | fsdp_state_dict_type: SHARDED_STATE_DICT 13 | fsdp_sync_module_states: true 14 | fsdp_use_orig_params: true # true for torch.compile 15 | machine_rank: 0 16 | main_process_port: 12345 17 | main_training_function: main 18 | mixed_precision: bf16 19 | num_machines: 1 20 | num_processes: 8 21 | rdzv_backend: static 22 | same_network: true 23 | tpu_env: [] 24 | tpu_use_cluster: false 25 | tpu_use_sudo: false 26 | use_cpu: false -------------------------------------------------------------------------------- /projects/cocomix/conf/setup/gpt2_1b_cocomix.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | mode: 'cocomix' 4 | n_embd: 2048 5 | n_layer: 24 6 | n_head: 16 7 | compile_dynamo_cache_size_limit: 512 8 | 9 | # optimization 10 | lr: 2e-4 11 | lr_schedule: 'cosine_with_min_lr' # 'cosine' 'constant_with_warmup' 'constant', 12 | beta1: 0.9 13 | beta2: 0.95 14 | grad_clip_thresh: 1. 15 | warmup_steps: 65 16 | min_lr: 2e-5 17 | eps: 1e-8 18 | mixed_precision: null 19 | weight_decay: 0.1 20 | train_steps: 20000 # 20k steps ~ 20B 21 | 22 | # total batch size = 1024 (context length) * 1024 (update_batch_size) * 1 (grad_acc_steps) = (~1.0M) 23 | # total number of tokens = train_steps * total batch size = 20k * 1.0M = 20B tokens 24 | update_batch_size: 1024 # micro batch size is update_batch_size // num_gpus 25 | grad_acc_steps: 1 26 | block_size: 1024 27 | 28 | # saving/evaluation/logging frequency 29 | save_step_freq: 1000 30 | eval_step_freq: 500 31 | log_step_freq: 50 32 | val_datasets: ['openwebtext'] # measuring ppl 33 | batch_size_eval: 256 34 | eval_limit: 1000 35 | 36 | # sae 37 | insert_layer_index: 5 # CoCoMix model's layer that predict and insert the concept 38 | sae_layer_index: 5 # SAE layer that is used for concept extraction 39 | -------------------------------------------------------------------------------- /projects/cocomix/conf/setup/gpt2_1b_ntp.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | mode: 'ntp' 4 | n_embd: 2096 5 | n_layer: 24 6 | n_head: 16 7 | 8 | # optimization 9 | lr: 2e-4 10 | lr_schedule: 'cosine_with_min_lr' # 'cosine' 'constant_with_warmup' 'constant', 11 | beta1: 0.9 12 | beta2: 0.95 13 | grad_clip_thresh: 1. 14 | warmup_steps: 65 15 | min_lr: 2e-5 16 | eps: 1e-8 17 | mixed_precision: null 18 | weight_decay: 0.1 19 | train_steps: 20000 # 20k steps ~ 20B 20 | 21 | # total batch size = 1024 (context length) * 1024 (update_batch_size) * 1 (grad_acc_steps) = (~1.0M) 22 | # total number of tokens = train_steps * total batch size = 20k * 1.0M = 20B tokens 23 | update_batch_size: 1024 # micro batch size is update_batch_size // num_gpus 24 | grad_acc_steps: 1 25 | block_size: 1024 26 | 27 | # saving/evaluation/logging frequency 28 | save_step_freq: 1000 29 | eval_step_freq: 500 30 | log_step_freq: 50 31 | val_datasets: ['openwebtext'] # measuring ppl 32 | batch_size_eval: 256 33 | eval_limit: 1000 34 | -------------------------------------------------------------------------------- /projects/cocomix/conf/setup/gpt2_386m_cocomix.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | mode: 'cocomix' 4 | n_embd: 1024 5 | n_layer: 24 6 | n_head: 16 7 | compile_dynamo_cache_size_limit: 512 8 | 9 | # optimization 10 | lr: 3e-4 11 | lr_schedule: 'cosine_with_min_lr' # 'cosine' 'constant_with_warmup' 'constant', 12 | beta1: 0.9 13 | beta2: 0.95 14 | grad_clip_thresh: 1. 15 | warmup_steps: 130 16 | min_lr: 3e-5 17 | eps: 1e-8 18 | mixed_precision: null 19 | weight_decay: 0.1 20 | train_steps: 40000 # 40k steps ~ 20B 21 | 22 | # total batch size = 1024 (context length) * 512 (update_batch_size) * 1 (grad_acc_steps) = (~0.5M) 23 | # total number of tokens = train_steps * total batch size = 40k * 0.5M = 20B tokens 24 | update_batch_size: 512 # micro batch size is update_batch_size // num_gpus 25 | grad_acc_steps: 1 26 | block_size: 1024 27 | 28 | # saving/evaluation/logging frequency 29 | save_step_freq: 2000 30 | eval_step_freq: 1000 31 | log_step_freq: 50 32 | val_datasets: ['openwebtext'] # measuring ppl 33 | batch_size_eval: 256 34 | eval_limit: 1000 35 | 36 | # sae 37 | insert_layer_index: 5 # CoCoMix model's layer that predict and insert the concept 38 | sae_layer_index: 5 # SAE layer that is used for concept extraction -------------------------------------------------------------------------------- /projects/cocomix/conf/setup/gpt2_386m_ntp.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | mode: 'ntp' 4 | n_embd: 1072 5 | n_layer: 24 6 | n_head: 16 7 | 8 | # optimization 9 | lr: 3e-4 10 | lr_schedule: 'cosine_with_min_lr' # 'cosine' 'constant_with_warmup' 'constant', 11 | beta1: 0.9 12 | beta2: 0.95 13 | grad_clip_thresh: 1. 14 | warmup_steps: 130 15 | min_lr: 3e-5 16 | eps: 1e-8 17 | mixed_precision: null 18 | weight_decay: 0.1 19 | train_steps: 40000 # 40k steps ~ 20B 20 | 21 | # total batch size = 1024 (context length) * 512 (update_batch_size) * 1 (grad_acc_steps) = (~0.5M) 22 | # total number of tokens = train_steps * total batch size = 40k * 0.5M = 20B tokens 23 | update_batch_size: 512 # micro batch size is update_batch_size // num_gpus 24 | grad_acc_steps: 1 25 | block_size: 1024 26 | 27 | # saving/evaluation/logging frequency 28 | save_step_freq: 2000 29 | eval_step_freq: 1000 30 | log_step_freq: 50 31 | val_datasets: ['openwebtext'] # measuring ppl 32 | batch_size_eval: 256 33 | eval_limit: 1000 34 | -------------------------------------------------------------------------------- /projects/cocomix/conf/setup/gpt2_69m_cocomix.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | mode: 'cocomix' 4 | n_embd: 512 5 | n_layer: 8 6 | n_head: 8 7 | compile_dynamo_cache_size_limit: 512 8 | 9 | # optimization 10 | lr: 6e-4 11 | lr_schedule: 'cosine_with_min_lr' # 'cosine' 'constant_with_warmup' 'constant', 12 | beta1: 0.9 13 | beta2: 0.95 14 | grad_clip_thresh: 1. 15 | warmup_steps: 130 16 | min_lr: 6e-5 17 | eps: 1e-8 18 | mixed_precision: null 19 | weight_decay: 0.1 20 | train_steps: 40000 # 40k steps ~ 20B 21 | 22 | # total batch size = 1024 (context length) * 512 (update_batch_size) * 1 (grad_acc_steps) = (~0.5M) 23 | # total number of tokens = train_steps * total batch size = 40k * 0.5M = 20B tokens 24 | update_batch_size: 512 # micro batch size is update_batch_size // num_gpus 25 | grad_acc_steps: 1 26 | block_size: 1024 27 | 28 | # saving/evaluation/logging frequency 29 | save_step_freq: 2000 30 | eval_step_freq: 1000 31 | log_step_freq: 50 32 | val_datasets: ['openwebtext'] # measuring ppl 33 | batch_size_eval: 256 34 | eval_limit: 1000 35 | 36 | # sae 37 | insert_layer_index: 3 # CoCoMix model's layer that predict and insert the concept 38 | sae_layer_index: 5 # SAE layer that is used for concept extraction 39 | -------------------------------------------------------------------------------- /projects/cocomix/conf/setup/gpt2_69m_ntp.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | mode: 'ntp' 4 | n_embd: 624 5 | n_layer: 8 6 | n_head: 8 7 | 8 | # optimization 9 | lr: 6e-4 10 | lr_schedule: 'cosine_with_min_lr' # 'cosine' 'constant_with_warmup' 'constant', 11 | beta1: 0.9 12 | beta2: 0.95 13 | grad_clip_thresh: 1. 14 | warmup_steps: 130 15 | min_lr: 6e-5 16 | eps: 1e-8 17 | mixed_precision: null 18 | weight_decay: 0.1 19 | train_steps: 40000 # 40k steps ~ 20B 20 | 21 | # total batch size = 1024 (context length) * 512 (update_batch_size) * 1 (grad_acc_steps) = (~0.5M) 22 | # total number of tokens = train_steps * total batch size = 40k * 0.5M = 20B tokens 23 | update_batch_size: 512 # micro batch size is update_batch_size // num_gpus 24 | grad_acc_steps: 1 25 | block_size: 1024 26 | 27 | # saving/evaluation/logging frequency 28 | save_step_freq: 2000 29 | eval_step_freq: 1000 30 | log_step_freq: 50 31 | val_datasets: ['openwebtext'] # measuring ppl 32 | batch_size_eval: 256 33 | eval_limit: 1000 34 | -------------------------------------------------------------------------------- /projects/cocomix/data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | -------------------------------------------------------------------------------- /projects/cocomix/data/data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import os 9 | import random 10 | 11 | import numpy as np 12 | import torch 13 | from torch.utils.data import DataLoader, Dataset 14 | 15 | 16 | class PreprocessedDataset(Dataset): 17 | def __init__(self, data_dir, block_size=1024, split="train"): 18 | self.data = np.memmap( 19 | os.path.join(data_dir, f"{split}.bin"), dtype=np.uint16, mode="r" 20 | ) 21 | self.block_size = block_size 22 | self.split = split 23 | 24 | self.data_len = len(self.data) // self.block_size # drop last block 25 | self.remain = len(self.data) % self.block_size 26 | 27 | def __len__(self): 28 | return self.data_len 29 | 30 | def __getitem__(self, idx): 31 | if self.split == "train": 32 | if idx < self.data_len - 1: 33 | random_shift = random.randint(0, self.block_size) 34 | else: 35 | random_shift = random.randint(0, self.remain) 36 | x = torch.from_numpy( 37 | ( 38 | self.data[ 39 | idx * self.block_size 40 | + random_shift : (idx + 1) * self.block_size 41 | + random_shift 42 | ] 43 | ).astype(np.int64) 44 | ) 45 | else: 46 | x = torch.from_numpy( 47 | (self.data[idx * self.block_size : (idx + 1) * self.block_size]).astype( 48 | np.int64 49 | ) 50 | ) 51 | attention_mask = torch.ones_like(x) 52 | return {"input_ids": x, "labels": x, "attention_mask": attention_mask} 53 | 54 | 55 | def get_train_dataloader(cfg): 56 | kwargs = {} 57 | if cfg.dataset == "openwebtext": 58 | train_dataset = PreprocessedDataset( 59 | cfg.data_dir, block_size=cfg.block_size, split="train" 60 | ) 61 | else: 62 | print(f"dataset [{cfg.dataset}] not supported for evaluation") 63 | raise NotImplementedError 64 | 65 | batch_size = cfg.update_batch_size // cfg.world_size 66 | cfg.n_epochs = ( 67 | cfg.train_steps 68 | * cfg.update_batch_size 69 | * cfg.grad_acc_steps 70 | // len(train_dataset) 71 | + 1 72 | ) 73 | train_dataloader = DataLoader( 74 | train_dataset, 75 | shuffle=True, 76 | batch_size=batch_size, 77 | pin_memory=True, 78 | num_workers=cfg.num_workers, 79 | **kwargs, 80 | ) 81 | return train_dataloader 82 | 83 | 84 | def get_val_dataloaders(cfg): 85 | if cfg.val_datasets is None: 86 | return None 87 | 88 | val_dataloaders = {} 89 | for val_name in cfg.val_datasets: 90 | 91 | kwargs = {} 92 | if val_name == "openwebtext": 93 | val_dataset = PreprocessedDataset( 94 | cfg.data_dir, block_size=cfg.block_size, split="val" 95 | ) 96 | else: 97 | continue 98 | 99 | batch_size = cfg.batch_size_eval // cfg.world_size 100 | val_dataloader = DataLoader( 101 | val_dataset, shuffle=False, batch_size=batch_size, pin_memory=True, **kwargs 102 | ) 103 | val_dataloaders[val_name] = val_dataloader 104 | return val_dataloaders 105 | -------------------------------------------------------------------------------- /projects/cocomix/data/openwebtext_preprocess/prepare.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | # saves the openwebtext dataset to a binary file for training. following was helpful: 9 | # https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py 10 | 11 | import os 12 | 13 | import numpy as np 14 | import tiktoken 15 | from datasets import load_dataset # huggingface datasets 16 | from tqdm import tqdm 17 | 18 | # number of workers in .map() call 19 | # good number to use is ~order number of cpu cores // 2 20 | num_proc = 32 21 | 22 | # number of workers in load_dataset() call 23 | # best number might be different from num_proc above as it also depends on NW speed. 24 | # it is better than 1 usually though 25 | num_proc_load_dataset = num_proc 26 | 27 | enc = tiktoken.get_encoding("gpt2") 28 | 29 | if __name__ == "__main__": 30 | # takes 54GB in huggingface .cache dir, about 8M documents (8,013,769) 31 | dataset = load_dataset("openwebtext", num_proc=num_proc_load_dataset) 32 | 33 | # owt by default only contains the 'train' split, so create a test split 34 | split_dataset = dataset["train"].train_test_split( 35 | test_size=0.0005, seed=2357, shuffle=True 36 | ) 37 | split_dataset["val"] = split_dataset.pop("test") # rename the test split to val 38 | 39 | # this results in: 40 | # >>> split_dataset 41 | # DatasetDict({ 42 | # train: Dataset({ 43 | # features: ['text'], 44 | # num_rows: 8009762 45 | # }) 46 | # val: Dataset({ 47 | # features: ['text'], 48 | # num_rows: 4007 49 | # }) 50 | # }) 51 | 52 | # we now want to tokenize the dataset. first define the encoding function (gpt2 bpe) 53 | def process(example): 54 | ids = enc.encode_ordinary( 55 | example["text"] 56 | ) # encode_ordinary ignores any special tokens 57 | ids.append(enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe 58 | # note: I think eot should be prepended not appended... hmm. it's called "eot" though... 59 | out = {"ids": ids, "len": len(ids)} 60 | return out 61 | 62 | # tokenize the dataset 63 | tokenized = split_dataset.map( 64 | process, 65 | remove_columns=["text"], 66 | desc="tokenizing the splits", 67 | num_proc=num_proc, 68 | ) 69 | 70 | # concatenate all the ids in each dataset into one large file we can use for training 71 | for split, dset in tokenized.items(): 72 | arr_len = np.sum(dset["len"], dtype=np.uint64) 73 | filename = os.path.join(os.path.dirname(__file__), f"{split}.bin") 74 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) 75 | arr = np.memmap(filename, dtype=dtype, mode="w+", shape=(arr_len,)) 76 | total_batches = 1024 77 | 78 | idx = 0 79 | for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"): 80 | # Batch together samples for faster write 81 | batch = dset.shard( 82 | num_shards=total_batches, index=batch_idx, contiguous=True 83 | ).with_format("numpy") 84 | arr_batch = np.concatenate(batch["ids"]) 85 | # Write into mmap 86 | arr[idx : idx + len(arr_batch)] = arr_batch 87 | idx += len(arr_batch) 88 | arr.flush() 89 | 90 | # train.bin is ~17GB, val.bin ~8.5MB 91 | # train has ~9B tokens (9,035,582,198) 92 | # val has ~4M tokens (4,434,897) 93 | 94 | # to read the bin files later, e.g. with numpy: 95 | # m = np.memmap('train.bin', dtype=np.uint16, mode='r') 96 | -------------------------------------------------------------------------------- /projects/cocomix/data/openwebtext_preprocess/readme.md: -------------------------------------------------------------------------------- 1 | 2 | ## openwebtext dataset 3 | 4 | after running `prepare.py` (preprocess) we get: 5 | 6 | - train.bin is ~17GB, val.bin ~8.5MB 7 | - train has ~9B tokens (9,035,582,198) 8 | - val has ~4M tokens (4,434,897) 9 | 10 | this came from 8,013,769 documents in total. 11 | 12 | references: 13 | 14 | - OpenAI's WebText dataset is discussed in [GPT-2 paper](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) 15 | - [OpenWebText](https://skylion007.github.io/OpenWebTextCorpus/) dataset 16 | -------------------------------------------------------------------------------- /projects/cocomix/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import hydra 9 | import torch 10 | import torch._dynamo 11 | import torch.distributed as dist 12 | from accelerate import Accelerator 13 | from data.data import get_train_dataloader, get_val_dataloaders 14 | from models import get_base_lm, get_concept_extractor 15 | from omegaconf import OmegaConf 16 | from train import setup as train_setup 17 | from train.trainer import trainer 18 | from utils import Logger, set_random_seed 19 | 20 | 21 | @hydra.main(config_path="conf", config_name="config", version_base="1.3.2") 22 | def main(cfg): 23 | """Use huggingface accelerator (automatically use distributed)""" 24 | accelerator = Accelerator( 25 | gradient_accumulation_steps=cfg.grad_acc_steps, 26 | ) 27 | accelerator.wait_for_everyone() 28 | 29 | """ distributed related config """ 30 | num_gpus = dist.get_world_size() 31 | cfg.distributed = num_gpus > 1 32 | cfg.world_size = num_gpus 33 | 34 | """ fixing randomness """ 35 | set_random_seed(cfg.seed) 36 | torch.backends.cudnn.deterministic = True 37 | torch.backends.cudnn.benchmark = True 38 | torch.backends.cuda.matmul.allow_tf32 = True 39 | torch.backends.cudnn.allow_tf32 = True 40 | 41 | """ if torch compile""" 42 | if cfg.use_torch_compile: 43 | torch._dynamo.config.cache_size_limit = cfg.compile_dynamo_cache_size_limit 44 | 45 | """ define dataset, data loader, and tokenizer """ 46 | train_loader = get_train_dataloader(cfg) 47 | val_loaders = get_val_dataloaders(cfg) 48 | 49 | """ define concept_extractor """ 50 | concept_extractor = get_concept_extractor(cfg, accelerator) 51 | 52 | """ define base model """ 53 | base_lm = get_base_lm(cfg, accelerator) 54 | 55 | """ define train and test type """ 56 | train_func, fname, wandb_name = train_setup(cfg.mode, cfg) 57 | 58 | """ define logger """ 59 | logger = Logger( 60 | fname, 61 | cfg, 62 | main_process=accelerator.is_main_process, 63 | use_wandb=cfg.wandb_log, 64 | wandb_name=wandb_name, 65 | log_path=cfg.log_path, 66 | ) 67 | logger.log(OmegaConf.to_yaml(cfg)) 68 | 69 | """ train """ 70 | trainer( 71 | cfg, 72 | train_func, 73 | base_lm, 74 | train_loader, 75 | val_loaders, 76 | logger, 77 | accelerator, 78 | concept_extractor, 79 | ) 80 | 81 | """ close tensorboard """ 82 | logger.close_writer() 83 | 84 | 85 | if __name__ == "__main__": 86 | main() 87 | -------------------------------------------------------------------------------- /projects/cocomix/models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import os 9 | 10 | import torch 11 | from models.concept_extractor import TransformerLensSAE 12 | from models.modeling_gpt2_cocomix import GPT2CoCoMixLMHeadModel 13 | from transformers import AutoConfig, GPT2LMHeadModel 14 | 15 | 16 | def get_base_lm(cfg, accelerator): 17 | """define base model""" 18 | config = AutoConfig.from_pretrained(cfg.base_model) 19 | if cfg.vocab_size is not None: 20 | config.vocab_size = cfg.vocab_size 21 | if "gpt2" in cfg.base_model: 22 | if cfg.n_embd is not None: 23 | config.n_embd = cfg.n_embd 24 | if cfg.n_layer is not None: 25 | config.n_layer = cfg.n_layer 26 | if cfg.n_head is not None: 27 | config.n_head = cfg.n_head 28 | if cfg.mode == "cocomix": 29 | config._attn_implementation = "flash_attention_2" 30 | base_lm = GPT2CoCoMixLMHeadModel( 31 | config, cfg.concept_dim, cfg.insert_layer_index, cfg.concept_num 32 | ) 33 | else: # just next token prediction 34 | config._attn_implementation = "sdpa" 35 | base_lm = GPT2LMHeadModel(config) 36 | else: 37 | raise NotImplementedError 38 | 39 | base_lm = accelerator.prepare(base_lm) # Accelerate does FSDP, DDP, etc. 40 | 41 | if cfg.use_torch_compile: 42 | base_lm = torch.compile(base_lm) 43 | 44 | return base_lm 45 | 46 | 47 | def get_concept_extractor(cfg, accelerator): 48 | concept_extractor = None 49 | if cfg.mode in ["cocomix"]: 50 | if "gpt2" in cfg.pretrained_model: 51 | concept_extractor = TransformerLensSAE( 52 | layer_index=cfg.sae_layer_index, location=cfg.sae_location 53 | ) 54 | ddp_local_rank = int(os.environ["LOCAL_RANK"]) 55 | local_device = f"cuda:{ddp_local_rank}" 56 | concept_extractor = concept_extractor.to(local_device) 57 | concept_extractor.base_model = concept_extractor.base_model.to(local_device) 58 | concept_extractor.autoencoder = concept_extractor.autoencoder.to( 59 | local_device 60 | ) 61 | else: 62 | raise NotImplementedError 63 | 64 | if cfg.use_torch_compile and concept_extractor is not None: 65 | concept_extractor = torch.compile(concept_extractor) 66 | 67 | return concept_extractor 68 | -------------------------------------------------------------------------------- /projects/cocomix/models/concept_extractor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | """ 9 | Full definition of a GPT Language Model, all of it in this single file. 10 | References: 11 | 1) the official GPT-2 TensorFlow implementation released by OpenAI: 12 | https://github.com/openai/gpt-2/blob/master/src/model.py 13 | 2) huggingface/transformers PyTorch implementation: 14 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py 15 | """ 16 | 17 | from dataclasses import dataclass 18 | 19 | import blobfile as bf 20 | import models.sparse_autoencoder as sparse_autoencoder 21 | import torch 22 | import torch.nn as nn 23 | import transformer_lens 24 | from torch.nn import CrossEntropyLoss 25 | from transformer_lens import ActivationCache 26 | 27 | 28 | class TransformerLensSAE(nn.Module): 29 | def __init__(self, layer_index=6, location="resid_post_mlp"): 30 | super().__init__() 31 | 32 | # define sparse autoencoder 33 | self.layer_index = layer_index 34 | base_model = transformer_lens.HookedTransformer.from_pretrained("gpt2") 35 | 36 | self.base_model = base_model 37 | self.base_model.eval() 38 | 39 | self.transformer_lens_loc = { 40 | "mlp_post_act": f"blocks.{layer_index}.mlp.hook_post", 41 | "resid_delta_attn": f"blocks.{layer_index}.hook_attn_out", 42 | "resid_post_attn": f"blocks.{layer_index}.hook_resid_mid", 43 | "resid_delta_mlp": f"blocks.{layer_index}.hook_mlp_out", 44 | "resid_post_mlp": f"blocks.{layer_index}.hook_resid_post", 45 | }[location] 46 | 47 | with bf.BlobFile( 48 | sparse_autoencoder.paths.v5_32k(location, layer_index), mode="rb" 49 | ) as f: 50 | state_dict = torch.load(f) 51 | autoencoder = sparse_autoencoder.Autoencoder.from_state_dict(state_dict) 52 | 53 | self.autoencoder = autoencoder 54 | 55 | def get_cache_fwd_and_bwd(self, tokens, labels, new_act=None): 56 | # filter_not_qkv_input = lambda name: "_input" not in name 57 | 58 | self.base_model.reset_hooks() 59 | cache = {} 60 | 61 | def forward_cache_hook(act, hook): 62 | if new_act is not None: 63 | cache[hook.name] = new_act.detach() 64 | return new_act # activation patching 65 | cache[hook.name] = act.detach() 66 | 67 | self.base_model.add_hook(self.transformer_lens_loc, forward_cache_hook, "fwd") 68 | 69 | grad_cache = {} 70 | 71 | def backward_cache_hook(act, hook): 72 | grad_cache[hook.name] = act.detach() 73 | 74 | self.base_model.add_hook(self.transformer_lens_loc, backward_cache_hook, "bwd") 75 | logits = self.base_model(tokens) 76 | loss = self.compute_loss(logits, labels) 77 | loss.backward() 78 | self.base_model.reset_hooks() 79 | return ( 80 | loss.item(), 81 | ActivationCache(cache, self.base_model), 82 | ActivationCache(grad_cache, self.base_model), 83 | ) 84 | 85 | def compute_loss(self, logits, labels): 86 | labels = labels.to(logits.device) 87 | # Shift so that tokens < n predict n 88 | shift_logits = logits[..., :-1, :].contiguous() 89 | shift_labels = labels[..., 1:].contiguous() 90 | # Flatten the tokens 91 | loss_fct = CrossEntropyLoss() 92 | loss = loss_fct( 93 | shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) 94 | ) 95 | return loss 96 | 97 | def compute_attribute(self, idx, labels): 98 | assert ( 99 | labels is not None 100 | ), "Attribute-based latent tokenization requires labels." 101 | 102 | _, act, grad = self.get_cache_fwd_and_bwd(idx, labels) 103 | 104 | x = act[self.transformer_lens_loc] 105 | grad_x = grad[self.transformer_lens_loc] 106 | 107 | latent_activations, _ = self.autoencoder.encode(x) 108 | w_dec = self.autoencoder.decoder.weight 109 | attribute = torch.matmul(grad_x, w_dec) * latent_activations 110 | 111 | return attribute 112 | 113 | def forward(self, input_ids, labels=None): 114 | if labels is None: 115 | labels = input_ids 116 | attribute = self.compute_attribute(input_ids, labels) 117 | return attribute 118 | -------------------------------------------------------------------------------- /projects/cocomix/models/sparse_autoencoder/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | from . import paths 9 | from .model import Autoencoder 10 | 11 | __all__ = ["Autoencoder"] 12 | -------------------------------------------------------------------------------- /projects/cocomix/models/sparse_autoencoder/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import torch 9 | 10 | 11 | def autoencoder_loss( 12 | reconstruction: torch.Tensor, 13 | original_input: torch.Tensor, 14 | latent_activations: torch.Tensor, 15 | l1_weight: float, 16 | ) -> torch.Tensor: 17 | """ 18 | :param reconstruction: output of Autoencoder.decode (shape: [batch, n_inputs]) 19 | :param original_input: input of Autoencoder.encode (shape: [batch, n_inputs]) 20 | :param latent_activations: output of Autoencoder.encode (shape: [batch, n_latents]) 21 | :param l1_weight: weight of L1 loss 22 | :return: loss (shape: [1]) 23 | """ 24 | return ( 25 | normalized_mean_squared_error(reconstruction, original_input) 26 | + normalized_L1_loss(latent_activations, original_input) * l1_weight 27 | ) 28 | 29 | 30 | def normalized_mean_squared_error( 31 | reconstruction: torch.Tensor, 32 | original_input: torch.Tensor, 33 | ) -> torch.Tensor: 34 | """ 35 | :param reconstruction: output of Autoencoder.decode (shape: [batch, n_inputs]) 36 | :param original_input: input of Autoencoder.encode (shape: [batch, n_inputs]) 37 | :return: normalized mean squared error (shape: [1]) 38 | """ 39 | return ( 40 | ((reconstruction - original_input) ** 2).mean(dim=1) 41 | / (original_input**2).mean(dim=1) 42 | ).mean() 43 | 44 | 45 | def normalized_L1_loss( 46 | latent_activations: torch.Tensor, 47 | original_input: torch.Tensor, 48 | ) -> torch.Tensor: 49 | """ 50 | :param latent_activations: output of Autoencoder.encode (shape: [batch, n_latents]) 51 | :param original_input: input of Autoencoder.encode (shape: [batch, n_inputs]) 52 | :return: normalized L1 loss (shape: [1]) 53 | """ 54 | return (latent_activations.abs().sum(dim=1) / original_input.norm(dim=1)).mean() 55 | -------------------------------------------------------------------------------- /projects/cocomix/models/sparse_autoencoder/paths.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | 9 | def v1(location, layer_index): 10 | """ 11 | Details: 12 | - Number of autoencoder latents: 32768 13 | - Number of training tokens: ~64M 14 | - Activation function: ReLU 15 | - L1 regularization strength: 0.01 16 | - Layer normed inputs: false 17 | - NeuronRecord files: 18 | `az://openaipublic/sparse-autoencoder/gpt2-small/{location}/collated_activations/{layer_index}/{latent_index}.json` 19 | """ 20 | assert location in ["mlp_post_act", "resid_delta_mlp"] 21 | assert layer_index in range(12) 22 | return f"az://openaipublic/sparse-autoencoder/gpt2-small/{location}/autoencoders/{layer_index}.pt" 23 | 24 | 25 | def v4(location, layer_index): 26 | """ 27 | Details: 28 | same as v1 29 | """ 30 | assert location in ["mlp_post_act", "resid_delta_mlp"] 31 | assert layer_index in range(12) 32 | return f"az://openaipublic/sparse-autoencoder/gpt2-small/{location}_v4/autoencoders/{layer_index}.pt" 33 | 34 | 35 | def v5_32k(location, layer_index): 36 | """ 37 | Details: 38 | - Number of autoencoder latents: 2**15 = 32768 39 | - Number of training tokens: TODO 40 | - Activation function: TopK(32) 41 | - L1 regularization strength: n/a 42 | - Layer normed inputs: true 43 | """ 44 | assert location in [ 45 | "resid_delta_attn", 46 | "resid_delta_mlp", 47 | "resid_post_attn", 48 | "resid_post_mlp", 49 | ] 50 | assert layer_index in range(12) 51 | # note: it's actually 2**15 and 2**17 ~= 131k 52 | return f"az://openaipublic/sparse-autoencoder/gpt2-small/{location}_v5_32k/autoencoders/{layer_index}.pt" 53 | 54 | 55 | def v5_128k(location, layer_index): 56 | """ 57 | Details: 58 | - Number of autoencoder latents: 2**17 = 131072 59 | - Number of training tokens: TODO 60 | - Activation function: TopK(32) 61 | - L1 regularization strength: n/a 62 | - Layer normed inputs: true 63 | """ 64 | assert location in [ 65 | "resid_delta_attn", 66 | "resid_delta_mlp", 67 | "resid_post_attn", 68 | "resid_post_mlp", 69 | ] 70 | assert layer_index in range(12) 71 | # note: it's actually 2**15 and 2**17 ~= 131k 72 | return f"az://openaipublic/sparse-autoencoder/gpt2-small/{location}_v5_128k/autoencoders/{layer_index}.pt" 73 | 74 | 75 | # NOTE: we have larger autoencoders (up to 8M, with varying n and k) trained on layer 8 resid_post_mlp 76 | # we may release them in the future 77 | -------------------------------------------------------------------------------- /projects/cocomix/requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.34.2 2 | blobfile==2.1.1 3 | dataset==1.6.2 4 | datasets==2.21.0 5 | einops==0.8.0 6 | flash-attn==2.5.8 7 | huggingface-hub==0.24.6 8 | hydra-core==1.3.2 9 | -e git+https://github.com/EleutherAI/lm-evaluation-harness@9a092f374bdc6d6032ae2b878b7a49b97801ab69#egg=lm_eval 10 | transformer-lens==2.4.1 11 | typeguard==4.4.1 12 | transformers==4.44.2 13 | wandb==0.18.7 14 | -------------------------------------------------------------------------------- /projects/cocomix/slurm_bash/slurm_multi.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | #SBATCH --job-name=JOB_NAME 9 | #SBATCH -D . 10 | 11 | #SBATCH --output=OUTPUT_LOG_PATH 12 | #SBATCH --error=ERROR_LOG_PATH 13 | 14 | #SBATCH --account=SLURM_ACCOUNT 15 | #SBATCH --qos=SLURM_QOS 16 | ## number of nodes 17 | #SBATCH --nodes=NUM_NODE 18 | #SBATCH --ntasks-per-node=1 # number of MP tasks 19 | #SBATCH --gres=gpu:8 # number of GPUs per node 20 | #SBATCH --cpus-per-task=CPU_NUM 21 | #SBATCH --time=MAX_TIME 22 | 23 | # Initialize Conda for bash using Miniconda3 24 | source ~/miniconda3/etc/profile.d/conda.sh 25 | 26 | # Activate the Conda environment 27 | conda activate cocomix 28 | 29 | # Variables 30 | NUM_PROCESSES=8 # GPUs per node 31 | NUM_MACHINES=$SLURM_JOB_NUM_NODES # Total nodes 32 | MACHINE_RANK=$SLURM_NODEID # Rank of the current node 33 | head_node_ip=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) 34 | MAIN_PROCESS_PORT=12345 # Choose an open port 35 | 36 | # Set any necessary environment variables 37 | # export NCCL_DEBUG=INFO 38 | export MASTER_ADDR=$head_node_ip 39 | export MASTER_PORT=$MAIN_PROCESS_PORT 40 | export WORLD_SIZE=$(($NUM_PROCESSES * $NUM_MACHINES)) 41 | export RANK=$(($MACHINE_RANK * $NUM_PROCESSES)) 42 | 43 | # Print for debugging 44 | echo "Main process IP: $head_node_ip" 45 | echo "Machine rank: $MACHINE_RANK" 46 | echo "World size: $WORLD_SIZE" 47 | 48 | # Print Slurm environment variables 49 | echo "SLURM_NTASKS=$SLURM_NTASKS" 50 | echo "SLURM_NNODES=$SLURM_NNODES" 51 | echo "SLURM_TASKS_PER_NODE=$SLURM_TASKS_PER_NODE" 52 | 53 | export LAUNCHER="accelerate launch \ 54 | --config_file ./conf/fsdp_bf16.yaml \ 55 | --num_processes $WORLD_SIZE \ 56 | --num_machines $NUM_MACHINES \ 57 | --rdzv_backend c10d \ 58 | --main_process_ip $head_node_ip \ 59 | --main_process_port $MAIN_PROCESS_PORT \ 60 | " 61 | export SCRIPT="main.py" 62 | export SCRIPT_ARGS="$@" 63 | 64 | # This step is necessary because accelerate launch does not handle multiline arguments properly 65 | export CMD="$LAUNCHER $SCRIPT $SCRIPT_ARGS" 66 | echo "Running the following command:" 67 | echo $CMD 68 | srun --ntasks=$SLURM_NTASKS $CMD -------------------------------------------------------------------------------- /projects/cocomix/train/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | 9 | def setup(mode, cfg): 10 | 11 | base_model = cfg.base_model 12 | if cfg.n_embd is not None: 13 | base_model = f"{base_model}_embd{cfg.n_embd}" 14 | if cfg.n_layer is not None: 15 | base_model = f"{base_model}_L{cfg.n_layer}" 16 | if cfg.n_head is not None: 17 | base_model = f"{base_model}_H{cfg.n_head}" 18 | 19 | fname = ( 20 | f"{cfg.dataset}/{base_model}/{mode}" 21 | f"_bs{int(cfg.update_batch_size*cfg.grad_acc_steps)}" 22 | ) 23 | wandb_name = ( 24 | f"{cfg.dataset}_{base_model}_{mode}" 25 | f"_bs{int(cfg.update_batch_size*cfg.grad_acc_steps)}" 26 | ) 27 | 28 | fname += f"_ctx{cfg.block_size}" 29 | wandb_name += f"_ctx{cfg.block_size}" 30 | 31 | if mode == "ntp": 32 | from train.train_func.ntp import train_step 33 | elif mode == "cocomix": 34 | from train.train_func.cocomix import train_step 35 | 36 | fname += f"_lam{cfg.lam_concept}" 37 | wandb_name += f"_lam{cfg.lam_concept}" 38 | else: 39 | raise NotImplementedError() 40 | 41 | fname += f"_seed_{cfg.seed}" 42 | wandb_name += f"_seed_{cfg.seed}" 43 | if cfg.suffix is not None: 44 | fname += f"_{cfg.suffix}" 45 | wandb_name += f"_{cfg.suffix}" 46 | 47 | return train_step, fname, wandb_name 48 | -------------------------------------------------------------------------------- /projects/cocomix/train/train_func/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | -------------------------------------------------------------------------------- /projects/cocomix/train/train_func/cocomix.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | from utils import metric_synchronize_between_processes 12 | 13 | 14 | def train_step( 15 | cfg, 16 | base_lm, 17 | optimizer, 18 | scheduler, 19 | accelerator, 20 | batch, 21 | logger, 22 | metrics_dic, 23 | concept_extractor, 24 | ): 25 | 26 | # compute loss 27 | with accelerator.accumulate(base_lm): 28 | base_lm.train() 29 | 30 | extracted_concept = concept_extractor(input_ids=batch["input_ids"]) 31 | outputs, concept_logit = base_lm( 32 | input_ids=batch["input_ids"], 33 | labels=batch["input_ids"], 34 | get_concept_logit=True, 35 | ) 36 | 37 | concept_labels = torch.topk(extracted_concept, k=cfg.topK_attri, dim=-1)[1] 38 | loss_concept = torch.tensor(0.0).to(base_lm.device) 39 | for i in range(cfg.topK_attri): 40 | loss_concept += ( 41 | 1 42 | / cfg.topK_attri 43 | * F.cross_entropy( 44 | concept_logit.view(-1, concept_logit.size(-1)), 45 | concept_labels[:, :, i].contiguous().view(-1), 46 | ) 47 | ) 48 | 49 | loss = outputs.loss 50 | 51 | metrics_dic["loss"].append(loss.item()) 52 | metrics_dic["loss_concept"].append(loss_concept.item()) 53 | 54 | loss_total = loss + cfg.lam_concept * loss_concept 55 | accelerator.backward(loss_total) 56 | 57 | if accelerator.sync_gradients: 58 | # clip gradient when using sync gradients 59 | grad_norm = accelerator.clip_grad_norm_( 60 | base_lm.parameters(), cfg.grad_clip_thresh 61 | ) 62 | metrics_dic["grad_norm"].append(grad_norm) 63 | 64 | # log metrics when using sync gradients (i.e., actual gradient update) 65 | if cfg.global_step % cfg.log_step_freq == 0: 66 | metric_synchronize_between_processes( 67 | metrics_dic, accelerator 68 | ) # sync metrics across processes 69 | log_metrics = { 70 | "train": {f"{k}": np.mean(v) for k, v in metrics_dic.items()}, 71 | "lr": optimizer.param_groups[0]["lr"], 72 | } 73 | logger.wandb_log(log_metrics, step=cfg.global_step) 74 | for k, v in metrics_dic.items(): 75 | logger.log(f"Step {cfg.global_step} Train {k}: {np.mean(v)}") 76 | 77 | metrics_dic.clear() 78 | cfg.global_step += 1 79 | 80 | optimizer.step() 81 | if accelerator.sync_gradients: 82 | scheduler.step() 83 | optimizer.zero_grad() 84 | -------------------------------------------------------------------------------- /projects/cocomix/train/train_func/ntp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import numpy as np 9 | from utils import metric_synchronize_between_processes 10 | 11 | 12 | def train_step( 13 | cfg, base_lm, optimizer, scheduler, accelerator, batch, logger, metrics_dic 14 | ): 15 | 16 | # compute loss 17 | with accelerator.accumulate(base_lm): 18 | base_lm.train() 19 | 20 | outputs = base_lm(input_ids=batch["input_ids"], labels=batch["input_ids"]) 21 | loss = outputs.loss 22 | metrics_dic["loss"].append(loss.item()) 23 | 24 | accelerator.backward(loss) 25 | 26 | if accelerator.sync_gradients: 27 | # clip gradient when using sync gradients 28 | grad_norm = accelerator.clip_grad_norm_( 29 | base_lm.parameters(), cfg.grad_clip_thresh 30 | ) 31 | metrics_dic["grad_norm"].append(grad_norm) 32 | 33 | # log metrics when using sync gradients (i.e., actual gradient update) 34 | if cfg.global_step % cfg.log_step_freq == 0: 35 | metric_synchronize_between_processes( 36 | metrics_dic, accelerator 37 | ) # sync metrics across processes 38 | log_metrics = { 39 | "train": {f"{k}": np.mean(v) for k, v in metrics_dic.items()}, 40 | "lr": optimizer.param_groups[0]["lr"], 41 | } 42 | logger.wandb_log(log_metrics, step=cfg.global_step) 43 | for k, v in metrics_dic.items(): 44 | logger.log(f"Step {cfg.global_step} Train {k}: {np.mean(v)}") 45 | logger.log( 46 | f'Step {cfg.global_step} Train lr: {optimizer.param_groups[0]["lr"]}' 47 | ) 48 | 49 | metrics_dic.clear() 50 | cfg.global_step += 1 51 | 52 | optimizer.step() 53 | if accelerator.sync_gradients: 54 | scheduler.step() 55 | optimizer.zero_grad() 56 | -------------------------------------------------------------------------------- /projects/cocomix/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import os 9 | import random 10 | import sys 11 | from datetime import datetime 12 | 13 | import numpy as np 14 | import torch 15 | import torch.distributed as dist 16 | import wandb 17 | from omegaconf import OmegaConf 18 | from tqdm.auto import tqdm 19 | 20 | 21 | class Logger(object): 22 | """Reference: https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514""" 23 | 24 | def __init__( 25 | self, 26 | fn, 27 | cfg, 28 | main_process=True, 29 | use_wandb=False, 30 | wandb_name=None, 31 | log_path=None, 32 | ): 33 | self.main_process = main_process 34 | self.log_path = "./logs/" if log_path is None else log_path 35 | self.logdir = None 36 | self.cfg = cfg 37 | self.use_wandb = use_wandb 38 | 39 | if self.main_process: 40 | logdir = self.log_path + fn 41 | self.logdir = logdir 42 | self.set_dir(logdir) 43 | 44 | if self.use_wandb: 45 | wandb.login(key=cfg.wandb_key) 46 | wandb.config = OmegaConf.to_container( 47 | cfg, resolve=True, throw_on_missing=True 48 | ) 49 | wandb.init( 50 | project=cfg.wandb_project, 51 | name=wandb_name, 52 | dir=logdir, 53 | entity=cfg.wandb_entity, 54 | settings=wandb.Settings(start_method="fork"), 55 | ) 56 | 57 | # distribute logdir to other processes 58 | if torch.distributed.is_initialized(): 59 | if self.main_process: 60 | object_list = [self.logdir] 61 | else: 62 | object_list = [None] 63 | dist.broadcast_object_list(object_list, src=0) 64 | self.logdir = object_list[0] 65 | 66 | def set_dir(self, logdir, log_fn="log.txt"): 67 | os.makedirs(logdir, exist_ok=True) 68 | self.log_file = open(os.path.join(logdir, log_fn), "a") 69 | with open(os.path.join(logdir, "config.yaml"), "w+") as fp: 70 | OmegaConf.save(config=self.cfg, f=fp.name) 71 | 72 | def close_writer(self): 73 | if self.main_process and self.use_wandb: 74 | wandb.finish() 75 | 76 | def log(self, string): 77 | if self.main_process: 78 | self.log_file.write("[%s] %s" % (datetime.now(), string) + "\n") 79 | self.log_file.flush() 80 | 81 | print("[%s] %s" % (datetime.now(), string)) 82 | sys.stdout.flush() 83 | 84 | def log_dirname(self, string): 85 | if self.main_process: 86 | self.log_file.write("%s (%s)" % (string, self.logdir) + "\n") 87 | self.log_file.flush() 88 | 89 | print("%s (%s)" % (string, self.logdir)) 90 | sys.stdout.flush() 91 | 92 | def wandb_log(self, log_dict, step=None, commit=None): 93 | if self.main_process and self.use_wandb: 94 | wandb.log(log_dict, step=step, commit=commit) 95 | 96 | 97 | def set_random_seed(seed): 98 | seed = int(seed) 99 | random.seed(seed) 100 | np.random.seed(seed) 101 | torch.manual_seed(seed) 102 | torch.cuda.manual_seed(seed) 103 | torch.cuda.manual_seed_all(seed) 104 | 105 | 106 | def is_dist_avail_and_initialized(): 107 | if not dist.is_available(): 108 | return False 109 | if not dist.is_initialized(): 110 | return False 111 | return True 112 | 113 | 114 | def metric_synchronize_between_processes(metrics, accelerator=None): 115 | if accelerator is not None: 116 | for k, v in metrics.items(): 117 | t = torch.tensor([v], dtype=torch.float64, device=accelerator.device) 118 | gathered_items = accelerator.gather_for_metrics(t) 119 | metrics[k] = gathered_items.mean().item() 120 | else: 121 | if is_dist_avail_and_initialized(): 122 | for k, v in metrics.items(): 123 | t = torch.tensor([v], dtype=torch.float64, device="cuda") 124 | dist.barrier() 125 | dist.all_reduce(t) 126 | t /= dist.get_world_size() 127 | t = t.tolist() 128 | metrics[k] = t[0] 129 | 130 | 131 | def logging_path_check(cfg): 132 | from train import setup as train_setup 133 | 134 | _, fname, _ = train_setup(cfg.mode, cfg) 135 | log_path = "./logs/" if cfg.log_path is None else cfg.log_path 136 | os.makedirs(log_path, exist_ok=True) 137 | logdir = log_path + fname 138 | os.makedirs(logdir, exist_ok=True) 139 | 140 | 141 | # Function to create a tqdm progress bar for distributed training 142 | def tqdm_distributed(main_process, iterator, *args, **kwargs): 143 | if main_process: 144 | return tqdm(iterator, *args, **kwargs) 145 | else: 146 | return iterator # No progress bar for non-main processes 147 | -------------------------------------------------------------------------------- /projects/cope/README.md: -------------------------------------------------------------------------------- 1 | # Contextual Position Encoding (CoPE): Learning to Count What's Important 2 | 3 | - Contextual Position Encoding (CoPE) is a new position encoding method that allows positions to be conditioned on context by incrementing position only on certain tokens determined by the model. 4 | - CoPE allows more general position addressing such as attending to the i-th particular word, noun, or sentence. 5 | - In particular, CoPE computes gate values conditioned on the context first, then uses that to assign positions to tokens using a cumulative sum. This allows positions to be contextualized, and represent the count of different units like words, verbs or sentences. CoPE operates on each attention head and so can attend to different position types on each. 6 | 7 |

8 | 9 | ## Papers 10 | 11 | This work is based on the following paper: [Contextual Position Encoding: Learning to Count What's Important](https://arxiv.org/pdf/2405.18719). 12 | 13 | ## Setup 14 | 15 | The following setup is recommened to reproduce experiments: 16 | 17 | 1. Create conda environment 18 | 19 | ```bash 20 | conda create --name cope python=3.9 21 | conda activate cope 22 | ``` 23 | 24 | 2. Install dependencies. Our experiments were run with `transformers` version 4.42.4, you can specify it in `requirements.txt` to reproduce results in the paper: 25 | ```bash 26 | conda install pytorch=2.2 pytorch-cuda=12.1 -y --strict-channel-priority --override-channels -c pytorch -c nvidia -c conda-forge 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | ## Run model training end evaluation 31 | 32 | We created a script that reproduces 3- and 5-variable runs for the Counting Task described in the [paper](https://arxiv.org/pdf/2405.18719). Simply run it on a GPU node: 33 | 34 | ```bash 35 | bash run.sh 36 | ``` 37 | 38 | In the paper we reported the average test error rates of 3 random seeds. 39 | 40 |

41 | 42 | ## Contributors 43 | Olga Golovneva, Tianlu Wang, Janice Lan, Jason Weston, Sainbayar Sukhbaatar 44 | 45 | ## Citation 46 | If you use our model in your own work, please cite with the following BibTex entry: 47 | ``` 48 | @article{golovneva2024contextual, 49 | title={Contextual Position Encoding: Learning to Count What's Important}, 50 | author={Golovneva, Olga and Wang, Tianlu and Weston, Jason and Sukhbaatar, Sainbayar}, 51 | journal={arXiv preprint arXiv:2405.18719}, 52 | year={2024} 53 | } -------------------------------------------------------------------------------- /projects/cope/eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | from argparse import ArgumentParser 9 | 10 | import torch 11 | from src.main import add_train_args, eval 12 | 13 | if __name__ == "__main__": 14 | parser = ArgumentParser() 15 | parser.add_argument( 16 | "--checkpoint-path", type=str, help="a path to the checkpoint file" 17 | ) 18 | args, train_args = parser.parse_known_args() 19 | 20 | if args.checkpoint_path is None: 21 | train_parser = ArgumentParser() 22 | add_train_args(train_parser) 23 | cfg = train_parser.parse_args(train_args) 24 | else: 25 | cfg = torch.load(args.checkpoint_path, map_location=torch.device("cpu"))["cfg"] 26 | if len(train_args) > 0: 27 | train_parser = ArgumentParser() 28 | add_train_args(train_parser) 29 | cfg = train_parser.parse_args(train_args, namespace=cfg) 30 | cfg.log_plot = False 31 | cfg.distributed = False 32 | cfg.checkpoint = args.checkpoint_path 33 | cfg.init_model_from = None 34 | 35 | eval(cfg) 36 | -------------------------------------------------------------------------------- /projects/cope/figures/CoPE.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/RAM/c567dfb73c0bc60c7ca1f114b72cd3bf7d9442bd/projects/cope/figures/CoPE.png -------------------------------------------------------------------------------- /projects/cope/figures/counting_task.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/RAM/c567dfb73c0bc60c7ca1f114b72cd3bf7d9442bd/projects/cope/figures/counting_task.png -------------------------------------------------------------------------------- /projects/cope/requirements.txt: -------------------------------------------------------------------------------- 1 | submitit==1.5.1 2 | transformers 3 | wandb==0.17.4 4 | -------------------------------------------------------------------------------- /projects/cope/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | # Example set of commands that generates data, trains, and evaluates. 10 | 11 | function gen_data { 12 | nvars=$1 # number of variables 13 | 14 | cmd="python scripts/count_data_gen.py --nvars $nvars --out-path count_data/" 15 | 16 | echo $cmd 17 | $cmd 18 | } 19 | 20 | function launch_train { 21 | pos_emb=$1 # abs, rel, or cape 22 | data=$2 # dir containing train/val/test jsonl data 23 | seed=$3 # random seed 24 | 25 | log_name="${pos_emb}_${data}_seed=${seed}" 26 | 27 | MODEL_ARGS="--nlayers 4 --hid-sz 256 --nheads 4 --block-size 512" 28 | GENERAL_ARGS="--model simpledec \ 29 | --tokenizer simple \ 30 | --emb-tie \ 31 | --nepochs 100 \ 32 | --drop 0.1 \ 33 | --batch-sz 64 \ 34 | --lr 0.00007 \ 35 | --train-on answer \ 36 | --post-norm \ 37 | --log-plot \ 38 | --log-plot-dir wandb_logs/" 39 | 40 | CUSTOM_ARGS="--data count_data/$data \ 41 | --seed $seed \ 42 | --pos-emb $pos_emb \ 43 | --log-name $log_name \ 44 | --checkpoint checkpoints/${log_name}.pt" 45 | 46 | cmd="python train.py ${GENERAL_ARGS} ${MODEL_ARGS} ${CUSTOM_ARGS}" 47 | 48 | echo $cmd 49 | $cmd 50 | } 51 | 52 | function launch_eval { 53 | pos_emb=$1 # abs, rel, or cape 54 | data=$2 # dir containing train/val/test jsonl data 55 | seed=$3 # seed used for train 56 | 57 | log_name="${pos_emb}_${data}_seed=${seed}" 58 | 59 | cmd="python eval.py --data count_data/$data \ 60 | --checkpoint-path checkpoints/${log_name}.pt \ 61 | --eval-on test" 62 | 63 | echo $cmd 64 | $cmd 65 | } 66 | 67 | set -eo pipefail 68 | 69 | gen_data 3 70 | launch_train cope count_var3_step512_train10k 1 71 | launch_eval cope count_var3_step512_train10k 1 72 | 73 | launch_train abs count_var3_step512_train10k 1 74 | launch_eval abs count_var3_step512_train10k 1 75 | 76 | launch_train rel count_var3_step512_train10k 1 77 | launch_eval rel count_var3_step512_train10k 1 78 | 79 | 80 | gen_data 5 81 | launch_train cope count_var5_step512_train10k 10 82 | launch_eval cope count_var5_step512_train10k 10 83 | 84 | launch_train abs count_var5_step512_train10k 10 85 | launch_eval abs count_var5_step512_train10k 10 86 | 87 | launch_train rel count_var5_step512_train10k 10 88 | launch_eval rel count_var5_step512_train10k 10 89 | -------------------------------------------------------------------------------- /projects/cope/scripts/count_data_gen.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import argparse 9 | import json 10 | import os 11 | import random 12 | import string 13 | from enum import Enum 14 | from typing import Dict, List 15 | 16 | VARIABLE_NAMES = list(string.ascii_lowercase) 17 | MAX_VAL = 10 18 | 19 | 20 | class Ops(Enum): 21 | SET = 1 22 | PRINT = 2 23 | DUMMY = 3 24 | INCREMENT = 4 25 | 26 | 27 | OP_WEIGHTS = { 28 | Ops.SET: 1, 29 | Ops.DUMMY: 50, 30 | Ops.INCREMENT: 7, 31 | } 32 | 33 | 34 | class GenerationFail(Exception): 35 | """Dummy class for raising errors""" 36 | 37 | pass 38 | 39 | 40 | def sample_op(): 41 | op = random.choices(list(OP_WEIGHTS.keys()), OP_WEIGHTS.values(), k=1)[0] 42 | return op 43 | 44 | 45 | def sample_existing_var(var_vals): 46 | if len(var_vals) == 0: 47 | raise GenerationFail 48 | var = random.choice(list(var_vals.keys())) 49 | return var, var_vals[var] 50 | 51 | 52 | def generate_op(args, op: Ops, var_names: List[str], var_vals: Dict[str, int]): 53 | if op == Ops.SET: 54 | var = random.choice(var_names) 55 | val = 0 56 | var_vals[var] = val 57 | text = f"{var} = {val} ;" 58 | elif op == Ops.PRINT: 59 | var, val = sample_existing_var(var_vals) 60 | text = f"print {var} {val} ;" 61 | elif op == Ops.DUMMY: 62 | text = f"pass;" 63 | elif op == Ops.INCREMENT: 64 | var, val = sample_existing_var(var_vals) 65 | if val + 1 > MAX_VAL: 66 | raise GenerationFail 67 | text = f"{var} ++;" 68 | var_vals[var] = val + 1 69 | return text 70 | 71 | 72 | def generate_statement(var_names: List[str], var_vals: Dict[str, int]): 73 | for trial in range(100): 74 | try: 75 | op = sample_op() 76 | text = generate_op(args, op, var_names, var_vals) 77 | return text 78 | except GenerationFail: 79 | continue 80 | raise ValueError("Failed to generate too many times!") 81 | 82 | 83 | def try_generate_sample(nvars: int): 84 | nsteps = random.randint(3, args.max_steps) 85 | state = {} 86 | var_names = VARIABLE_NAMES[:nvars] 87 | context = [] 88 | for step in range(nsteps): 89 | text = generate_statement(var_names, state) 90 | context.append(text) 91 | context = " ".join(context) 92 | text = generate_op(args, Ops.PRINT, var_names, state) 93 | question = " ".join(text.split()[:-2]) 94 | answer = text.split()[-2] 95 | return {"context": context, "question": question, "answer": answer} 96 | 97 | 98 | def generate_sample(nvars: int): 99 | for trial in range(1000): 100 | try: 101 | return try_generate_sample(nvars) 102 | except GenerationFail: 103 | continue 104 | raise ValueError("Failed to generate too many times!") 105 | 106 | 107 | def generate_data(args, nsamples: int): 108 | data = [] 109 | for i in range(nsamples): 110 | sample = generate_sample(args.nvars) 111 | # print(sample) 112 | data.append(sample) 113 | return data 114 | 115 | 116 | def save_json(data: List, file_name: str): 117 | """Save data to specified json file""" 118 | dirname = os.path.dirname(file_name) 119 | if not os.path.exists(dirname): 120 | os.makedirs(dirname) 121 | print(f"saving {file_name}") 122 | with open(file_name, "w") as fp: 123 | for i, sample in enumerate(data): 124 | json.dump(sample, fp) 125 | fp.write("\n") 126 | 127 | 128 | if __name__ == "__main__": 129 | parser = argparse.ArgumentParser() 130 | parser.add_argument("--nvars", type=int, default=3) 131 | parser.add_argument("--ntrain", type=int, default=10000) 132 | parser.add_argument("--ntest", type=int, default=1000) 133 | parser.add_argument("--max-steps", type=int, default=512) 134 | parser.add_argument( 135 | "--out-path", 136 | type=str, 137 | default="./data/context_position/counting", 138 | ) 139 | args = parser.parse_args() 140 | 141 | train_data = generate_data(args, args.ntrain) 142 | 143 | valid_data = generate_data(args, args.ntest) 144 | test_data = generate_data(args, args.ntest) 145 | 146 | data_name = "count_var{}_step{}_train{}k".format( 147 | args.nvars, 148 | args.max_steps, 149 | int(args.ntrain / 1000), 150 | ) 151 | 152 | save_json(train_data, os.path.join(args.out_path, data_name, "train.jsonl")) 153 | save_json(valid_data, os.path.join(args.out_path, data_name, "valid.jsonl")) 154 | save_json(test_data, os.path.join(args.out_path, data_name, "test.jsonl")) 155 | -------------------------------------------------------------------------------- /projects/cope/src/cope/context_position.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import math 9 | from argparse import ArgumentParser, BooleanOptionalAction 10 | from typing import Optional 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | 16 | def add_args(parser: ArgumentParser): 17 | parser.add_argument( 18 | "--cope-sep-key", 19 | action=BooleanOptionalAction, 20 | default=False, 21 | help="use separate key projection for computing gates", 22 | ) 23 | parser.add_argument( 24 | "--cope-val-gates", 25 | action=BooleanOptionalAction, 26 | default=False, 27 | help="use values for computing gates", 28 | ) 29 | parser.add_argument( 30 | "--cope-npos", 31 | type=int, 32 | help="the number of positions. If not set, will use block size + 1.", 33 | ) 34 | parser.add_argument( 35 | "--cope-nodiv", 36 | action=BooleanOptionalAction, 37 | default=False, 38 | help="do not divide by sqrt(div)", 39 | ) 40 | parser.add_argument( 41 | "--cope-layers", 42 | type=int, 43 | default=1, 44 | help="use CoPE only every K layers.", 45 | ) 46 | parser.add_argument( 47 | "--cope-shared", 48 | action=BooleanOptionalAction, 49 | default=False, 50 | help="share embeddings accross layers", 51 | ) 52 | 53 | 54 | class ContextPosSelfAttn(nn.Module): 55 | """Self-attention layer with contextual position embedding (CoPE)""" 56 | 57 | def __init__(self, cfg, lay_ind): 58 | super().__init__() 59 | self.cfg = cfg 60 | self.lay_ind = lay_ind 61 | self.dropout = nn.Dropout(cfg.drop) 62 | self.rel_only = False 63 | if self.cfg.cope_layers > 1: 64 | assert cfg.pos_emb == "cope+rel" 65 | if (self.lay_ind + 1) % self.cfg.cope_layers > 0: 66 | # do not use cope on this layer 67 | self.rel_only = True 68 | 69 | if cfg.cope_npos is not None: 70 | self.npos = cfg.cope_npos 71 | else: 72 | # need 1 extra position because position 0 is possible 73 | self.npos = cfg.block_size + 1 74 | 75 | if not self.rel_only: 76 | self.pos_emb = nn.parameter.Parameter( 77 | torch.zeros(1, cfg.head_dim, self.npos) 78 | ) 79 | 80 | def forward( 81 | self, 82 | query: torch.Tensor, 83 | key: torch.Tensor, 84 | key_cope: torch.Tensor, 85 | val: torch.Tensor, 86 | attn_mask: Optional[torch.Tensor] = None, 87 | ): 88 | # query, key, val : B x L x H 89 | B, L, H = key.size() 90 | # attn_mask : B x L 91 | 92 | if self.rel_only: 93 | pos_logits = 0 94 | gate_logits = None 95 | else: 96 | if self.cfg.cope_val_gates: 97 | gate_logits = torch.bmm(query, val.transpose(-1, -2)) # B x L(q) x L(k) 98 | else: 99 | gate_logits = torch.bmm( 100 | query, key_cope.transpose(-1, -2) 101 | ) # B x L(q) x L(k) 102 | gates = torch.sigmoid(gate_logits / math.sqrt(self.cfg.head_dim)) 103 | if attn_mask is not None: 104 | gates = gates * attn_mask 105 | positions = gates.flip(-1).cumsum(dim=-1).flip(-1) # B x L x L 106 | positions = positions.clamp(max=self.npos - 1) 107 | 108 | # NOW what to do with these fractional positions? 109 | 110 | # let's compute for discrete fixed positions (1,2, .., T) first 111 | pos_logits_fixed = torch.matmul(query, self.pos_emb) # B x L x npos 112 | 113 | # now we need to intrapolate floor(p) and ceil(p) for position p 114 | positions_ceil = positions.ceil().long() # yes, no gradient here 115 | positions_floor = positions.floor().long() # yes, no gradient here 116 | pos_logits_ceil = pos_logits_fixed.gather(-1, positions_ceil) # B x L x L 117 | pos_logits_floor = pos_logits_fixed.gather(-1, positions_floor) # B x L x L 118 | 119 | # almost there, need to do weighted sum of these two 120 | pos_ceil_weight = positions - positions_floor # this is differentiable 121 | pos_logits = pos_logits_ceil * pos_ceil_weight + pos_logits_floor * ( 122 | 1 - pos_ceil_weight 123 | ) 124 | 125 | if self.cfg.cope_sep_key or self.cfg.cope_val_gates or gate_logits is None: 126 | attn_logits = torch.bmm(query, key.transpose(-1, -2)) # B x L x L 127 | else: 128 | attn_logits = gate_logits 129 | 130 | if self.cfg.pos_emb == "cope+rel": 131 | # relative position only works for self-attention, not for enc-dec attention 132 | attn_logits += self.rel_pos_emb(query, key) 133 | 134 | if self.cfg.cope_nodiv: 135 | attn_logits /= math.sqrt(self.cfg.head_dim) 136 | attn_logits += pos_logits 137 | else: 138 | attn_logits += pos_logits 139 | attn_logits /= math.sqrt(self.cfg.head_dim) 140 | 141 | if attn_mask is not None: 142 | attn_logits += attn_mask.log() 143 | 144 | attn = torch.softmax(attn_logits, dim=-1) 145 | self.attn_saved = attn 146 | 147 | attn = self.dropout(attn) 148 | 149 | out = torch.bmm(attn, val) # B x L x H 150 | return out 151 | -------------------------------------------------------------------------------- /projects/cope/src/data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import argparse 9 | import logging 10 | 11 | from torch.utils.data import DataLoader, DistributedSampler 12 | 13 | from . import simple 14 | from .data_collator import DataCollatorForDecoder 15 | 16 | 17 | def add_args(parser: argparse.ArgumentParser): 18 | parser = parser.add_argument_group("data") 19 | parser.add_argument("--data", required=True) 20 | parser.add_argument("--train-file", default="train.jsonl", type=str) 21 | parser.add_argument("--val-file", default="valid.jsonl", type=str) 22 | parser.add_argument("--test-file", default="test.jsonl", type=str) 23 | parser.add_argument( 24 | "--fixed-len", 25 | type=int, 26 | default=-1, 27 | help="cut and pad data samples into a fixed length. reduces memory fragmentation", 28 | ) 29 | 30 | 31 | def get_data(cfg, tokenizer): 32 | # Convert text to numbers 33 | train_data, val_data, test_data = simple.get_data(cfg) 34 | tokenizer.build_vocab(train_data, val_data, test_data) 35 | 36 | cfg.nvocab = tokenizer.vocab_size 37 | logging.info(f"nvocab = {cfg.nvocab}") 38 | 39 | return train_data, val_data, test_data, tokenizer 40 | 41 | 42 | def get_loader(cfg, data, tokenizer, eval=False): 43 | """Get data loader and sampler for training data.""" 44 | collator = DataCollatorForDecoder(cfg, tokenizer, cfg.fixed_len) 45 | 46 | if cfg.distributed: 47 | sampler = DistributedSampler( 48 | data, 49 | num_replicas=cfg.world_size, 50 | rank=cfg.rank, 51 | shuffle=not eval, 52 | seed=cfg.seed, # must be the same for all workers 53 | drop_last=True, 54 | ) 55 | assert cfg.batch_sz % cfg.world_size == 0 56 | if eval and len(data) % cfg.world_size != 0: 57 | logging.warning( 58 | "eval data size is not divisible by ngpus, so some samples will be omitted!" 59 | ) 60 | new_batch_sz = cfg.batch_sz // cfg.world_size 61 | loader = DataLoader( 62 | data, 63 | batch_size=new_batch_sz, 64 | sampler=sampler, 65 | pin_memory=True, 66 | collate_fn=collator, 67 | ) 68 | else: 69 | loader = DataLoader( 70 | data, 71 | batch_size=cfg.batch_sz, 72 | shuffle=not eval, 73 | drop_last=False, 74 | collate_fn=collator, 75 | ) 76 | 77 | return loader 78 | -------------------------------------------------------------------------------- /projects/cope/src/data/constants.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | CONTEXT_KEY = "context" 9 | QUESTION_KEY = "question" 10 | ANSWER_KEY = "answer" 11 | REF_COT_LOGPROB_KEY = "ref_cot_logprobs" 12 | 13 | 14 | TOKEN_TYPE_PAD = 0 15 | TOKEN_TYPE_CONTEXT = 1 16 | TOKEN_TYPE_QUESTION = 2 17 | TOKEN_TYPE_ANSWER = 3 18 | -------------------------------------------------------------------------------- /projects/cope/src/data/data_collator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import torch 9 | 10 | from .constants import ( 11 | ANSWER_KEY, 12 | CONTEXT_KEY, 13 | QUESTION_KEY, 14 | TOKEN_TYPE_ANSWER, 15 | TOKEN_TYPE_CONTEXT, 16 | TOKEN_TYPE_PAD, 17 | TOKEN_TYPE_QUESTION, 18 | ) 19 | 20 | 21 | class DataCollatorForDecoder: 22 | def __init__(self, cfg, tokenizer, fixed_len: int, training=True): 23 | self.cfg = cfg 24 | self.tokenizer = tokenizer 25 | # Training argument is not put in use right now 26 | self.training = training 27 | self.fixed_len = fixed_len 28 | 29 | def pad_right(self, enc_seqs, enc_seqs_type): 30 | # Right padding for batching 31 | 32 | if self.fixed_len > 0: 33 | max_len = self.fixed_len + 1 34 | else: 35 | max_len = max([len(enc_seq) for enc_seq in enc_seqs]) 36 | 37 | # trim from left for longer sequences 38 | max_len = min(max_len, self.cfg.block_size) 39 | enc_seqs = [x[-max_len:] for x in enc_seqs] 40 | enc_seqs_type = [x[-max_len:] for x in enc_seqs_type] 41 | 42 | enc_seqs = [ 43 | (enc_seq + [self.tokenizer.pad_ind] * (max_len - len(enc_seq))) 44 | for enc_seq in enc_seqs 45 | ] 46 | enc_seqs_type = [ 47 | (enc_seq_type + [TOKEN_TYPE_PAD] * (max_len - len(enc_seq_type))) 48 | for enc_seq_type in enc_seqs_type 49 | ] 50 | enc_seqs = torch.LongTensor(enc_seqs) 51 | enc_seqs_type = torch.LongTensor(enc_seqs_type) 52 | 53 | return enc_seqs, enc_seqs_type 54 | 55 | def __call__(self, instances): 56 | enc_seqs = [[] for _ in range(len(instances))] 57 | enc_seqs_type = [[] for _ in range(len(instances))] 58 | for i, instance in enumerate(instances): 59 | for sample_key, token_type in zip( 60 | [ 61 | CONTEXT_KEY, 62 | QUESTION_KEY, 63 | ANSWER_KEY, 64 | ], 65 | [ 66 | TOKEN_TYPE_CONTEXT, 67 | TOKEN_TYPE_QUESTION, 68 | TOKEN_TYPE_ANSWER, 69 | ], 70 | ): 71 | if sample_key in instance and instance[sample_key] != "": 72 | # encode separately so we know which token belongs to what 73 | add_bos = len(enc_seqs[i]) == 0 # add bos at the start only 74 | enc_key = self.tokenizer.encode( 75 | instance[sample_key], add_bos=add_bos 76 | ) 77 | enc_seqs[i] += enc_key 78 | enc_seqs_type[i] += [token_type] * len(enc_key) 79 | 80 | enc_seqs, enc_seqs_type = self.pad_right(enc_seqs, enc_seqs_type) 81 | 82 | # remove answer from x 83 | dec_x = enc_seqs[:, :-1] 84 | dec_x_type = enc_seqs_type[:, :-1] 85 | 86 | # shift y right by 1 87 | dec_y = enc_seqs[:, 1:] 88 | dec_y_type = enc_seqs_type[:, 1:] 89 | 90 | out = { 91 | "dec_x": dec_x, 92 | "dec_y": dec_y, 93 | "dec_x_type": dec_x_type, 94 | "dec_y_type": dec_y_type, 95 | } 96 | return out 97 | -------------------------------------------------------------------------------- /projects/cope/src/data/simple.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import json 9 | import os 10 | 11 | from torch.utils.data import Dataset 12 | 13 | 14 | class SimpleDataset(Dataset): 15 | def __init__(self, cfg, data_path): 16 | super().__init__() 17 | self.cfg = cfg 18 | with open(data_path, "r") as fp: 19 | data_list = list(fp) 20 | 21 | self.data = [] 22 | for sample in data_list: 23 | sample = json.loads(sample) 24 | 25 | if "context" in sample and isinstance(sample["context"], list): 26 | sample["context"] = " ".join(sample["context"]) 27 | 28 | self.data.append(sample) 29 | 30 | def __len__(self): 31 | return len(self.data) 32 | 33 | def __getitem__(self, idx): 34 | return self.data[idx] 35 | 36 | 37 | def get_data(cfg): 38 | train_data = SimpleDataset(cfg, os.path.join(cfg.data, cfg.train_file)) 39 | val_data = SimpleDataset(cfg, os.path.join(cfg.data, cfg.val_file)) 40 | if os.path.exists(os.path.join(cfg.data, cfg.test_file)): 41 | test_data = SimpleDataset(cfg, os.path.join(cfg.data, cfg.test_file)) 42 | else: 43 | test_data = None 44 | return train_data, val_data, test_data 45 | -------------------------------------------------------------------------------- /projects/cope/src/data/tokenizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | from abc import ABC, abstractmethod 9 | 10 | from .constants import ANSWER_KEY, CONTEXT_KEY, QUESTION_KEY 11 | 12 | 13 | class Tokenizer(ABC): 14 | pad_ind: int 15 | 16 | def __init__(self): 17 | pass 18 | 19 | @abstractmethod 20 | def encode(self, text, **kwargs): 21 | pass 22 | 23 | @abstractmethod 24 | def decode(self, inds, **kwargs): 25 | pass 26 | 27 | @property 28 | @abstractmethod 29 | def vocab_size(self): 30 | pass 31 | 32 | def get_state(self): 33 | return None 34 | 35 | def set_state(self, state): 36 | pass 37 | 38 | 39 | class SimpleTokenizer(Tokenizer): 40 | def __init__(self): 41 | self.word2ind = {} 42 | self.ind2word = [] 43 | self._allow_add_word = True 44 | self.add_word("") 45 | self.pad_ind = self.word2ind[""] 46 | self.add_word("") 47 | self.sta_ind = self.word2ind[""] 48 | 49 | def add_word(self, word): 50 | assert self._allow_add_word 51 | assert word not in self.word2ind 52 | self.word2ind[word] = len(self.ind2word) 53 | self.ind2word.append(word) 54 | 55 | def encode(self, text, **kwargs): 56 | text = text.strip() 57 | text = text.lower() 58 | text = text.replace(".", " .") 59 | text = text.replace(",", " ,") 60 | text = text.replace("?", " ?") 61 | 62 | words = text.split() 63 | inds = [] 64 | for w in words: 65 | if w not in self.word2ind: 66 | self.add_word(w) 67 | inds.append(self.word2ind[w]) 68 | return inds 69 | 70 | def decode(self, inds, **kwargs): 71 | words = [self.ind2word[i] for i in inds] 72 | return words 73 | 74 | @property 75 | def vocab_size(self): 76 | return len(self.ind2word) 77 | 78 | def build_vocab(self, train_data, val_data, test_data): 79 | for data in [train_data, val_data, test_data]: 80 | if data is None: 81 | continue 82 | for sample in data: 83 | if CONTEXT_KEY in sample: 84 | self.encode(sample[CONTEXT_KEY]) 85 | if QUESTION_KEY in sample: 86 | self.encode(sample[QUESTION_KEY]) 87 | if ANSWER_KEY in sample: 88 | self.encode(sample[ANSWER_KEY]) 89 | self._allow_add_word = False 90 | 91 | def get_state(self): 92 | state = {} 93 | state["word2ind"] = self.word2ind 94 | state["ind2word"] = self.ind2word 95 | return state 96 | 97 | def set_state(self, state): 98 | self.word2ind = state["word2ind"] 99 | self.ind2word = state["ind2word"] 100 | 101 | 102 | def get_tokenizer(cfg): 103 | return SimpleTokenizer() 104 | -------------------------------------------------------------------------------- /projects/cope/src/models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | from argparse import ArgumentParser, BooleanOptionalAction, Namespace 9 | 10 | import torch 11 | 12 | from ..data.tokenizer import Tokenizer 13 | from ..utils.distributed import DummyWrapper 14 | from . import simple_transformer 15 | from .base import Model 16 | 17 | 18 | def add_args(parser: ArgumentParser): 19 | group = parser.add_argument_group("Model") 20 | group.add_argument( 21 | "--model", 22 | choices=[ 23 | "simpledec", 24 | ], # todo remove flag 25 | default="simpledec", 26 | ) 27 | group.add_argument( 28 | "--tokenizer", 29 | choices=["simple"], 30 | default="simple", 31 | help="if not specified, infer from the model", 32 | ) # todo remove flag 33 | group.add_argument( 34 | "--gpt2-add-special-tokens", 35 | nargs="+", 36 | help="add new special tokens to GPT2 tokenizer", 37 | ) 38 | group.add_argument("--untrained", action=BooleanOptionalAction, default=False) 39 | group.add_argument("--drop", type=float, default=0.1) 40 | group.add_argument( 41 | "--block-size", 42 | type=int, 43 | help="the maximum number of tokens that the model can process. Most models have it predefined.", 44 | ) 45 | simple_transformer.add_args(group) # type: ignore 46 | 47 | 48 | def set_block_size(cfg): 49 | assert cfg.block_size is not None # todo simplify 50 | 51 | 52 | def build(cfg, tokenizer: Tokenizer) -> Model: 53 | return simple_transformer.TransformerDecoder(cfg, tokenizer) 54 | -------------------------------------------------------------------------------- /projects/cope/src/models/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | from abc import ABC, abstractmethod 9 | from typing import List, Optional 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch import Tensor 14 | 15 | 16 | class Model(nn.Module, ABC): 17 | @abstractmethod 18 | def forward( 19 | self, 20 | x: torch.Tensor, 21 | pad_mask: Optional[Tensor] = None, 22 | **kargs, 23 | ): 24 | pass 25 | 26 | @abstractmethod 27 | def generate( 28 | self, 29 | prompts: List[Tensor], 30 | train=False, 31 | **kargs, 32 | ): 33 | pass 34 | 35 | 36 | def build_value_head(cfg, value_layers, embed_dim, dtype=None): 37 | value_out = nn.Linear(embed_dim, 1, dtype=dtype) 38 | value_out.weight.data.fill_(0) 39 | value_out.bias.data.fill_(0) 40 | if value_layers > 1: 41 | value_mods = [] 42 | for _ in range(value_layers - 1): 43 | value_mods.append(nn.Linear(embed_dim, embed_dim, dtype=dtype)) 44 | value_mods.append(nn.ReLU()) 45 | value_mods.append(value_out) 46 | value_head = nn.Sequential(*value_mods) 47 | else: 48 | value_head = value_out 49 | 50 | return value_head 51 | -------------------------------------------------------------------------------- /projects/cope/src/models/relative_position.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | from typing import Optional 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch import Tensor 14 | 15 | 16 | def rel2abs(X): 17 | """convert from relative position to absolute""" 18 | # X = B x L x M 19 | B, L, M = X.size() 20 | if L == 1: 21 | raise NotImplementedError("TODO") 22 | return X 23 | X = F.pad(X, (0, L)) # B x L x M+L 24 | X = X.view(B, -1) # B x LM+LL 25 | X = X[:, :-L] # B x LM+LL-L 26 | X = X.view(B, L, M + L - 1) 27 | return X 28 | 29 | 30 | class RelPosEmb(nn.Module): 31 | def __init__( 32 | self, emb_dim: int, past_len: int, npos_max: Optional[int], extend=False 33 | ) -> None: 34 | super().__init__() 35 | self.emb_dim = emb_dim 36 | self.total_len = past_len + 1 # add 1 for position zero 37 | self.extend = extend 38 | if npos_max is None: 39 | self.npos = self.total_len 40 | else: 41 | self.npos = npos_max 42 | self.emb = nn.parameter.Parameter( 43 | torch.randn(1, emb_dim, self.npos) * emb_dim**-0.5 44 | ) 45 | 46 | def forward(self, query: Tensor, key: Tensor): 47 | B, Lq, H = query.size() 48 | B, Lk, H = key.size() 49 | 50 | pos_logits = torch.matmul(query, self.emb) # B x Lq x npos 51 | 52 | if self.npos < Lk: 53 | # there must be fewer positions than the context size 54 | assert self.npos < self.total_len 55 | assert Lk == self.total_len 56 | if self.extend: 57 | # use the last position as for those far away positions 58 | extend_len = Lk - self.npos 59 | pos_logits = F.pad(pos_logits, [extend_len, 0], mode="replicate") 60 | 61 | if Lq == 1: 62 | # used during generation 63 | pass 64 | else: 65 | pos_logits = rel2abs(pos_logits) 66 | 67 | if pos_logits.size(2) > Lk: 68 | # this could happen because early tokens will not use all its rel-pos embeddings 69 | # trim from left to match with the number of keys 70 | pos_logits = pos_logits[:, :, -Lk:] 71 | elif pos_logits.size(2) < Lk: 72 | # this should not happen even if npos is set shorter than the block size 73 | assert False 74 | return pos_logits 75 | -------------------------------------------------------------------------------- /projects/cope/src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | from dataclasses import dataclass 9 | from typing import Dict, List, Optional 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch import Tensor 14 | 15 | 16 | def find_pos(x: Tensor, id: int): 17 | """position of the first occurance id""" 18 | assert x.ndim == 1 19 | index = (x == id).nonzero() 20 | if index.size(0) == 0: 21 | return None 22 | return index[0].item() 23 | 24 | 25 | def stack_with_pad(x: List[Tensor], pad_value: float, from_left=False) -> Tensor: 26 | max_len = max([z.size(0) for z in x]) 27 | assert x[0].ndim == 1 28 | x_padded = [] 29 | for i in range(len(x)): 30 | if from_left: 31 | pads = [max_len - len(x[i]), 0] 32 | else: 33 | pads = [0, max_len - len(x[i])] 34 | x_padded.append(F.pad(x[i], pads, value=pad_value)) 35 | return torch.stack(x_padded) 36 | 37 | 38 | @dataclass 39 | class Batch: 40 | x: Tensor 41 | x_type: Tensor 42 | y: Tensor 43 | y_type: Optional[Tensor] 44 | dict: Optional[Dict] = None 45 | rationales: Optional[list] = None 46 | ref_cot_logprobs: Optional[list] = None 47 | misc = {} 48 | 49 | @classmethod 50 | def from_dict(cls, d: Dict, device=None): 51 | y_type = d["dec_y_type"].to(device=device) if "dec_y_type" in d else None 52 | 53 | if "rationales" in d: 54 | rationales = [ 55 | [r.to(device=device) for r in rats] for rats in d["rationales"] 56 | ] 57 | else: 58 | rationales = None 59 | 60 | if "ref_cot_logprobs" in d: 61 | ref_cot_logprobs = d["ref_cot_logprobs"].to(device=device) 62 | else: 63 | ref_cot_logprobs = None 64 | 65 | return cls( 66 | d["dec_x"].to(device=device), 67 | d["dec_x_type"].to(device=device), 68 | d["dec_y"].to(device=device), 69 | y_type, 70 | d, 71 | rationales, 72 | ref_cot_logprobs, 73 | ) 74 | 75 | @property 76 | def size(self): 77 | return self.x.size(0) 78 | -------------------------------------------------------------------------------- /projects/cope/src/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import logging 9 | import os 10 | from argparse import ArgumentParser, BooleanOptionalAction 11 | 12 | import torch 13 | from torch import distributed 14 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 15 | 16 | """Checkpointing utils""" 17 | 18 | 19 | def add_args(parser: ArgumentParser): 20 | group = parser.add_argument_group("Checkpoint") 21 | parser.add_argument("--checkpoint", type=str, default=None) 22 | parser.add_argument( 23 | "--checkpoint-freq", type=int, default=1, help="how often to save checkpoint" 24 | ) 25 | parser.add_argument( 26 | "--checkpoint-keep", 27 | type=int, 28 | help="keep old checkpoints from every Nth epoch in addition to the last", 29 | ) 30 | group.add_argument("--init-model-from", type=str) 31 | group.add_argument( 32 | "--init-model-no-strict", default=False, action=BooleanOptionalAction 33 | ) 34 | 35 | 36 | def save( 37 | cfg, 38 | model, 39 | optimizer, 40 | logger, 41 | tokenizer, 42 | epoch, 43 | ): 44 | """Save checkpoint""" 45 | epoch_num = epoch + 1 46 | if cfg.checkpoint_freq > 1: 47 | if epoch_num % cfg.checkpoint_freq != 0: 48 | return 49 | 50 | if cfg.distributed: 51 | distributed.barrier() 52 | 53 | # need to run all workers for FSDP 54 | model_state_dict = model.state_dict() 55 | if cfg.distributed and cfg.fsdp: 56 | optim_state_dict = FSDP.optim_state_dict(model, optimizer) 57 | else: 58 | optim_state_dict = optimizer.state_dict() 59 | 60 | if cfg.rank > 0: 61 | return 62 | 63 | path = cfg.checkpoint 64 | dirname = os.path.dirname(path) 65 | if not os.path.exists(dirname): 66 | os.makedirs(dirname) 67 | logging.info(f"saving epoch {epoch_num} to {path}") 68 | 69 | state_dict = {} 70 | state_dict["cfg"] = cfg 71 | state_dict["model"] = model_state_dict 72 | state_dict["logger"] = logger.get_state() 73 | state_dict["optimizer"] = optim_state_dict 74 | state_dict["tokenizer"] = tokenizer.get_state() 75 | 76 | if cfg.checkpoint_keep is not None: 77 | if epoch_num % cfg.checkpoint_keep == 0: 78 | torch.save(state_dict, path + f".{epoch_num}") 79 | 80 | # if best_epoch is True, save the model as "*_best.pt" in the save path 81 | if cfg.valid_metric is not None and logger.metrics[cfg.valid_metric].cur_epoch_best: 82 | torch.save(state_dict, path + ".best") 83 | 84 | torch.save(state_dict, path) 85 | torch.save(cfg, path + ".cfg") 86 | logging.info("done saving.") 87 | 88 | 89 | def load_model(state_dict, model, strict=True): 90 | """Load model from a given state dictionary""" 91 | model.load_state_dict(state_dict["model"], strict=strict) 92 | 93 | 94 | def load_checkpoint(cfg, model, optimizer, logger, tokenizer): 95 | """Load checkpoint""" 96 | state_dict = torch.load(cfg.checkpoint, map_location=torch.device("cpu")) 97 | load_model(state_dict, model) 98 | optim_state_dict = state_dict["optimizer"] 99 | if cfg.distributed and cfg.fsdp: 100 | try: 101 | optim_state_dict = FSDP.optim_state_dict_to_load( 102 | optim_state_dict, model, optimizer 103 | ) 104 | except Exception: 105 | # for loading old checkpoints 106 | optim_state_dict = state_dict["optimizer"] 107 | optimizer.load_state_dict(optim_state_dict) 108 | logger.set_state(state_dict["logger"]) 109 | if "tokenizer" in state_dict: 110 | tokenizer.set_state(state_dict["tokenizer"]) 111 | 112 | 113 | def load(cfg, model, optimizer, logger, tokenizer): 114 | if cfg.checkpoint is not None and os.path.exists(cfg.checkpoint): 115 | logging.info(f"loading checkpoint from {cfg.checkpoint}") 116 | load_checkpoint(cfg, model, optimizer, logger, tokenizer) 117 | elif cfg.init_model_from is not None: 118 | f = torch.load(cfg.init_model_from, map_location=torch.device("cpu")) 119 | logging.info(f"loading model from {cfg.init_model_from}") 120 | load_model(f, model, strict=not cfg.init_model_no_strict) 121 | -------------------------------------------------------------------------------- /projects/cope/src/utils/distributed.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import contextlib 9 | import functools 10 | import os 11 | from argparse import ArgumentParser, BooleanOptionalAction 12 | 13 | import submitit 14 | import torch 15 | from torch import distributed 16 | 17 | 18 | def add_cmd_args(parser: ArgumentParser): 19 | parser.add_argument( 20 | "--distributed", 21 | action=BooleanOptionalAction, 22 | default=False, 23 | help="distributed training", 24 | ) 25 | parser.add_argument( 26 | "--submitit", action=BooleanOptionalAction, default=False, help="using submitit" 27 | ) 28 | parser.add_argument("--rank", type=int, default=0, help="") 29 | parser.add_argument("--local-rank", type=int, default=0, help="") 30 | parser.add_argument("--world-size", type=int, default=1, help="") 31 | parser.add_argument("--dist-init", type=str, help="distrixbuted training") 32 | parser.add_argument( 33 | "--fsdp", action=BooleanOptionalAction, default=False, help="using fsdp" 34 | ) 35 | 36 | 37 | class DummyWrapper(torch.nn.Module): 38 | def __init__(self, mod): 39 | super(DummyWrapper, self).__init__() 40 | self.module = mod 41 | self.no_sync = contextlib.nullcontext 42 | 43 | def forward(self, *cfg, **kwcfg): 44 | return self.module(*cfg, **kwcfg) 45 | 46 | def load_state_dict(self, state_dict, strict=True): 47 | try: 48 | super().load_state_dict(state_dict, strict) 49 | except RuntimeError: 50 | # FSDP saves without "module" prefix 51 | self.module.load_state_dict(state_dict, strict) 52 | 53 | 54 | def init(cfg): 55 | if cfg.submitit: 56 | job_env = submitit.JobEnvironment() 57 | cfg.local_rank = job_env.local_rank 58 | cfg.rank = job_env.global_rank 59 | cfg.world_size = job_env.num_tasks 60 | distributed.init_process_group( 61 | backend="nccl", 62 | init_method=cfg.dist_init, 63 | rank=job_env.global_rank, 64 | world_size=job_env.num_tasks, 65 | ) 66 | else: 67 | init_file = os.getcwd() + "/dist_init" 68 | distributed.init_process_group( 69 | backend="nccl", 70 | init_method=f"file://{init_file}", 71 | world_size=cfg.world_size, 72 | rank=cfg.rank, 73 | ) 74 | cfg.local_rank = distributed.get_rank() 75 | cfg.device = torch.device("cuda", cfg.local_rank) 76 | if cfg.rank > 0: 77 | cfg.log_plot = False 78 | 79 | 80 | def wrap_model(cfg, model): 81 | if cfg.distributed: 82 | if cfg.fsdp: 83 | from torch.distributed.fsdp import FullStateDictConfig 84 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 85 | from torch.distributed.fsdp import ShardingStrategy, StateDictType 86 | from torch.distributed.fsdp.api import FullOptimStateDictConfig 87 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 88 | from transformers.models.llama.modeling_llama import LlamaDecoderLayer 89 | 90 | assert cfg.model.startswith("llama") 91 | llama_auto_wrap_policy = functools.partial( 92 | transformer_auto_wrap_policy, 93 | transformer_layer_cls={ 94 | LlamaDecoderLayer, 95 | }, 96 | ) 97 | model = FSDP( 98 | model, 99 | sharding_strategy=ShardingStrategy.FULL_SHARD, 100 | auto_wrap_policy=llama_auto_wrap_policy, 101 | device_id=cfg.local_rank, 102 | ) 103 | 104 | FSDP.set_state_dict_type( 105 | model, 106 | StateDictType.FULL_STATE_DICT, 107 | FullStateDictConfig(offload_to_cpu=True, rank0_only=False), 108 | FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False), 109 | ) 110 | else: 111 | model = model.to(cfg.device) 112 | model = torch.nn.parallel.DistributedDataParallel( 113 | model, 114 | device_ids=[cfg.local_rank], 115 | output_device=cfg.local_rank, 116 | find_unused_parameters=False, 117 | ) 118 | else: 119 | model = DummyWrapper(model) 120 | model = model.to(cfg.device) 121 | return model 122 | 123 | 124 | def cleanup(): 125 | pass 126 | # distributed.destroy_process_group() 127 | -------------------------------------------------------------------------------- /projects/cope/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | from argparse import ArgumentParser 9 | 10 | from src.main import add_train_args, train 11 | 12 | if __name__ == "__main__": 13 | parser = ArgumentParser() 14 | add_train_args(parser) 15 | cfg = parser.parse_args() 16 | train(cfg) 17 | -------------------------------------------------------------------------------- /projects/length_instruct/README.md: -------------------------------------------------------------------------------- 1 | # Following Length Constraints in Instructions 2 | 3 | Arxiv link: [Yuan et al. (2024)](https://arxiv.org/abs/2406.17744). 4 | 5 |

6 | 7 | ## Abstract 8 | 9 | Aligned instruction following models can better fulfill user requests than their unaligned counterparts. However, it has been shown that there is a length bias in evaluation of such models, and that training algorithms tend to exploit this bias by learning longer responses. In this work we show how to train models that can be controlled at inference time with instructions containing desired length constraints. Such models are superior in length instructed evaluations, outperforming standard instruction following models such as GPT4, Llama 3 and Mixtral. 10 | 11 | ## AlpacaEval-LI \& MT-Bench-LI: New Length-Instructed Benchmarks 12 | To evaluate the ability of current instruction following models to follow length instructions, we build length-instructed (LI) benchmarks, AlpacaEval-LI and MT-Bench-LI. 13 | 14 | In this paper, we also develop a method called Length-Instruction Fine-Tuning (LIFT) for improving instruction following models at length instruction following. Results reveal that: 15 |
    16 |
  1. Current state-of-the-art LLMs fail to follow length instructions
  2. 17 |
  3. LIFT-DPO models perform well on AlpacaEval-LI and MT-Bench-LI
  4. 18 |
  5. LIFT-DPO models show no performance degradation on standard AlpacaEval 2
  6. 19 |
  7. LIFT-DPO can follow out-of-distribution length instructions better than existing methods
  8. 20 |
21 | 22 | 23 | ## AlpacaEval-LI \& MT-Bench-LI: New Length-Instructed Benchmarks 24 | We introduce two new length-instructed benchmarks **AlpacaEval-LI** and **MT-Bench-LI**. 25 | 26 | ### AlpacaEval-LI 27 | The AlpacaEval-LI consists of 802 length-instructed prompts from [Alpaca Eval](https://huggingface.co/datasets/tatsu-lab/alpaca_eval/tree/main). Download [AlpacaEval-LI prompts from s3](https://dl.fbaipublicfiles.com/length_instruct/v1/data/AlpacaEval-LI/length_instructed_alpaca_baseline.json). Below is a length-instructed example from the AlpacaEval-LI evaluation set. 28 | ``` 29 | { 30 | "dataset":"helpful_base", 31 | "instruction":"Answer the following instruction using 56 words or less.\n\nWho created the Superman cartoon character?", 32 | "output":"Superman, the iconic comic book superhero, was created by writer Jerry Siegel and artist Joe Shuster. Superman first appeared in Action Comics #1, which was published by Detective Comics, Inc. (later DC Comics) in June 1938. The character's immense popularity established him as one of the most enduring and recognizable figures in the superhero genre.", 33 | "generator":"AlpacaEval-LI", 34 | "max_len":56, 35 | "original_instruction":"Who created the Superman cartoon character?" 36 | } 37 | ``` 38 | In particular, 39 | - `original_instruction`: the original prompt from alpaca eval 40 | - `instruction`: the length-instructed prompt. 41 | - `max_len`: target length limit (in words). 42 | - `output`: the reference output to pairwise compare against. 43 | 44 | Command to Run AlpacaEval-LI Evaluation: Coming soon. 45 | 46 | ### MT-Bench-LI 47 | The MT-Bench-LI consists of 240 length-instructed prompts from [MT-Bench](https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge). See [downloadable MT-Bench-LI questions from s3](https://dl.fbaipublicfiles.com/length_instruct/v1/data/MT-Bench-LI/question.jsonl) for more details. Below is an example of length-instructed MT-Bench-LI prompt: 48 | ``` 49 | { 50 | "question_id": 162, 51 | "category": "writing", 52 | "turns": ["Answer the following instruction using 180 words or less.\n\nDraft a professional email seeking your supervisor's feedback on the 'Quarterly Financial Report' you prepared. Ask specifically about the data analysis, presentation style, and the clarity of conclusions drawn. Keep the email short and to the point."], 53 | "max_len": [180], 54 | "original_turns": ["Draft a professional email seeking your supervisor's feedback on the 'Quarterly Financial Report' you prepared. Ask specifically about the data analysis, presentation style, and the clarity of conclusions drawn. Keep the email short and to the point.", "Take a moment to evaluate and critique your own response."] 55 | } 56 | ``` 57 | In particular, 58 | - `original_turns`: the original multi-turn questions from MT-Bench. 59 | - `turns`: the length-instructed question. Only turn1 question is included for length-instructed evaluation. 60 | - `max_len`: target length limit (in words) for turn1 question. 61 | 62 | To run pairwise MT-Bench-LI evaluation, please download [MT-Bench-LI pairwise baseline answers](https://dl.fbaipublicfiles.com/length_instruct/v1/data/MT-Bench-LI/length_instructed_turn1_baseline.jsonl), as well as [MT-Bench-LI reference answers](https://dl.fbaipublicfiles.com/length_instruct/v1/data/MT-Bench-LI/reference_answer.jsonl). 63 | 64 | Command to Run MT-Bench-LI Evaluation: Coming soon. 65 | 66 | ## Citation 67 | If you use the dataset or models in your own work, please cite with the following BibTex entry: 68 | ``` 69 | @misc{yuan2024followinglengthconstraintsinstructions, 70 | title={Following Length Constraints in Instructions}, 71 | author={Weizhe Yuan and Ilia Kulikov and Ping Yu and Kyunghyun Cho and Sainbayar Sukhbaatar and Jason Weston and Jing Xu}, 72 | year={2024}, 73 | eprint={2406.17744}, 74 | archivePrefix={arXiv}, 75 | primaryClass={cs.CL} 76 | url={https://arxiv.org/abs/2406.17744}, 77 | } 78 | ``` 79 | -------------------------------------------------------------------------------- /projects/length_instruct/fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/RAM/c567dfb73c0bc60c7ca1f114b72cd3bf7d9442bd/projects/length_instruct/fig.png -------------------------------------------------------------------------------- /projects/mta/README.md: -------------------------------------------------------------------------------- 1 | # Multi-Token Attention 2 | 3 | - Motivation: soft attention looks at two tokens at a time to weigh their importance. But often it’s not enough! Suppose you are reading a history book, and you want to find what happened in Rome in 1417. You need to match both city and date *mentioned together*. 4 | - The high level goal is to make it possible to use the similarities of multiple vector pairs to determine where attention must focus. 5 | - We add convolutions for keys, queries, and attention heads to allow conditioning on neighboring tokens! 6 | 7 |

8 | 9 | ## Papers 10 | 11 | This work is based on the following paper: [Multi-Token Attention](https://arxiv.org/pdf/2504.00927). 12 | 13 | ## Setup 14 | 15 | 1. Create conda environment following [Lingua instructions](https://github.com/facebookresearch/lingua) 16 | 2. Add RAM installation: 17 | 18 | ```bash 19 | cd ~/RAM 20 | 21 | pip install -e . 22 | ``` 23 | 24 | 3. Pull submodules: 25 | 26 | ```bash 27 | git submodule update --init --recursive 28 | ``` 29 | 30 | ## Run model training 31 | 32 | 1. Activate your environment: 33 | 34 | ```bash 35 | conda activate 36 | export PYTHONPATH=/RAM/public_repos/lingua 37 | ``` 38 | 39 | 2. Start distributed training: 40 | ```bash 41 | cd projects/mta/ 42 | 43 | python -m lingua.stool script=train config=./configs/300M_mta.yaml nodes=4 qos=lowest 44 | ``` 45 | 46 | 300M-830M configurations assume 4 nodes for training, 1B configuration assums 8 nodes. 47 | 48 | 3. Generate text completion: 49 | ```bash 50 | python -m generate ckpt=/consolidated dump_dir=/tmp max_gen_len=16 51 | ``` 52 | 53 | 54 | ## Contributors 55 | Olga Golovneva, Tianlu Wang, Jason Weston, Sainbayar Sukhbaatar 56 | 57 | ## Citation 58 | If you use our model in your own work, please cite with the following BibTex entry: 59 | ``` 60 | @article{golovneva2025multi, 61 | title={Multi-Token Attention}, 62 | author={Golovneva, Olga and Wang, Tianlu and Weston, Jason and Sukhbaatar, Sainbayar}, 63 | journal={arXiv preprint arXiv:2504.00927}, 64 | year={2025} 65 | } -------------------------------------------------------------------------------- /projects/mta/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | -------------------------------------------------------------------------------- /projects/mta/configs/1B_baseline.yaml: -------------------------------------------------------------------------------- 1 | # use local train.py; 8 nodes 2 | 3 | dump_dir: /checkpoints/1B_baseline 4 | name: &name "1B_baseline" 5 | steps: 200000 6 | grad_acc_steps: 1 7 | probe_freq: 100 8 | 9 | seed: 777 10 | optim: 11 | lr: 0.00015 12 | weight_decay: 0.05 13 | warmup: 187 # 375 14 | scheduler: linear 15 | beta2: 0.98 16 | lr_min_ratio: 0.0 17 | 18 | distributed: 19 | fsdp_type: full_shard 20 | compile: true 21 | model_dtype: bf16 22 | matmul_allow_tf32: false 23 | selective_activation_checkpointing: false 24 | tp_size: 1 25 | 26 | model: 27 | dim: 2048 28 | n_layers: 24 29 | n_heads: 16 30 | rope_theta: 100_000 31 | ffn_dim_multiplier: 1.0 32 | multiple_of: 256 33 | weight_tying: true 34 | mta: 35 | use_mta: false 36 | 37 | data: 38 | root_dir: /slim_pajama/train/ 39 | sources: 40 | commoncrawl: 0.726 41 | c4: 0.081 42 | github: 0.049 43 | book: 0.021 44 | arxiv: 0.023 45 | wikipedia: 0.05 46 | stackexchange: 0.05 47 | # bsx = 2048 * 8 * nnodes * b_s 48 | # 4 nodes, 4 b_s = 0.25M tok per batch 49 | batch_size: 4 50 | prefetch_size: 1024 51 | seq_len: 2048 52 | n_views: 2 53 | load_async: true 54 | tokenizer: 55 | name: tiktoken 56 | path: /Llama-3.1-70B-Instruct/original/tokenizer.model 57 | 58 | profiling: 59 | run: true 60 | mem_warmup: 0 61 | mem_steps: 4 62 | profile_warmup: 10 63 | profile_steps: 4 64 | 65 | checkpoint: 66 | dump: 67 | every: 10000 68 | keep: 1 69 | eval: 70 | every: 10000 71 | keep: 1 72 | 73 | logging: 74 | freq: 10 75 | wandb: # specify monitoring if needed 76 | project: lingua 77 | entity: new_attention 78 | resume: allow 79 | id: *name 80 | 81 | # sync eval 82 | eval: 83 | ppl_files: 84 | - /slim_pajama/valid/arxiv/data.chunk.00.jsonl 85 | - /slim_pajama/valid/book/data.chunk.00.jsonl 86 | - /slim_pajama/valid/c4/data.chunk.00.jsonl 87 | - /slim_pajama/valid/commoncrawl/data.chunk.00.jsonl 88 | - /slim_pajama/valid/github/data.chunk.00.jsonl 89 | - /slim_pajama/valid/stackexchange/data.chunk.00.jsonl 90 | - /slim_pajama/valid/wikipedia/data.chunk.00.jsonl 91 | ppl_seq_len: 2048 92 | ppl_batch_size: 4 93 | ppl_n_batches: 256 94 | generator: 95 | max_tokens: 2048 96 | dtype: bf16 97 | -------------------------------------------------------------------------------- /projects/mta/configs/1B_mta.yaml: -------------------------------------------------------------------------------- 1 | # use local train.py 8 nodes 2 | 3 | dump_dir: /checkpoints/1B_mta 4 | name: &name "1B_mta" 5 | steps: 200000 6 | grad_acc_steps: 1 7 | probe_freq: 100 8 | 9 | seed: 777 10 | optim: 11 | lr: 0.00015 12 | weight_decay: 0.05 13 | warmup: 187 14 | scheduler: linear 15 | beta2: 0.98 16 | lr_min_ratio: 0.0 17 | 18 | distributed: 19 | fsdp_type: full_shard 20 | compile: true 21 | model_dtype: bf16 22 | matmul_allow_tf32: false 23 | selective_activation_checkpointing: false 24 | tp_size: 1 25 | 26 | model: 27 | dim: 2048 28 | n_layers: 24 29 | n_heads: 16 30 | rope_theta: 100_000 31 | ffn_dim_multiplier: 1.0 32 | multiple_of: 256 33 | weight_tying: true 34 | mta: 35 | use_mta: true 36 | # before sm 37 | pre_sm_linear_head: true 38 | query_kernel_size: 6 39 | key_kernel_size: 11 40 | # after sm 41 | head_kernel_size: 16 42 | after_sm_query_kernel_size: 6 43 | after_sm_key_kernel_size: 11 44 | # common 45 | init_method: identity 46 | pad_key: "both" 47 | mta_layers: "2,6,10,14,18,22" 48 | group_norm: true 49 | layer_norm_rescale: false 50 | add_gating: true 51 | gate_1d: true 52 | 53 | data: 54 | root_dir: /slim_pajama/train/ 55 | sources: 56 | commoncrawl: 0.726 57 | c4: 0.081 58 | github: 0.049 59 | book: 0.021 60 | arxiv: 0.023 61 | wikipedia: 0.05 62 | stackexchange: 0.05 63 | # bsx = 2048 * 8 * nnodes * b_s 64 | # 4 nodes, 4 b_s = 0.25M tok per batch 65 | batch_size: 4 66 | prefetch_size: 1024 67 | seq_len: 2048 68 | n_views: 2 69 | load_async: true 70 | tokenizer: 71 | name: tiktoken 72 | path: /Llama-3.1-70B-Instruct/original/tokenizer.model 73 | 74 | profiling: 75 | run: true 76 | mem_warmup: 0 77 | mem_steps: 4 78 | profile_warmup: 10 79 | profile_steps: 4 80 | 81 | checkpoint: 82 | dump: 83 | every: 40000 84 | keep: 1 85 | eval: 86 | every: 40000 87 | keep: 1 88 | 89 | logging: 90 | freq: 10 91 | wandb: # specify monitoring if needed 92 | project: lingua 93 | entity: new_attention 94 | resume: allow 95 | id: *name 96 | 97 | # sync eval 98 | eval: 99 | ppl_files: 100 | - /slim_pajama/valid/arxiv/data.chunk.00.jsonl 101 | - /slim_pajama/valid/book/data.chunk.00.jsonl 102 | - /slim_pajama/valid/c4/data.chunk.00.jsonl 103 | - /slim_pajama/valid/commoncrawl/data.chunk.00.jsonl 104 | - /slim_pajama/valid/github/data.chunk.00.jsonl 105 | - /slim_pajama/valid/stackexchange/data.chunk.00.jsonl 106 | - /slim_pajama/valid/wikipedia/data.chunk.00.jsonl 107 | ppl_seq_len: 2048 108 | ppl_batch_size: 4 109 | ppl_n_batches: 256 110 | generator: 111 | max_tokens: 2048 112 | dtype: bf16 113 | -------------------------------------------------------------------------------- /projects/mta/configs/1B_talking_heads.yaml: -------------------------------------------------------------------------------- 1 | # use local train.py 8 nodes 2 | 3 | dump_dir: /checkpoints/1B_talking_heads 4 | name: &name "1B_talking_heads" 5 | steps: 200000 6 | grad_acc_steps: 1 7 | probe_freq: 100 8 | 9 | seed: 777 10 | optim: 11 | lr: 0.00015 12 | weight_decay: 0.05 13 | warmup: 187 14 | scheduler: linear 15 | beta2: 0.98 16 | lr_min_ratio: 0.0 17 | 18 | distributed: 19 | fsdp_type: full_shard 20 | compile: true 21 | model_dtype: bf16 22 | matmul_allow_tf32: false 23 | selective_activation_checkpointing: false 24 | tp_size: 1 25 | 26 | model: 27 | dim: 2048 28 | n_layers: 24 29 | n_heads: 16 30 | rope_theta: 100_000 31 | ffn_dim_multiplier: 1.0 32 | multiple_of: 256 33 | weight_tying: true 34 | mta: 35 | use_mta: true 36 | query_kernel_size: null 37 | head_kernel_size: 16 38 | init_method: identity 39 | pre_sm_linear_head: true 40 | 41 | data: 42 | root_dir: /slim_pajama/train/ 43 | sources: 44 | commoncrawl: 0.726 45 | c4: 0.081 46 | github: 0.049 47 | book: 0.021 48 | arxiv: 0.023 49 | wikipedia: 0.05 50 | stackexchange: 0.05 51 | # bsx = 2048 * 8 * nnodes * b_s 52 | # 4 nodes, 4 b_s = 0.25M tok per batch 53 | batch_size: 4 54 | prefetch_size: 1024 55 | seq_len: 2048 56 | n_views: 2 57 | load_async: true 58 | tokenizer: 59 | name: tiktoken 60 | path: /Llama-3.1-70B-Instruct/original/tokenizer.model 61 | 62 | profiling: 63 | run: true 64 | mem_warmup: 0 65 | mem_steps: 4 66 | profile_warmup: 10 67 | profile_steps: 4 68 | 69 | checkpoint: 70 | dump: 71 | every: 40000 72 | keep: 1 73 | eval: 74 | every: 40000 75 | keep: 1 76 | 77 | logging: 78 | freq: 10 79 | wandb: # specify monitoring if needed 80 | project: lingua 81 | entity: new_attention 82 | resume: allow 83 | id: *name 84 | 85 | # sync eval 86 | eval: 87 | ppl_files: 88 | - /slim_pajama/valid/arxiv/data.chunk.00.jsonl 89 | - /slim_pajama/valid/book/data.chunk.00.jsonl 90 | - /slim_pajama/valid/c4/data.chunk.00.jsonl 91 | - /slim_pajama/valid/commoncrawl/data.chunk.00.jsonl 92 | - /slim_pajama/valid/github/data.chunk.00.jsonl 93 | - /slim_pajama/valid/stackexchange/data.chunk.00.jsonl 94 | - /slim_pajama/valid/wikipedia/data.chunk.00.jsonl 95 | ppl_seq_len: 2048 96 | ppl_batch_size: 4 97 | ppl_n_batches: 256 98 | generator: 99 | max_tokens: 2048 100 | dtype: bf16 101 | -------------------------------------------------------------------------------- /projects/mta/configs/300M_baseline.yaml: -------------------------------------------------------------------------------- 1 | # use local train.py 2 | 3 | dump_dir: /checkpoints/300M_baseline 4 | name: &name "300M_baseline" 5 | steps: 400000 6 | grad_acc_steps: 1 7 | probe_freq: 100 8 | 9 | seed: 777 10 | optim: 11 | lr: 0.00015 12 | weight_decay: 0.05 13 | warmup: 375 14 | scheduler: linear 15 | beta2: 0.98 16 | lr_min_ratio: 0.0 17 | 18 | distributed: 19 | fsdp_type: full_shard 20 | compile: true 21 | model_dtype: bf16 22 | matmul_allow_tf32: false 23 | selective_activation_checkpointing: false 24 | tp_size: 1 25 | 26 | model: 27 | dim: 1024 28 | n_layers: 20 29 | n_heads: 16 30 | rope_theta: 100_000 31 | ffn_dim_multiplier: 1.0 32 | multiple_of: 256 33 | weight_tying: true 34 | mta: 35 | use_mta: false 36 | 37 | data: 38 | root_dir: /slim_pajama/train/ 39 | sources: 40 | commoncrawl: 0.726 41 | c4: 0.081 42 | github: 0.049 43 | book: 0.021 44 | arxiv: 0.023 45 | wikipedia: 0.05 46 | stackexchange: 0.05 47 | # bsx = 2048 * 8 * nnodes * b_s 48 | # 4 nodes, 4 b_s = 0.25M tok per batch 49 | batch_size: 4 50 | prefetch_size: 1024 51 | seq_len: 2048 52 | n_views: 2 53 | load_async: true 54 | tokenizer: 55 | name: tiktoken 56 | path: /Llama-3.1-70B-Instruct/original/tokenizer.model 57 | 58 | profiling: 59 | run: true 60 | mem_warmup: 0 61 | mem_steps: 4 62 | profile_warmup: 10 63 | profile_steps: 4 64 | 65 | checkpoint: 66 | dump: 67 | every: 40000 68 | keep: 1 69 | eval: 70 | every: 40000 71 | keep: 1 72 | 73 | logging: 74 | freq: 10 75 | wandb: # specify monitoring if needed 76 | project: lingua 77 | entity: new_attention 78 | resume: allow 79 | id: *name 80 | 81 | # sync eval 82 | eval: 83 | ppl_files: 84 | - /slim_pajama/valid/arxiv/data.chunk.00.jsonl 85 | - /slim_pajama/valid/book/data.chunk.00.jsonl 86 | - /slim_pajama/valid/c4/data.chunk.00.jsonl 87 | - /slim_pajama/valid/commoncrawl/data.chunk.00.jsonl 88 | - /slim_pajama/valid/github/data.chunk.00.jsonl 89 | - /slim_pajama/valid/stackexchange/data.chunk.00.jsonl 90 | - /slim_pajama/valid/wikipedia/data.chunk.00.jsonl 91 | ppl_seq_len: 2048 92 | ppl_batch_size: 4 93 | ppl_n_batches: 256 94 | generator: 95 | max_tokens: 2048 96 | dtype: bf16 97 | -------------------------------------------------------------------------------- /projects/mta/configs/300M_mta.yaml: -------------------------------------------------------------------------------- 1 | # use local train.py 2 | 3 | dump_dir: /checkpoints/300M_mta_4 4 | name: &name "300M_mta_4_h1002" 5 | steps: 400000 6 | grad_acc_steps: 1 7 | probe_freq: 100 8 | 9 | seed: 777 10 | optim: 11 | lr: 0.00015 12 | weight_decay: 0.05 13 | warmup: 375 14 | scheduler: linear 15 | beta2: 0.98 16 | lr_min_ratio: 0.0 17 | 18 | distributed: 19 | fsdp_type: full_shard 20 | compile: true 21 | model_dtype: bf16 22 | matmul_allow_tf32: false 23 | selective_activation_checkpointing: false 24 | tp_size: 1 25 | 26 | model: 27 | dim: 1024 28 | n_layers: 20 29 | n_heads: 16 30 | rope_theta: 100_000 31 | ffn_dim_multiplier: 1.0 32 | multiple_of: 256 33 | weight_tying: true 34 | mta: 35 | use_mta: true 36 | # before sm 37 | pre_sm_linear_head: true 38 | query_kernel_size: 6 39 | key_kernel_size: 11 40 | # after sm 41 | head_kernel_size: 16 42 | after_sm_query_kernel_size: 6 43 | after_sm_key_kernel_size: 11 44 | # common 45 | init_method: identity 46 | pad_key: "both" 47 | mta_layers: "2,6,10,14,18" 48 | group_norm: true 49 | layer_norm_rescale: false 50 | add_gating: true 51 | gate_1d: true 52 | 53 | data: 54 | root_dir: /slim_pajama/train/ 55 | sources: 56 | commoncrawl: 0.726 57 | c4: 0.081 58 | github: 0.049 59 | book: 0.021 60 | arxiv: 0.023 61 | wikipedia: 0.05 62 | stackexchange: 0.05 63 | # bsx = 2048 * 8 * nnodes * b_s 64 | # 4 nodes, 4 b_s = 0.25M tok per batch 65 | batch_size: 4 66 | prefetch_size: 1024 67 | seq_len: 2048 68 | n_views: 2 69 | load_async: true 70 | tokenizer: 71 | name: tiktoken 72 | path: /Llama-3.1-70B-Instruct/original/tokenizer.model 73 | 74 | profiling: 75 | run: true 76 | mem_warmup: 0 77 | mem_steps: 4 78 | profile_warmup: 10 79 | profile_steps: 4 80 | 81 | checkpoint: 82 | dump: 83 | every: 40000 84 | keep: 1 85 | eval: 86 | every: 40000 87 | keep: 1 88 | 89 | logging: 90 | freq: 10 91 | wandb: # specify monitoring if needed 92 | project: lingua 93 | entity: new_attention 94 | resume: allow 95 | id: *name 96 | 97 | # sync eval 98 | eval: 99 | ppl_files: 100 | - /slim_pajama/valid/arxiv/data.chunk.00.jsonl 101 | - /slim_pajama/valid/book/data.chunk.00.jsonl 102 | - /slim_pajama/valid/c4/data.chunk.00.jsonl 103 | - /slim_pajama/valid/commoncrawl/data.chunk.00.jsonl 104 | - /slim_pajama/valid/github/data.chunk.00.jsonl 105 | - /slim_pajama/valid/stackexchange/data.chunk.00.jsonl 106 | - /slim_pajama/valid/wikipedia/data.chunk.00.jsonl 107 | ppl_seq_len: 2048 108 | ppl_batch_size: 4 109 | ppl_n_batches: 256 110 | generator: 111 | max_tokens: 2048 112 | dtype: bf16 113 | -------------------------------------------------------------------------------- /projects/mta/configs/300M_talking_heads.yaml: -------------------------------------------------------------------------------- 1 | # use local train.py 2 | 3 | dump_dir: /checkpoints/300M_talking_heads 4 | name: &name "300M_talking_heads" 5 | steps: 400000 6 | grad_acc_steps: 1 7 | probe_freq: 100 8 | 9 | seed: 777 10 | optim: 11 | lr: 0.00015 12 | weight_decay: 0.05 13 | warmup: 375 14 | scheduler: linear 15 | beta2: 0.98 16 | lr_min_ratio: 0.0 17 | 18 | distributed: 19 | fsdp_type: full_shard 20 | compile: true 21 | model_dtype: bf16 22 | matmul_allow_tf32: false 23 | selective_activation_checkpointing: false 24 | tp_size: 1 25 | 26 | model: 27 | dim: 1024 28 | n_layers: 20 29 | n_heads: 16 30 | rope_theta: 100_000 31 | ffn_dim_multiplier: 1.0 32 | multiple_of: 256 33 | weight_tying: true 34 | mta: 35 | use_mta: true 36 | query_kernel_size: null 37 | head_kernel_size: 16 38 | init_method: identity 39 | pre_sm_linear_head: true 40 | 41 | data: 42 | root_dir: /slim_pajama/train/ 43 | sources: 44 | commoncrawl: 0.726 45 | c4: 0.081 46 | github: 0.049 47 | book: 0.021 48 | arxiv: 0.023 49 | wikipedia: 0.05 50 | stackexchange: 0.05 51 | # bsx = 2048 * 8 * nnodes * b_s 52 | # 4 nodes, 4 b_s = 0.25M tok per batch 53 | batch_size: 4 54 | prefetch_size: 1024 55 | seq_len: 2048 56 | n_views: 2 57 | load_async: true 58 | tokenizer: 59 | name: tiktoken 60 | path: /Llama-3.1-70B-Instruct/original/tokenizer.model 61 | 62 | profiling: 63 | run: true 64 | mem_warmup: 0 65 | mem_steps: 4 66 | profile_warmup: 10 67 | profile_steps: 4 68 | 69 | checkpoint: 70 | dump: 71 | every: 40000 72 | keep: 1 73 | eval: 74 | every: 40000 75 | keep: 1 76 | 77 | logging: 78 | freq: 10 79 | wandb: # specify monitoring if needed 80 | project: lingua 81 | entity: new_attention 82 | resume: allow 83 | id: *name 84 | 85 | # sync eval 86 | eval: 87 | ppl_files: 88 | - /slim_pajama/valid/arxiv/data.chunk.00.jsonl 89 | - /slim_pajama/valid/book/data.chunk.00.jsonl 90 | - /slim_pajama/valid/c4/data.chunk.00.jsonl 91 | - /slim_pajama/valid/commoncrawl/data.chunk.00.jsonl 92 | - /slim_pajama/valid/github/data.chunk.00.jsonl 93 | - /slim_pajama/valid/stackexchange/data.chunk.00.jsonl 94 | - /slim_pajama/valid/wikipedia/data.chunk.00.jsonl 95 | ppl_seq_len: 2048 96 | ppl_batch_size: 4 97 | ppl_n_batches: 256 98 | generator: 99 | max_tokens: 2048 100 | dtype: bf16 101 | -------------------------------------------------------------------------------- /projects/mta/configs/550M_baseline.yaml: -------------------------------------------------------------------------------- 1 | # use local train.py 2 | 3 | dump_dir: /checkpoints/550M_baseline 4 | name: &name "550M_baseline_h100" 5 | steps: 400000 6 | grad_acc_steps: 1 7 | probe_freq: 100 8 | 9 | seed: 777 10 | optim: 11 | lr: 0.00015 12 | weight_decay: 0.05 13 | warmup: 375 14 | scheduler: linear 15 | beta2: 0.98 16 | lr_min_ratio: 0.0 17 | 18 | distributed: 19 | fsdp_type: full_shard 20 | compile: true 21 | model_dtype: bf16 22 | matmul_allow_tf32: false 23 | selective_activation_checkpointing: false 24 | tp_size: 1 25 | 26 | model: 27 | dim: 1280 28 | n_layers: 24 29 | n_heads: 10 30 | rope_theta: 100_000 31 | ffn_dim_multiplier: 1.0 32 | multiple_of: 256 33 | weight_tying: true 34 | mta: 35 | use_mta: false 36 | 37 | data: 38 | root_dir: /slim_pajama/train/ 39 | sources: 40 | commoncrawl: 0.726 41 | c4: 0.081 42 | github: 0.049 43 | book: 0.021 44 | arxiv: 0.023 45 | wikipedia: 0.05 46 | stackexchange: 0.05 47 | # bsx = 2048 * 8 * nnodes * b_s 48 | # 4 nodes, 4 b_s = 0.25M tok per batch 49 | batch_size: 4 50 | prefetch_size: 1024 51 | seq_len: 2048 52 | n_views: 2 53 | load_async: true 54 | tokenizer: 55 | name: tiktoken 56 | path: /Llama-3.1-70B-Instruct/original/tokenizer.model 57 | 58 | profiling: 59 | run: true 60 | mem_warmup: 0 61 | mem_steps: 4 62 | profile_warmup: 10 63 | profile_steps: 4 64 | 65 | checkpoint: 66 | dump: 67 | every: 40000 68 | keep: 1 69 | eval: 70 | every: 40000 71 | keep: 1 72 | 73 | logging: 74 | freq: 10 75 | wandb: # specify monitoring if needed 76 | project: lingua 77 | entity: new_attention 78 | resume: allow 79 | id: *name 80 | 81 | # sync eval 82 | eval: 83 | ppl_files: 84 | - /slim_pajama/valid/arxiv/data.chunk.00.jsonl 85 | - /slim_pajama/valid/book/data.chunk.00.jsonl 86 | - /slim_pajama/valid/c4/data.chunk.00.jsonl 87 | - /slim_pajama/valid/commoncrawl/data.chunk.00.jsonl 88 | - /slim_pajama/valid/github/data.chunk.00.jsonl 89 | - /slim_pajama/valid/stackexchange/data.chunk.00.jsonl 90 | - /slim_pajama/valid/wikipedia/data.chunk.00.jsonl 91 | ppl_seq_len: 2048 92 | ppl_batch_size: 4 93 | ppl_n_batches: 256 94 | generator: 95 | max_tokens: 2048 96 | dtype: bf16 97 | -------------------------------------------------------------------------------- /projects/mta/configs/550M_mta.yaml: -------------------------------------------------------------------------------- 1 | # use local train.py 2 | 3 | dump_dir: /checkpoints/550M_mta 4 | name: &name "550M_mta" 5 | steps: 400000 6 | grad_acc_steps: 1 7 | probe_freq: 100 8 | 9 | seed: 777 10 | optim: 11 | lr: 0.00015 12 | weight_decay: 0.05 13 | warmup: 375 14 | scheduler: linear 15 | beta2: 0.98 16 | lr_min_ratio: 0.0 17 | 18 | distributed: 19 | fsdp_type: full_shard 20 | compile: true 21 | model_dtype: bf16 22 | matmul_allow_tf32: false 23 | selective_activation_checkpointing: false 24 | tp_size: 1 25 | 26 | model: 27 | dim: 1280 28 | n_layers: 24 29 | n_heads: 10 30 | rope_theta: 100_000 31 | ffn_dim_multiplier: 1.0 32 | multiple_of: 256 33 | weight_tying: true 34 | mta: 35 | use_mta: true 36 | # before sm 37 | pre_sm_linear_head: true 38 | query_kernel_size: 6 39 | key_kernel_size: 11 40 | # after sm 41 | head_kernel_size: 10 42 | after_sm_query_kernel_size: 6 43 | after_sm_key_kernel_size: 11 44 | # common 45 | init_method: identity 46 | pad_key: "both" 47 | mta_layers: "2,6,10,14,18,22" 48 | group_norm: true 49 | layer_norm_rescale: false 50 | add_gating: true 51 | gate_1d: true 52 | 53 | data: 54 | root_dir: /slim_pajama/train/ 55 | sources: 56 | commoncrawl: 0.726 57 | c4: 0.081 58 | github: 0.049 59 | book: 0.021 60 | arxiv: 0.023 61 | wikipedia: 0.05 62 | stackexchange: 0.05 63 | # bsx = 2048 * 8 * nnodes * b_s 64 | # 4 nodes, 4 b_s = 0.25M tok per batch 65 | batch_size: 4 66 | prefetch_size: 1024 67 | seq_len: 2048 68 | n_views: 2 69 | load_async: true 70 | tokenizer: 71 | name: tiktoken 72 | path: /Llama-3.1-70B-Instruct/original/tokenizer.model 73 | 74 | profiling: 75 | run: true 76 | mem_warmup: 0 77 | mem_steps: 4 78 | profile_warmup: 10 79 | profile_steps: 4 80 | 81 | checkpoint: 82 | dump: 83 | every: 40000 84 | keep: 1 85 | eval: 86 | every: 40000 87 | keep: 1 88 | 89 | logging: 90 | freq: 10 91 | wandb: # specify monitoring if needed 92 | project: lingua 93 | entity: new_attention 94 | resume: allow 95 | id: *name 96 | 97 | # sync eval 98 | eval: 99 | ppl_files: 100 | - /slim_pajama/valid/arxiv/data.chunk.00.jsonl 101 | - /slim_pajama/valid/book/data.chunk.00.jsonl 102 | - /slim_pajama/valid/c4/data.chunk.00.jsonl 103 | - /slim_pajama/valid/commoncrawl/data.chunk.00.jsonl 104 | - /slim_pajama/valid/github/data.chunk.00.jsonl 105 | - /slim_pajama/valid/stackexchange/data.chunk.00.jsonl 106 | - /slim_pajama/valid/wikipedia/data.chunk.00.jsonl 107 | ppl_seq_len: 2048 108 | ppl_batch_size: 4 109 | ppl_n_batches: 256 110 | generator: 111 | max_tokens: 2048 112 | dtype: bf16 113 | -------------------------------------------------------------------------------- /projects/mta/configs/550M_talking_heads.yaml: -------------------------------------------------------------------------------- 1 | # use local train.py 2 | 3 | dump_dir: /checkpoints/550M_talking_heads 4 | name: &name "550M_talking_heads" 5 | steps: 400000 6 | grad_acc_steps: 1 7 | probe_freq: 100 8 | 9 | seed: 777 10 | optim: 11 | lr: 0.00015 12 | weight_decay: 0.05 13 | warmup: 375 14 | scheduler: linear 15 | beta2: 0.98 16 | lr_min_ratio: 0.0 17 | 18 | distributed: 19 | fsdp_type: full_shard 20 | compile: true 21 | model_dtype: bf16 22 | matmul_allow_tf32: false 23 | selective_activation_checkpointing: false 24 | tp_size: 1 25 | 26 | model: 27 | dim: 1280 28 | n_layers: 24 29 | n_heads: 10 30 | rope_theta: 100_000 31 | ffn_dim_multiplier: 1.0 32 | multiple_of: 256 33 | weight_tying: true 34 | mta: 35 | use_mta: true 36 | query_kernel_size: null 37 | head_kernel_size: 10 38 | init_method: identity 39 | pre_sm_linear_head: true 40 | 41 | data: 42 | root_dir: /slim_pajama/train/ 43 | sources: 44 | commoncrawl: 0.726 45 | c4: 0.081 46 | github: 0.049 47 | book: 0.021 48 | arxiv: 0.023 49 | wikipedia: 0.05 50 | stackexchange: 0.05 51 | # bsx = 2048 * 8 * nnodes * b_s 52 | # 4 nodes, 4 b_s = 0.25M tok per batch 53 | batch_size: 4 54 | prefetch_size: 1024 55 | seq_len: 2048 56 | n_views: 2 57 | load_async: true 58 | tokenizer: 59 | name: tiktoken 60 | path: /Llama-3.1-70B-Instruct/original/tokenizer.model 61 | 62 | profiling: 63 | run: true 64 | mem_warmup: 0 65 | mem_steps: 4 66 | profile_warmup: 10 67 | profile_steps: 4 68 | 69 | checkpoint: 70 | dump: 71 | every: 40000 72 | keep: 1 73 | eval: 74 | every: 40000 75 | keep: 1 76 | 77 | logging: 78 | freq: 10 79 | wandb: # specify monitoring if needed 80 | project: lingua 81 | entity: new_attention 82 | resume: allow 83 | id: *name 84 | 85 | # sync eval 86 | eval: 87 | ppl_files: 88 | - /slim_pajama/valid/arxiv/data.chunk.00.jsonl 89 | - /slim_pajama/valid/book/data.chunk.00.jsonl 90 | - /slim_pajama/valid/c4/data.chunk.00.jsonl 91 | - /slim_pajama/valid/commoncrawl/data.chunk.00.jsonl 92 | - /slim_pajama/valid/github/data.chunk.00.jsonl 93 | - /slim_pajama/valid/stackexchange/data.chunk.00.jsonl 94 | - /slim_pajama/valid/wikipedia/data.chunk.00.jsonl 95 | ppl_seq_len: 2048 96 | ppl_batch_size: 4 97 | ppl_n_batches: 256 98 | generator: 99 | max_tokens: 2048 100 | dtype: bf16 101 | -------------------------------------------------------------------------------- /projects/mta/configs/830M_baseline.yaml: -------------------------------------------------------------------------------- 1 | # use local train.py 2 | 3 | dump_dir: /checkpoints/830M_baseline 4 | name: &name "830M_baseline_h1001" 5 | steps: 400000 6 | grad_acc_steps: 1 7 | probe_freq: 100 8 | 9 | seed: 777 10 | optim: 11 | lr: 0.00015 12 | weight_decay: 0.05 13 | warmup: 375 14 | scheduler: linear 15 | beta2: 0.98 16 | lr_min_ratio: 0.0 17 | 18 | distributed: 19 | fsdp_type: full_shard 20 | compile: true 21 | model_dtype: bf16 22 | matmul_allow_tf32: false 23 | selective_activation_checkpointing: false 24 | tp_size: 1 25 | 26 | model: 27 | dim: 1536 28 | n_layers: 24 29 | n_heads: 16 30 | rope_theta: 100_000 31 | ffn_dim_multiplier: 1.0 32 | multiple_of: 256 33 | weight_tying: true 34 | mta: 35 | use_mta: false 36 | 37 | data: 38 | root_dir: /slim_pajama/train/ 39 | sources: 40 | commoncrawl: 0.726 41 | c4: 0.081 42 | github: 0.049 43 | book: 0.021 44 | arxiv: 0.023 45 | wikipedia: 0.05 46 | stackexchange: 0.05 47 | # bsx = 2048 * 8 * nnodes * b_s 48 | # 4 nodes, 4 b_s = 0.25M tok per batch 49 | batch_size: 4 50 | prefetch_size: 1024 51 | seq_len: 2048 52 | n_views: 2 53 | load_async: true 54 | tokenizer: 55 | name: tiktoken 56 | path: /Llama-3.1-70B-Instruct/original/tokenizer.model 57 | 58 | profiling: 59 | run: true 60 | mem_warmup: 0 61 | mem_steps: 4 62 | profile_warmup: 10 63 | profile_steps: 4 64 | 65 | checkpoint: 66 | dump: 67 | every: 10000 68 | keep: -1 69 | eval: 70 | every: 10000 71 | keep: 1 72 | 73 | logging: 74 | freq: 10 75 | wandb: # specify monitoring if needed 76 | project: lingua 77 | entity: new_attention 78 | resume: allow 79 | id: *name 80 | 81 | # sync eval 82 | eval: 83 | ppl_files: 84 | - /slim_pajama/valid/arxiv/data.chunk.00.jsonl 85 | - /slim_pajama/valid/book/data.chunk.00.jsonl 86 | - /slim_pajama/valid/c4/data.chunk.00.jsonl 87 | - /slim_pajama/valid/commoncrawl/data.chunk.00.jsonl 88 | - /slim_pajama/valid/github/data.chunk.00.jsonl 89 | - /slim_pajama/valid/stackexchange/data.chunk.00.jsonl 90 | - /slim_pajama/valid/wikipedia/data.chunk.00.jsonl 91 | ppl_seq_len: 2048 92 | ppl_batch_size: 4 93 | ppl_n_batches: 256 94 | generator: 95 | max_tokens: 2048 96 | dtype: bf16 97 | -------------------------------------------------------------------------------- /projects/mta/configs/830M_mta.yaml: -------------------------------------------------------------------------------- 1 | # use local train.py 2 | 3 | dump_dir: /checkpoints/830M_mta_gate1d_seed42 4 | name: &name "830M_mta_gate1d_seed42" 5 | steps: 400000 6 | grad_acc_steps: 1 7 | probe_freq: 100 8 | 9 | seed: 42 #777 10 | optim: 11 | lr: 0.00015 12 | weight_decay: 0.05 13 | warmup: 375 14 | scheduler: linear 15 | beta2: 0.98 16 | lr_min_ratio: 0.0 17 | 18 | distributed: 19 | fsdp_type: full_shard 20 | compile: true 21 | model_dtype: bf16 22 | matmul_allow_tf32: false 23 | selective_activation_checkpointing: false 24 | tp_size: 1 25 | 26 | model: 27 | dim: 1536 28 | n_layers: 24 29 | n_heads: 16 30 | rope_theta: 100_000 31 | ffn_dim_multiplier: 1.0 32 | multiple_of: 256 33 | weight_tying: true 34 | mta: 35 | use_mta: true 36 | # before sm 37 | pre_sm_linear_head: true 38 | query_kernel_size: 6 39 | key_kernel_size: 11 40 | # after sm 41 | head_kernel_size: 16 42 | after_sm_query_kernel_size: 6 43 | after_sm_key_kernel_size: 11 44 | # common 45 | init_method: identity 46 | pad_key: "both" 47 | mta_layers: "2,6,10,14,18,22" 48 | group_norm: true 49 | layer_norm_rescale: false 50 | add_gating: true 51 | gate_1d: true 52 | 53 | data: 54 | root_dir: /slim_pajama/train/ 55 | sources: 56 | commoncrawl: 0.726 57 | c4: 0.081 58 | github: 0.049 59 | book: 0.021 60 | arxiv: 0.023 61 | wikipedia: 0.05 62 | stackexchange: 0.05 63 | # bsx = 2048 * 8 * nnodes * b_s 64 | # 4 nodes, 4 b_s = 0.25M tok per batch 65 | batch_size: 4 66 | prefetch_size: 1024 67 | seq_len: 2048 68 | n_views: 2 69 | load_async: true 70 | tokenizer: 71 | name: tiktoken 72 | path: /Llama-3.1-70B-Instruct/original/tokenizer.model 73 | 74 | profiling: 75 | run: true 76 | mem_warmup: 0 77 | mem_steps: 4 78 | profile_warmup: 10 79 | profile_steps: 4 80 | 81 | checkpoint: 82 | dump: 83 | every: 10000 84 | keep: 1 85 | eval: 86 | every: 10000 87 | keep: 1 88 | 89 | logging: 90 | freq: 10 91 | wandb: # specify monitoring if needed 92 | project: lingua 93 | entity: new_attention 94 | resume: allow 95 | id: *name 96 | 97 | # sync eval 98 | eval: 99 | ppl_files: 100 | - /slim_pajama/valid/arxiv/data.chunk.00.jsonl 101 | - /slim_pajama/valid/book/data.chunk.00.jsonl 102 | - /slim_pajama/valid/c4/data.chunk.00.jsonl 103 | - /slim_pajama/valid/commoncrawl/data.chunk.00.jsonl 104 | - /slim_pajama/valid/github/data.chunk.00.jsonl 105 | - /slim_pajama/valid/stackexchange/data.chunk.00.jsonl 106 | - /slim_pajama/valid/wikipedia/data.chunk.00.jsonl 107 | ppl_seq_len: 2048 108 | ppl_batch_size: 4 109 | ppl_n_batches: 256 110 | generator: 111 | max_tokens: 2048 112 | dtype: bf16 113 | -------------------------------------------------------------------------------- /projects/mta/configs/830M_talking_heads.yaml: -------------------------------------------------------------------------------- 1 | # use local train.py 2 | 3 | dump_dir: /checkpoints/830M_talking_heads_h1001 4 | name: &name "830M_talking_heads_h1001" 5 | steps: 400000 6 | grad_acc_steps: 1 7 | probe_freq: 100 8 | 9 | seed: 777 10 | optim: 11 | lr: 0.00015 12 | weight_decay: 0.05 13 | warmup: 375 14 | scheduler: linear 15 | beta2: 0.98 16 | lr_min_ratio: 0.0 17 | 18 | distributed: 19 | fsdp_type: full_shard 20 | compile: true 21 | model_dtype: bf16 22 | matmul_allow_tf32: false 23 | selective_activation_checkpointing: false 24 | tp_size: 1 25 | 26 | model: 27 | dim: 1536 28 | n_layers: 24 29 | n_heads: 16 30 | rope_theta: 100_000 31 | ffn_dim_multiplier: 1.0 32 | multiple_of: 256 33 | weight_tying: true 34 | mta: 35 | use_mta: true 36 | query_kernel_size: null 37 | head_kernel_size: 16 38 | init_method: identity 39 | pre_sm_linear_head: true 40 | 41 | data: 42 | root_dir: /slim_pajama/train/ 43 | sources: 44 | commoncrawl: 0.726 45 | c4: 0.081 46 | github: 0.049 47 | book: 0.021 48 | arxiv: 0.023 49 | wikipedia: 0.05 50 | stackexchange: 0.05 51 | # bsx = 2048 * 8 * nnodes * b_s 52 | # 4 nodes, 4 b_s = 0.25M tok per batch 53 | batch_size: 4 54 | prefetch_size: 1024 55 | seq_len: 2048 56 | n_views: 2 57 | load_async: true 58 | tokenizer: 59 | name: tiktoken 60 | path: /Llama-3.1-70B-Instruct/original/tokenizer.model 61 | 62 | profiling: 63 | run: true 64 | mem_warmup: 0 65 | mem_steps: 4 66 | profile_warmup: 10 67 | profile_steps: 4 68 | 69 | checkpoint: 70 | dump: 71 | every: 10000 72 | keep: 1 73 | eval: 74 | every: 10000 75 | keep: 1 76 | 77 | logging: 78 | freq: 10 79 | wandb: # specify monitoring if needed 80 | project: lingua 81 | entity: new_attention 82 | resume: allow 83 | id: *name 84 | 85 | # sync eval 86 | eval: 87 | ppl_files: 88 | - /slim_pajama/valid/arxiv/data.chunk.00.jsonl 89 | - /slim_pajama/valid/book/data.chunk.00.jsonl 90 | - /slim_pajama/valid/c4/data.chunk.00.jsonl 91 | - /slim_pajama/valid/commoncrawl/data.chunk.00.jsonl 92 | - /slim_pajama/valid/github/data.chunk.00.jsonl 93 | - /slim_pajama/valid/stackexchange/data.chunk.00.jsonl 94 | - /slim_pajama/valid/wikipedia/data.chunk.00.jsonl 95 | ppl_seq_len: 2048 96 | ppl_batch_size: 4 97 | ppl_n_batches: 256 98 | generator: 99 | max_tokens: 2048 100 | dtype: bf16 101 | -------------------------------------------------------------------------------- /projects/mta/figures/attn_schema.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/RAM/c567dfb73c0bc60c7ca1f114b72cd3bf7d9442bd/projects/mta/figures/attn_schema.png -------------------------------------------------------------------------------- /projects/mta/tokenizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | from typing import List, Optional, Tuple 9 | 10 | from lingua import tokenizer 11 | 12 | 13 | class CharcterTokenizer(tokenizer.Tokenizer): 14 | def __init__(self): 15 | self.mapping = self.create_alphabet( 16 | vocab_size=26, 17 | block_separator=".", 18 | query_separator="#", 19 | ) 20 | 21 | self.bos_id = len(self.mapping) 22 | self.mapping[""] = self.bos_id 23 | self.eos_id = len(self.mapping) 24 | self.mapping[""] = self.eos_id 25 | i = 0 26 | while len(self.mapping) % 8 != 0: 27 | self.mapping[f"<|reserved_special_token_{i}|>"] = len(self.mapping) 28 | i += 1 29 | self.n_words = len(self.mapping) 30 | 31 | self.reversed_mapping = {v: k for k, v in self.mapping.items()} 32 | 33 | def create_alphabet( 34 | self, vocab_size: int, block_separator: str, query_separator: str 35 | ): 36 | alphabet = {} 37 | i = 0 38 | j = 0 39 | while len(alphabet) < vocab_size: 40 | new_char = chr(ord("a") + i) 41 | if new_char not in [block_separator, query_separator]: 42 | assert new_char not in alphabet.keys() 43 | alphabet[new_char] = j 44 | j += 1 45 | i += 1 46 | alphabet[block_separator] = j 47 | alphabet[query_separator] = j + 1 48 | alphabet["_"] = j + 2 # normally should be an argument 49 | return alphabet 50 | 51 | def encode(self, s: str, add_bos: bool, add_eos: bool): 52 | assert type(s) is str, s 53 | assert " " not in s, s 54 | t = [self.mapping[char] for char in s] 55 | if add_bos: 56 | t.insert(0, self.bos_id) 57 | if add_eos: 58 | t.append(self.eos_id) 59 | return t 60 | 61 | def decode(self, tokens: List[int]): 62 | return "".join([self.reversed_mapping[tok_id] for tok_id in tokens]) 63 | 64 | def get_token_offsets( 65 | self, text: str, tokens: Optional[List[int]] = None 66 | ) -> Tuple[List[str], List[int]]: 67 | if tokens is None: 68 | tokens = self.encode(text, False, False) 69 | 70 | decoded_chars, offsets = [], [] 71 | char_pos = 0 72 | for token in tokens: 73 | if token < self.bos_id: 74 | char = self.reversed_mapping[token] 75 | decoded_chars.append(char) 76 | offsets.append(char_pos) 77 | char_pos += len(char) 78 | 79 | return decoded_chars, offsets 80 | 81 | 82 | def build_tokenizer(name: str, path: Optional[str] = None) -> tokenizer.Tokenizer: 83 | if name == "char": 84 | return CharcterTokenizer() 85 | else: 86 | return tokenizer.build_tokenizer(name, path) 87 | -------------------------------------------------------------------------------- /projects/sd-ra-it/README.md: -------------------------------------------------------------------------------- 1 | # Training LLMs on Self-generated Demonstrations 2 | 3 | Scripts and configs for replicating the experiments from ["Post-training an LLM for RAG? Train on Self-Generated Demonstrations"](https://arxiv.org/abs/2502.10596). 4 | 5 | 6 | 7 | You may cite our work as 8 | ```bibtex 9 | @misc{finlayson2025posttraining, 10 | title={Post-training an LLM for RAG? Train on Self-Generated Demonstrations}, 11 | author={Matthew Finlayson and Ilia Kulikov and Daniel M. Bikel and Barlas Oguz and Xilun Chen and Aasish Pappu}, 12 | year={2025}, 13 | primaryClass={cs.CL}, 14 | } 15 | ``` 16 | 17 | ## Generating self-demos. 18 | 19 | 1. Obtain training data. We use the training data from the [RA-DIT paper](https://arxiv.org/abs/2310.01352), placed in directories `data/70b/train/tasks.jsonl` and `data/70b/train/oasst.jsonl` with subsampling weights of 0.9 and 0.1. 20 | 21 | 2. Generate prompts. Use `scripts/prompt_optimization.py`, e.g., 22 | ```sh 23 | python scripts/prompt_optimization.py \ 24 | --dataset_filename "data/70b/tasks.jsonl" \ 25 | --model "Meta-Llama-3-70B-Instruct" \ 26 | --outfile "data/prompts/base.json" \ 27 | --logfile "70B_prompt_optimization.log" \ 28 | --eval_example_count 30 \ 29 | --train_example_count 30 \ 30 | --topk 5 \ 31 | --shuffle_window 400 \ 32 | --beam_size 12 \ 33 | --tensor-parallel-size=2 \ 34 | --chat \ 35 | --steps 5 \ 36 | --rag 37 | ``` 38 | 3. Generate self-demos with `scripts/create_self_demo_train_set.sh` 39 | ```sh 40 | bash scripts/create_self_demo_train_set.sh tasks Meta-Llama-3-70B-Instruct 41 | bash scripts/create_self_demo_train_set.sh oasst Meta-Llama-3-70B-Instruct 42 | ``` 43 | 44 | ## SFT and DPO training with `fairseq2` 45 | 46 | To train a DPO model on self-demonstrations using fairseq2: 47 | 48 | ```sh 49 | srun fairseq2 lm preference_finetune dpo_checkpoints/fairseq2/self_demo \ 50 | --config-file configs/dpo_70b.yml 51 | ``` 52 | 53 | Other configs correspond to SFT and smaller scale (8B) training runs. 54 | 55 | Please refer to documentation on the library setup and examples: https://facebookresearch.github.io/fairseq2/stable/ 56 | 57 | ## Evaluation 58 | 59 | 1. Obtain eval data with retrievals. We use the evals from the [RA-DIT paper](https://arxiv.org/abs/2310.01352), which comes with retrievals and place them in `data/ra-dit/`. 60 | 2. Convert eval files to the correct format. 61 | ```sh 62 | python scripts/data/io_to_qas_format.py \ 63 | data/ra-dit/eli5/eli5-dev-kilt.jsonl \ 64 | data/ra-dit/eli5/dev.jsonl 65 | ``` 66 | 3. Run the evaluation. 67 | ```sh 68 | judge="Meta-Llama-3.1-405B-Instruct-FP8" # Set to judge model path 69 | eval_set=nq # Set to one of `mmlu zsrequestion conllyagotrunc eli5 hotpotqa nq tqa trex fever wow` 70 | strat=dpo_self_demo_70b # Set to training strategy name 71 | hf_checkpoint= # Set to Huggingface model checkpoint path 72 | pred_tpsize=8 # Tensor parallel size 73 | model_size=70b 74 | ndocs=4 75 | samples=1 76 | preds="data/${strat}/eval/preds/${eval_set}.jsonl" 77 | reward_file="data/${strat}/eval/reward/${eval_set}.jsonl" 78 | reward_file_gemma="data/${strat}/eval/reward/${eval_set}_gemma.jsonl" 79 | response_labels="data/${strat}/eval/response_labels/${eval_set}.jsonl" 80 | response_label_reasons="data/${strat}/eval/response_label_reasons/${eval_set}.jsonl" 81 | relevance="data/relevance/${model_size}/${eval_set}.jsonl" 82 | relevance_reasons="data/relevance_reasons/${model_size}/${eval_set}.jsonl" 83 | resultsfile="results/${strat}/eval/metrics/${eval_set}.json" 84 | datafile=data/ra-dit/${eval_set}/dev.jsonl 85 | 86 | # Generate outputs 87 | python scripts/generate.py \ 88 | --model=${hf_checkpoint} \ 89 | --outfile=$preds \ 90 | --samples=$samples \ 91 | --tensor-parallel-size=$pred_tpsize \ 92 | --ndocs=$ndocs \ 93 | --data $datafile 94 | 95 | # Get reward model scores 96 | python scripts/reward_model_gemma.py \ 97 | --outfile=$reward_file_gemma \ 98 | --responses=$preds \ 99 | --ndocs=$ndocs \ 100 | --data $datafile 101 | 102 | # Identify whether context contains the answer 103 | python scripts/relevance.py \ 104 | --datafile $datafile \ 105 | --reasoning_file $relevance_reasons \ 106 | --outfile $relevance \ 107 | --ndocs=$ndocs \ 108 | --tensor-parallel-size=$tpsize \ 109 | --judge=$judge \ 110 | --logfile logs/relevance_${eval_set}.log \ 111 | 112 | # Evaluate (correct/incorrect/refuse) model outputs. 113 | python scripts/eval.py \ 114 | --preds $preds \ 115 | --datafile $datafile \ 116 | --outfile $response_labels \ 117 | --reasoning_file $response_label_reasons \ 118 | --tensor-parallel-size=$tpsize \ 119 | --judge=$judge \ 120 | --logfile logs/response_labels_${eval_set}_${strat}.log \ 121 | ``` 122 | -------------------------------------------------------------------------------- /projects/sd-ra-it/configs/dpo_70b.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | _set_: 3 | name: llama3_70b_instruct 4 | dataset: 5 | _set_: 6 | name: sdrait 7 | path: data/70b/train/self_demo/dpo 8 | max_seq_len: 4096 9 | batch_size: 1 10 | gang: 11 | _set_: 12 | tensor_parallel_size: 8 13 | trainer: 14 | fsdp: 15 | _set_: 16 | version: v1 17 | granularity: layer 18 | hsdp: false 19 | reshard_after_forward: true 20 | fp32_reduce: true 21 | _set_: 22 | dtype: bfloat16 23 | data_parallelism: fsdp 24 | mixed_precision: static 25 | gradient_accumulation: 4 26 | activation_checkpointing: true 27 | max_gradient_norm: null 28 | fp16_loss_scale: 29 | - 128.0 30 | - 0.0001 31 | torch_compile: false 32 | profile: null 33 | gradient_check: false 34 | anomaly_detection: false 35 | criterion: 36 | config: 37 | reference_model: 38 | _set_: 39 | name: llama3_70b_instruct 40 | _set_: 41 | reference_dtype: bfloat16 42 | beta: 0.1 43 | nll_scale: 0.0 44 | length_normalization: false 45 | _set_: 46 | name: dpo 47 | optimizer: 48 | config: 49 | _set_: 50 | lr: 5.5e-06 51 | lr_scheduler: 52 | config: 53 | _set_: 54 | cycle_len: null 55 | num_warmup_steps: 0 56 | cycle_mul: 1.0 57 | lr_mul: 1.0 58 | start_lr: 0.0 59 | final_lr: 1.1e-06 60 | final_lr_scale: null 61 | _set_: 62 | name: cosine_annealing 63 | regime: 64 | _set_: 65 | num_steps: 800 66 | num_data_epochs: 5 67 | checkpoint_every_n_steps: 1000 68 | checkpoint_after_n_data_epochs: 1 69 | checkpoint_every_n_data_epochs: null 70 | keep_last_n_checkpoints: 1 71 | publish_metrics_every_n_steps: 5 72 | -------------------------------------------------------------------------------- /projects/sd-ra-it/configs/dpo_8b.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | _set_: 3 | name: llama3_8b_instruct 4 | dataset: 5 | _set_: 6 | name: sdrait 7 | path: data/70b/train/self_demo/dpo 8 | max_seq_len: 4096 9 | batch_size: 1 10 | trainer: 11 | fsdp: 12 | _set_: 13 | version: v1 14 | granularity: layer 15 | hsdp: false 16 | reshard_after_forward: true 17 | fp32_reduce: true 18 | _set_: 19 | dtype: bfloat16 20 | data_parallelism: fsdp 21 | mixed_precision: static 22 | gradient_accumulation: 4 23 | activation_checkpointing: true 24 | max_gradient_norm: null 25 | fp16_loss_scale: 26 | - 128.0 27 | - 0.0001 28 | torch_compile: false 29 | profile: null 30 | gradient_check: false 31 | anomaly_detection: false 32 | criterion: 33 | config: 34 | reference_model: 35 | _set_: 36 | name: llama3_8b_instruct 37 | _set_: 38 | reference_dtype: bfloat16 39 | beta: 0.1 40 | nll_scale: 0.0 41 | length_normalization: false 42 | _set_: 43 | name: dpo 44 | optimizer: 45 | config: 46 | _set_: 47 | lr: 5.5e-06 48 | lr_scheduler: 49 | config: 50 | _set_: 51 | cycle_len: null 52 | num_warmup_steps: 0 53 | cycle_mul: 1.0 54 | lr_mul: 1.0 55 | start_lr: 0.0 56 | final_lr: 1.1e-06 57 | final_lr_scale: null 58 | _set_: 59 | name: cosine_annealing 60 | regime: 61 | _set_: 62 | num_steps: 800 63 | num_data_epochs: 5 64 | checkpoint_every_n_steps: 1000 65 | checkpoint_after_n_data_epochs: 1 66 | checkpoint_every_n_data_epochs: null 67 | keep_last_n_checkpoints: 1 68 | publish_metrics_every_n_steps: 5 69 | -------------------------------------------------------------------------------- /projects/sd-ra-it/configs/sft_70b.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | _set_: 3 | name: llama3_70b_instruct 4 | dataset: 5 | _set_: 6 | name: sdrait 7 | path: ra-dit/train 8 | gang: 9 | _set_: 10 | tensor_parallel_size: 8 11 | optimizer: 12 | config: 13 | _set_: 14 | lr: 5.5e-06 15 | -------------------------------------------------------------------------------- /projects/sd-ra-it/configs/sft_8b.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | _set_: 3 | name: llama3_8b_instruct 4 | dataset: 5 | _set_: 6 | name: sdrait 7 | path: ra-dit/train 8 | optimizer: 9 | config: 10 | _set_: 11 | lr: 5.5e-06 12 | -------------------------------------------------------------------------------- /projects/sd-ra-it/scripts/create_self_demo_train_set.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | set -e 9 | eval "$(conda shell.bash hook)" 10 | 11 | conda activate pytorch 12 | train_set=$1 13 | model_name=$2 14 | if [ $train_set = "oasst" ] 15 | then 16 | num=20_000 17 | elif [ $train_set = "tasks" ] 18 | then 19 | num=200_000 20 | fi 21 | if [ $model_name = "Meta-Llama-3-70B-Instruct" ] 22 | then 23 | tensor_parallel_size=4 24 | else 25 | tensor_parallel_size=1 26 | fi 27 | output_file="data/70b/train/${train_set}.jsonl" 28 | mkdir -p $(dirname $output_file) 29 | python scripts/get_demos.py \ 30 | --filename ra-dit/multisource/${train_set}.jsonl \ 31 | --output_file $output_file \ 32 | --n $num \ 33 | --prompts-per-strat 3 \ 34 | --tensor_parallel_size $tensor_parallel_size \ 35 | --model_name $model_name \ 36 | --continued \ 37 | --logfile logs/get_demos_${train_set}_70b.log 38 | -------------------------------------------------------------------------------- /projects/sd-ra-it/scripts/data/io_to_qas_format.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import json 9 | import os 10 | import sys 11 | 12 | from tqdm import tqdm 13 | 14 | infilename = sys.argv[1] 15 | outfilename = sys.argv[2] 16 | with open(infilename) as infile, open(outfilename, "w") as outfile: 17 | for line in tqdm(map(json.loads, infile)): 18 | question = line["input"] 19 | answers = [ 20 | output.get("answer") 21 | for output in line["output"] 22 | if output.get("answer") is not None 23 | ] 24 | output_line = line | dict(question=question, answers=answers) 25 | print(json.dumps(output_line), file=outfile) 26 | -------------------------------------------------------------------------------- /projects/sd-ra-it/scripts/eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import collections 9 | import functools as ft 10 | import itertools as it 11 | import json 12 | import logging 13 | import os 14 | import pathlib 15 | import re 16 | from typing import Iterable 17 | 18 | import fire # type:ignore 19 | from src.utils import OfflineModel, batched, get_renderer 20 | from tqdm import tqdm 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | def classify( 26 | model: OfflineModel, 27 | line_pred_pairs: Iterable[tuple[dict, str]], 28 | verbose: bool = False, 29 | ) -> tuple[list[str], list[str | None]]: 30 | sys_msg: str = pathlib.Path("templates/judge/eval/sys.jinja").read_text() 31 | render = get_renderer("templates/judge/eval/usr.jinja") 32 | usr_msgs: list[str] = [ 33 | render(question=line["question"], answers=line["answers"], prediction=pred) 34 | for line, pred in line_pred_pairs 35 | ] 36 | sys_msgs: list[str] = list(it.repeat(sys_msg, len(usr_msgs))) 37 | reasoning, labels = model.parse_zero_shot_with_retries(sys_msgs, usr_msgs, parse) 38 | return list(zip(reasoning, labels)) 39 | 40 | 41 | def parse(reply: str) -> str | None: 42 | pattern: str = ( 43 | r"(\*\*)?(Label|Prediction label):(\*\*)? (?P