├── .circleci └── config.yml ├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── .vscode ├── extensions.json └── settings.json ├── CITATION.bib ├── LICENSE ├── README.md ├── experiment ├── .gitignore ├── README.md ├── configs │ ├── compressive-former-enwik8.json │ ├── compressive-former-sort.json │ ├── compressive-former.json │ ├── gpt2-enwik8.json │ ├── gpt2-synthetic.json │ ├── gpt2-xl.json │ ├── gpt2.json │ ├── infinity-gpt2-enwik8.json │ ├── infinity-gpt2-sort.json │ ├── infinity-gpt2-synthetic.json │ ├── infinity-gpt2.json │ ├── memoria-enwik8.json │ ├── memoria-gpt2-large.json │ ├── memoria-gpt2-medium.json │ ├── memoria-gpt2-sort.json │ ├── memoria-gpt2-synthetic.json │ ├── memoria-gpt2-xl.json │ ├── memoria-gpt2.json │ ├── transfo-xl-enwik8.json │ ├── transfo-xl-sort.json │ ├── transfo-xl-synthetic.json │ └── transfo-xl.json ├── eval_classification.py ├── eval_language_modeling.py ├── eval_synthetic.py ├── longseq_formers │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── enwik8.py │ │ ├── hyperpartisan.py │ │ ├── pg19.py │ │ └── wikitext103.py │ ├── dataset │ │ ├── __init__.py │ │ ├── classification.py │ │ ├── language_modeling.py │ │ └── synthetic.py │ ├── model │ │ ├── __init__.py │ │ ├── compressive_former │ │ │ ├── __init__.py │ │ │ └── modeling_compressive_transformer.py │ │ ├── gpt2_with_memoria │ │ │ ├── __init__.py │ │ │ └── modeling_gpt2_with_memoria.py │ │ ├── infinity_gpt2 │ │ │ ├── __init__.py │ │ │ ├── basis_functions.py │ │ │ ├── configuration_infinity_gpt2.py │ │ │ ├── continuous_softmax.py │ │ │ ├── continuous_sparsemax.py │ │ │ ├── long_term_attention.py │ │ │ └── modeling_infinity_gpt2.py │ │ ├── memoria_bert │ │ │ ├── __init__.py │ │ │ ├── configuration_memoria_bert.py │ │ │ └── modeling_memoria_bert.py │ │ └── memoria_roberta │ │ │ ├── __init__.py │ │ │ ├── configuration_memoria_roberta.py │ │ │ └── modeling_memoria_roberta.py │ ├── task │ │ ├── __init__.py │ │ ├── classification.py │ │ ├── language_modeling.py │ │ └── synthetic.py │ └── utils.py ├── requirements.txt ├── train_classification.py ├── train_language_modeling.py └── train_synthetic.py ├── images └── Memoria-Engrams.gif ├── memoria ├── __init__.py ├── abstractor.py ├── engram.py ├── history_manager.py ├── memoria.py ├── sparse_tensor.py ├── types.py └── utils.py ├── pyproject.toml ├── requirements-dev.txt ├── requirements.txt ├── setup.py └── tests ├── __init__.py ├── test_abstractor.py ├── test_engram.py ├── test_history_manager.py ├── test_memoria.py ├── test_sparse_tensor.py └── test_utils.py /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | 3 | orbs: 4 | codecov: codecov/codecov@4.0.1 5 | 6 | executors: 7 | python-executor: 8 | working_directory: ~/memoria 9 | docker: 10 | - image: circleci/python:3.10 11 | 12 | commands: 13 | install-packages: 14 | steps: 15 | - checkout 16 | 17 | - restore_cache: 18 | key: deps-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }} 19 | 20 | - run: 21 | name: Create Virtual Environment and Install Dependencies 22 | command: | 23 | virtualenv env 24 | source env/bin/activate 25 | pip install -r requirements.txt -r requirements-dev.txt 26 | 27 | - save_cache: 28 | key: deps-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }} 29 | paths: 30 | - "env" 31 | 32 | jobs: 33 | run-test: 34 | executor: python-executor 35 | steps: 36 | - install-packages 37 | 38 | - run: 39 | name: Run Tests and Coverage 40 | command: | 41 | source env/bin/activate 42 | pytest --cov --cov-branch --cov-report=xml 43 | 44 | - codecov/upload 45 | 46 | check-linting: 47 | executor: python-executor 48 | steps: 49 | - install-packages 50 | 51 | - run: 52 | name: Run black, isort 53 | command: | 54 | source env/bin/activate 55 | black --check memoria tests 56 | isort memoria tests 57 | 58 | workflows: 59 | main: 60 | jobs: 61 | - run-test 62 | - check-linting 63 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v4 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Environments 7 | .env 8 | .venv 9 | env/ 10 | venv/ 11 | ENV/ 12 | env.bak/ 13 | venv.bak/ 14 | 15 | # personal vscode settings 16 | .vscode/launch.json 17 | .vscode/tasks.json 18 | 19 | # log file 20 | *.log 21 | 22 | # build byproducts 23 | build/ 24 | dist/ 25 | memoria_pytorch.egg-info/ 26 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "ms-python.python", 4 | "ms-pyright.pyright", 5 | ] 6 | } 7 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": "env/bin/python", 3 | "python.formatting.provider": "black", 4 | "[python]": { 5 | "editor.codeActionsOnSave": { 6 | "source.organizeImports": "explicit" 7 | } 8 | }, 9 | "editor.formatOnSave": true, 10 | "files.trimTrailingWhitespace": true, 11 | "files.insertFinalNewline": true, 12 | "python.testing.pytestArgs": [ 13 | "./tests" 14 | ], 15 | "python.testing.unittestEnabled": false, 16 | "python.testing.pytestEnabled": true 17 | } 18 | -------------------------------------------------------------------------------- /CITATION.bib: -------------------------------------------------------------------------------- 1 | @InProceedings{pmlr-v235-park24a, 2 | title = {Memoria: Resolving Fateful Forgetting Problem through Human-Inspired Memory Architecture}, 3 | author = {Park, Sangjun and Bak, Jinyeong}, 4 | booktitle = {Proceedings of the 41st International Conference on Machine Learning}, 5 | pages = {39587--39615}, 6 | year = {2024}, 7 | editor = {Salakhutdinov, Ruslan and Kolter, Zico and Heller, Katherine and Weller, Adrian and Oliver, Nuria and Scarlett, Jonathan and Berkenkamp, Felix}, 8 | volume = {235}, 9 | series = {Proceedings of Machine Learning Research}, 10 | month = {21--27 Jul}, 11 | publisher = {PMLR}, 12 | pdf = {https://raw.githubusercontent.com/mlresearch/v235/main/assets/park24a/park24a.pdf}, 13 | url = {https://proceedings.mlr.press/v235/park24a.html}, 14 | abstract = {Making neural networks remember over the long term has been a longstanding issue. Although several external memory techniques have been introduced, most focus on retaining recent information in the short term. Regardless of its importance, information tends to be fatefully forgotten over time. We present Memoria, a memory system for artificial neural networks, drawing inspiration from humans and applying various neuroscientific and psychological theories. The experimental results prove the effectiveness of Memoria in the diverse tasks of sorting, language modeling, and classification, surpassing conventional techniques. Engram analysis reveals that Memoria exhibits the primacy, recency, and temporal contiguity effects which are characteristics of human memory.} 15 | } 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 ParkSangJun 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Memoria 2 | 3 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 4 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 5 | [![Imports: isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/) 6 | [![CircleCI](https://dl.circleci.com/status-badge/img/gh/cosmoquester/memoria/tree/master.svg?style=svg&circle-token=513f0f5e9a706a51509d198359fe0e016a227ce9)](https://dl.circleci.com/status-badge/redirect/gh/cosmoquester/memoria/tree/master) 7 | [![codecov](https://codecov.io/gh/cosmoquester/memoria/branch/master/graph/badge.svg?token=KZdkgkBzZG)](https://codecov.io/gh/cosmoquester/memoria) 8 | 9 | 10 | 11 | Making neural networks remember over the long term has been a longstanding issue. Although several external memory techniques have been introduced, most focus on retaining recent information in the short term. Regardless of its importance, information tends to be fatefully forgotten over time. We present Memoria, a memory system for artificial neural networks, drawing inspiration from humans and applying various neuroscientific and psychological theories. The experimental results prove the effectiveness of Memoria in the diverse tasks of sorting, language modeling, and classification, surpassing conventional techniques. Engram analysis reveals that Memoria exhibits the primacy, recency, and temporal contiguity effects which are characteristics of human memory. 12 | 13 | Memoria is an independant module which can be applied to neural network models in various ways and the experiment code of the paper is in the `experiment` directory. 14 | 15 | My paper [Memoria: Resolving Fateful Forgetting Problem through Human-Inspired Memory Architecture](https://icml.cc/virtual/2024/poster/32668) is accepted to **International Conference on Machine Learning (ICML) 2024 as a Spotlight paper**. 16 | The full text of the paper can be accessed from [OpenReview](https://openreview.net/forum?id=yTz0u4B8ug) or [ArXiv](https://arxiv.org/abs/2310.03052). 17 | 18 | ## Installation 19 | 20 | ```sh 21 | $ pip install memoria-pytorch 22 | ``` 23 | 24 | You can install memoria by pip command above. 25 | 26 | ## Tutorial 27 | 28 | This is a tutorial to help to understand the concept and mechanism of Memoria. 29 | 30 | #### 1. Import Memoria and Set Parameters 31 | 32 | ```python 33 | import torch 34 | from memoria import Memoria, EngramType 35 | 36 | torch.manual_seed(42) 37 | 38 | # Memoria Parameters 39 | num_reminded_stm = 4 40 | stm_capacity = 16 41 | ltm_search_depth = 5 42 | initial_lifespan = 3 43 | num_final_ltms = 4 44 | 45 | # Data Parameters 46 | batch_size = 2 47 | sequence_length = 8 48 | hidden_dim = 64 49 | ``` 50 | 51 | #### 2. Initialize Memoria and Dummy Data 52 | 53 | - Fake random data and lifespan delta are used for simplification. 54 | 55 | ```python 56 | memoria = Memoria( 57 | num_reminded_stm=num_reminded_stm, 58 | stm_capacity=stm_capacity, 59 | ltm_search_depth=ltm_search_depth, 60 | initial_lifespan=initial_lifespan, 61 | num_final_ltms=num_final_ltms, 62 | ) 63 | data = torch.rand(batch_size, sequence_length, hidden_dim) 64 | ``` 65 | 66 | #### 3. Add Data as Working Memory 67 | 68 | ```python 69 | # Add data as working memory 70 | memoria.add_working_memory(data) 71 | ``` 72 | 73 | ```python 74 | # Expected values 75 | >>> len(memoria.engrams) 76 | 16 77 | >>> memoria.engrams.data.shape 78 | torch.Size([2, 8, 64]) 79 | >>> memoria.engrams.lifespan 80 | tensor([[3., 3., 3., 3., 3., 3., 3., 3.], 81 | [3., 3., 3., 3., 3., 3., 3., 3.]]) 82 | ``` 83 | 84 | #### 4. Remind Memories 85 | 86 | - Empty memories are reminded because there is no engrams in STM/LTM yet 87 | 88 | ```python 89 | reminded_memories, reminded_indices = memoria.remind() 90 | ``` 91 | 92 | ```python 93 | # No reminded memories because there is no STM/LTM engrams yet 94 | >>> reminded_memories 95 | tensor([], size=(2, 0, 64)) 96 | >>> reminded_indices 97 | tensor([], size=(2, 0), dtype=torch.int64) 98 | ``` 99 | 100 | #### 5. Adjust Lifespan and Memories 101 | 102 | - In this step, no engrams earn lifespan because there is no reminded memories 103 | 104 | ```python 105 | memoria.adjust_lifespan_and_memories(reminded_indices, torch.zeros_like(reminded_indices)) 106 | ``` 107 | 108 | ```python 109 | # Decreases lifespan for all engrams & working memories have changed into shortterm memory 110 | >>> memoria.engrams.lifespan 111 | tensor([[2., 2., 2., 2., 2., 2., 2., 2.], 112 | [2., 2., 2., 2., 2., 2., 2., 2.]]) 113 | >>> memoria.engrams.engrams_types 114 | tensor([[2, 2, 2, 2, 2, 2, 2, 2], 115 | [2, 2, 2, 2, 2, 2, 2, 2]], dtype=torch.uint8) 116 | >>> EngramType.SHORTTERM 117 | 118 | ``` 119 | 120 | #### 6. Repeat one more time 121 | 122 | - Now, there are some engrams in STM, remind and adjustment from STM will work 123 | 124 | ```python 125 | data2 = torch.rand(batch_size, sequence_length, hidden_dim) 126 | memoria.add_working_memory(data2) 127 | ``` 128 | 129 | ```python 130 | >>> len(memoria.engrams) 131 | 32 132 | >>> memoria.engrams.lifespan 133 | tensor([[2., 2., 2., 2., 2., 2., 2., 2., 3., 3., 3., 3., 3., 3., 3., 3.], 134 | [2., 2., 2., 2., 2., 2., 2., 2., 3., 3., 3., 3., 3., 3., 3., 3.]]) 135 | ``` 136 | 137 | ```python 138 | reminded_memories, reminded_indices = memoria.remind() 139 | ``` 140 | 141 | ```python 142 | # Remind memories from STM 143 | >>> reminded_memories.shape 144 | torch.Size([2, 6, 64]) 145 | >>> reminded_indices.shape 146 | torch.Size([2, 6]) 147 | >>> reminded_indices 148 | tensor([[ 0, 6, 4, 3, 2, -1], 149 | [ 0, 7, 6, 5, 4, -1]]) 150 | ``` 151 | 152 | ```python 153 | # Increase lifespan of all the reminded engrams by 5 154 | memoria.adjust_lifespan_and_memories(reminded_indices, torch.full_like(reminded_indices, 5)) 155 | ``` 156 | 157 | ```python 158 | # Reminded engrams got lifespan by 5, other engrams have got older 159 | >>> memoria.engrams.lifespan 160 | >>> memoria.engrams.lifespan 161 | tensor([[6., 1., 6., 6., 6., 1., 6., 1., 2., 2., 2., 2., 2., 2., 2., 2.], 162 | [6., 1., 1., 1., 6., 6., 6., 6., 2., 2., 2., 2., 2., 2., 2., 2.]]) 163 | ``` 164 | 165 | #### 7. Repeat 166 | 167 | - Repeat 10 times to see the dynamics of LTM 168 | 169 | ```python 170 | # This is default process to utilize Memoria 171 | for _ in range(10): 172 | data = torch.rand(batch_size, sequence_length, hidden_dim) 173 | memoria.add_working_memory(data) 174 | 175 | reminded_memories, reminded_indices = memoria.remind() 176 | 177 | lifespan_delta = torch.randint_like(reminded_indices, 0, 6).float() 178 | 179 | memoria.adjust_lifespan_and_memories(reminded_indices, lifespan_delta) 180 | ``` 181 | 182 | ```python 183 | # After 10 iteration, some engrams have changed into longterm memory and got large lifespan 184 | # Engram type zero means those engrams are deleted 185 | >>> len(memoria.engrams) 186 | 72 187 | >>> memoria.engrams.engrams_types 188 | tensor([[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 189 | 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], 190 | [0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 191 | 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]], dtype=torch.uint8) 192 | >>> EngramType.LONGTERM 193 | 194 | >>> EngramType.NULL 195 | 196 | >>> memoria.engrams.lifespan 197 | tensor([[ 9., 1., 8., 2., 16., 5., 13., 7., 7., 3., 3., 4., 3., 3., 198 | 4., 2., 2., 1., 1., 1., 1., 1., 1., 1., 2., 6., 1., 1., 199 | 2., 2., 2., 2., 2., 2., 2., 2.], 200 | [-1., -1., 3., 2., 19., 21., 11., 6., 14., 1., 5., 1., 5., 1., 201 | 5., 1., 1., 8., 2., 1., 1., 1., 2., 1., 1., 1., 1., 1., 202 | 2., 2., 2., 2., 2., 2., 2., 2.]]) 203 | ``` 204 | 205 | # Citation 206 | 207 | ```bibtex 208 | @InProceedings{pmlr-v235-park24a, 209 | title = {Memoria: Resolving Fateful Forgetting Problem through Human-Inspired Memory Architecture}, 210 | author = {Park, Sangjun and Bak, Jinyeong}, 211 | booktitle = {Proceedings of the 41st International Conference on Machine Learning}, 212 | pages = {39587--39615}, 213 | year = {2024}, 214 | editor = {Salakhutdinov, Ruslan and Kolter, Zico and Heller, Katherine and Weller, Adrian and Oliver, Nuria and Scarlett, Jonathan and Berkenkamp, Felix}, 215 | volume = {235}, 216 | series = {Proceedings of Machine Learning Research}, 217 | month = {21--27 Jul}, 218 | publisher = {PMLR}, 219 | pdf = {https://raw.githubusercontent.com/mlresearch/v235/main/assets/park24a/park24a.pdf}, 220 | url = {https://proceedings.mlr.press/v235/park24a.html}, 221 | abstract = {Making neural networks remember over the long term has been a longstanding issue. Although several external memory techniques have been introduced, most focus on retaining recent information in the short term. Regardless of its importance, information tends to be fatefully forgotten over time. We present Memoria, a memory system for artificial neural networks, drawing inspiration from humans and applying various neuroscientific and psychological theories. The experimental results prove the effectiveness of Memoria in the diverse tasks of sorting, language modeling, and classification, surpassing conventional techniques. Engram analysis reveals that Memoria exhibits the primacy, recency, and temporal contiguity effects which are characteristics of human memory.} 222 | } 223 | ``` 224 | -------------------------------------------------------------------------------- /experiment/.gitignore: -------------------------------------------------------------------------------- 1 | lightning_logs/ 2 | -------------------------------------------------------------------------------- /experiment/README.md: -------------------------------------------------------------------------------- 1 | # Memoria Experiment 2 | 3 | The directory contains the model architecture, data loader, config files, and training and evaluation script to conduct experiments in my paper. You can reproduce my research by refering to the Memoria paper or develop your own idea on this. 4 | 5 | ## Package Install 6 | 7 | You should install the required packages before running the code. 8 | 9 | ```sh 10 | $ pip install -r requirements.txt 11 | ``` 12 | 13 | ## Structure 14 | 15 | ``` 16 | longseq_formers 17 | ├── configs 18 | └── longseq_formers 19 | ├── data 20 | ├── dataset 21 | ├── model 22 | │ ├── compressive_former 23 | │ ├── gpt2_with_memoria 24 | │ ├── infinity_gpt2 25 | │ ├── memoria_bert 26 | │ └── memoria_roberta 27 | └── task 28 | ``` 29 | - `longseq_formers` directory is main directory for experiment. There are data loaders, task training, and model architectures. 30 | - `configs` directory includes multiple config files for language modeling and synthetic task (sorting) for multiple models. 31 | 32 | ## Models 33 | 34 | You can load modeles from `longseq_formers.model` module regardless of the training or evaluation script. 35 | 36 | ```python 37 | import torch 38 | from longseq_formers.model import MemoriaBertModel 39 | 40 | memoria_bert = MemoriaBertModel.from_pretrained("bert-base-uncased") 41 | input_ids = torch.randint(0, 10, [1,10]) 42 | outputs = memoria_bert(input_ids) 43 | ``` 44 | 45 | ```python 46 | import torch 47 | from longseq_formers.model import GPT2WithMemoriaLMHeadModel 48 | 49 | memoria_gpt2 = GPT2WithMemoriaLMHeadModel.from_pretrained("gpt2") 50 | input_ids = torch.randint(0, 10, [1,10]) 51 | outputs = memoria_gpt2(input_ids) 52 | ``` 53 | 54 | ## Train 55 | 56 | You can train the model with training scripts depending on the task. With the `--help` option, you can see the options for training or evaluation. Because all the datasets except for the sorting task will be loaded from web, you don't have to download the dataset separately. 57 | 58 | ```sh 59 | $ python train_language_modeling.py --help 60 | usage: train [-h] [--model-config MODEL_CONFIG] [--model MODEL] [--model-type MODEL_TYPE] [--tokenizer TOKENIZER] [--dataset {wikitext103,pg19,enwik8}] 61 | [--batch-size BATCH_SIZE] [--valid-batch-size VALID_BATCH_SIZE] [--accumulate-grad-batches ACCUMULATE_GRAD_BATCHES] [--max-length MAX_LENGTH] [--epochs EPOCHS] 62 | [--learning-rate LEARNING_RATE] [--warmup-rate WARMUP_RATE] [--max-grad-norm MAX_GRAD_NORM] [--seed SEED] [--shuffle] [--test-ckpt {best,last}] 63 | [--output-dir OUTPUT_DIR] [--gpus GPUS] [--logging-interval LOGGING_INTERVAL] [--valid-interval VALID_INTERVAL] [--wandb-run-name WANDB_RUN_NAME] 64 | [--wandb-entity WANDB_ENTITY] [--wandb-project WANDB_PROJECT] 65 | 66 | Train & Test Language Modeling 67 | 68 | optional arguments: 69 | -h, --help show this help message and exit 70 | 71 | Train Parameter: 72 | --model-config MODEL_CONFIG 73 | huggingface model config 74 | --model MODEL huggingface model 75 | --model-type MODEL_TYPE 76 | specific model type 77 | --tokenizer TOKENIZER 78 | huggingface tokenizer 79 | --dataset {wikitext103,pg19,enwik8} 80 | dataset name 81 | --batch-size BATCH_SIZE 82 | global training batch size 83 | --valid-batch-size VALID_BATCH_SIZE 84 | validation batch size 85 | --accumulate-grad-batches ACCUMULATE_GRAD_BATCHES 86 | the number of gradident accumulation steps 87 | --max-length MAX_LENGTH 88 | max sequence length 89 | --epochs EPOCHS the number of training epochs 90 | --learning-rate LEARNING_RATE 91 | learning rate 92 | --warmup-rate WARMUP_RATE 93 | warmup step rate 94 | --max-grad-norm MAX_GRAD_NORM 95 | maximum gradient norm 96 | --seed SEED random seed 97 | --shuffle shuffle data order 98 | --test-ckpt {best,last} 99 | checkpoint type for testing 100 | 101 | Personal Options: 102 | --output-dir OUTPUT_DIR 103 | output directory path to save artifacts 104 | --gpus GPUS the number of gpus, use all devices by default 105 | --logging-interval LOGGING_INTERVAL 106 | logging interval 107 | --valid-interval VALID_INTERVAL 108 | validation interval rate 109 | 110 | Wandb Options: 111 | --wandb-run-name WANDB_RUN_NAME 112 | wanDB run name 113 | --wandb-entity WANDB_ENTITY 114 | wanDB entity name 115 | --wandb-project WANDB_PROJECT 116 | wanDB project name 117 | ``` 118 | 119 | ```sh 120 | $ python train_language_modeling.py --model gpt2 121 | [2023-09-29 21:31:29,995] ====== Arguements ====== 122 | [2023-09-29 21:31:29,995] model_config : None 123 | [2023-09-29 21:31:29,995] model : gpt2 124 | [2023-09-29 21:31:29,995] model_type : None 125 | [2023-09-29 21:31:29,995] tokenizer : None 126 | [2023-09-29 21:31:29,995] dataset : wikitext103 127 | [2023-09-29 21:31:29,995] batch_size : 8 128 | [2023-09-29 21:31:29,995] valid_batch_size : 1 129 | [2023-09-29 21:31:29,995] accumulate_grad_batches : 1 130 | [2023-09-29 21:31:29,995] max_length : 150 131 | [2023-09-29 21:31:29,995] epochs : 6 132 | [2023-09-29 21:31:29,995] learning_rate : 0.0002 133 | ... 134 | ``` 135 | - You can start training simply this command without any download. 136 | 137 | ```sh 138 | $ python train_language_modeling.py --model-config configs/memoria-gpt2.json --tokenizer gpt2 --output-dir trained-model 139 | [2023-09-29 21:43:27,347] [+] Save output to "trained-model" 140 | [2023-09-29 21:43:27,347] ====== Arguements ====== 141 | [2023-09-29 21:43:27,347] model_config : configs/memoria-gpt2.json 142 | [2023-09-29 21:43:27,347] model : None 143 | [2023-09-29 21:43:27,347] model_type : None 144 | [2023-09-29 21:43:27,347] tokenizer : gpt2 145 | [2023-09-29 21:43:27,347] dataset : wikitext103 146 | [2023-09-29 21:43:27,347] batch_size : 8 147 | [2023-09-29 21:43:27,347] valid_batch_size : 1 148 | [2023-09-29 21:43:27,347] accumulate_grad_batches : 1 149 | [2023-09-29 21:43:27,347] max_length : 150 150 | [2023-09-29 21:43:27,347] epochs : 6 151 | [2023-09-29 21:43:27,347] learning_rate : 0.0002 152 | ... 153 | ``` 154 | - You can train MemoriaGPT2 model by adding `--model-type gpt2_with_memoria` option or `--model-config configs/memoria-gpt2.json` simply. 155 | - To save model checkpoint, you can add `--output-dir [OUTPUT-DIR]` option. The model checkpoint will be saved in `trained-model` directory. 156 | - Refer help description and the Memoria paper for detail hyperparameters. 157 | 158 | ## Evaluation 159 | 160 | ```sh 161 | $ python eval_language_modeling.py --model trained-model/checkpoint/last.ckpt 162 | [2023-09-29 21:45:03,214] ====== Arguements ====== 163 | [2023-09-29 21:45:03,214] model : trained-model/checkpoint/last.ckpt 164 | [2023-09-29 21:45:03,214] tokenizer : gpt2 165 | [2023-09-29 21:45:03,214] dataset : wikitext103 166 | [2023-09-29 21:45:03,214] valid_batch_size : 1 167 | [2023-09-29 21:45:03,214] max_length : 512 168 | [2023-09-29 21:45:03,214] seed : 42 169 | [2023-09-29 21:45:03,214] [+] Set Random Seed to 42 170 | Global seed set to 42 171 | [2023-09-29 21:45:03,237] [+] GPU: 1 172 | [2023-09-29 21:45:03,237] [+] Load Tokenizer: "gpt2" 173 | ... 174 | ``` 175 | - You should give save model checkpoint with `--model [MODEL-CHECKPOINT]` option. 176 | -------------------------------------------------------------------------------- /experiment/configs/compressive-former-enwik8.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "compressive_transformer", 3 | "attn_dropout": 0.1, 4 | "attn_layer_dropout": 0.1, 5 | "cmem_len": 256, 6 | "cmem_ratio": 4, 7 | "depth": 12, 8 | "dim": 512, 9 | "emb_dim": null, 10 | "enhanced_recurrence": true, 11 | "ff_dropout": 0.1, 12 | "ff_glu": false, 13 | "gru_gated_residual": false, 14 | "heads": 8, 15 | "mem_len": 256, 16 | "memory_layers": null, 17 | "mogrify_gru": false, 18 | "vocab_size": 204, 19 | "reconstruction_attn_dropout": 0.0, 20 | "reconstruction_loss_weight": 1.0, 21 | "seq_len": 512, 22 | "transformers_version": "4.25.1" 23 | } 24 | -------------------------------------------------------------------------------- /experiment/configs/compressive-former-sort.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "compressive_transformer", 3 | "attn_dropout": 0.1, 4 | "attn_layer_dropout": 0.1, 5 | "cmem_len": 512, 6 | "cmem_ratio": 4, 7 | "depth": 4, 8 | "dim": 512, 9 | "emb_dim": null, 10 | "enhanced_recurrence": true, 11 | "ff_dropout": 0.1, 12 | "ff_glu": false, 13 | "gru_gated_residual": false, 14 | "heads": 4, 15 | "mem_len": 512, 16 | "memory_layers": null, 17 | "mogrify_gru": false, 18 | "num_tokens": 21, 19 | "reconstruction_attn_dropout": 0.0, 20 | "reconstruction_loss_weight": 1.0, 21 | "seq_len": 1024, 22 | "transformers_version": "4.25.1" 23 | } 24 | -------------------------------------------------------------------------------- /experiment/configs/compressive-former.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "compressive_transformer", 3 | "attn_dropout": 0.1, 4 | "attn_layer_dropout": 0.1, 5 | "cmem_len": 75, 6 | "cmem_ratio": 4, 7 | "depth": 12, 8 | "dim": 768, 9 | "emb_dim": null, 10 | "enhanced_recurrence": true, 11 | "ff_dropout": 0.1, 12 | "ff_glu": false, 13 | "gru_gated_residual": false, 14 | "heads": 12, 15 | "mem_len": 75, 16 | "memory_layers": null, 17 | "mogrify_gru": false, 18 | "vocab_size": 50257, 19 | "reconstruction_attn_dropout": 0.0, 20 | "reconstruction_loss_weight": 1.0, 21 | "seq_len": 512, 22 | "transformers_version": "4.25.1" 23 | } 24 | -------------------------------------------------------------------------------- /experiment/configs/gpt2-enwik8.json: -------------------------------------------------------------------------------- 1 | { 2 | "activation_function": "gelu_new", 3 | "attn_pdrop": 0.1, 4 | "bos_token_id": 50256, 5 | "embd_pdrop": 0.1, 6 | "eos_token_id": 50256, 7 | "initializer_range": 0.02, 8 | "layer_norm_epsilon": 1e-05, 9 | "model_type": "gpt2", 10 | "n_embd": 512, 11 | "n_head": 8, 12 | "n_inner": null, 13 | "n_layer": 12, 14 | "n_positions": 1024, 15 | "reorder_and_upcast_attn": false, 16 | "resid_pdrop": 0.1, 17 | "scale_attn_by_inverse_layer_idx": false, 18 | "scale_attn_weights": true, 19 | "summary_activation": null, 20 | "summary_first_dropout": 0.1, 21 | "summary_proj_to_labels": true, 22 | "summary_type": "cls_index", 23 | "summary_use_proj": true, 24 | "transformers_version": "4.25.1", 25 | "use_cache": true, 26 | "vocab_size": 204 27 | } 28 | -------------------------------------------------------------------------------- /experiment/configs/gpt2-synthetic.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "gpt2_with_memoria", 3 | "activation_function": "gelu_new", 4 | "attn_pdrop": 0.1, 5 | "bos_token_id": 50256, 6 | "embd_pdrop": 0.1, 7 | "eos_token_id": 50256, 8 | "initializer_range": 0.02, 9 | "layer_norm_epsilon": 1e-05, 10 | "n_ctx": 1024, 11 | "n_embd": 512, 12 | "n_head": 4, 13 | "n_inner": null, 14 | "n_layer": 4, 15 | "n_positions": 1024, 16 | "reorder_and_upcast_attn": false, 17 | "resid_pdrop": 0.1, 18 | "scale_attn_by_inverse_layer_idx": false, 19 | "scale_attn_weights": true, 20 | "summary_activation": null, 21 | "summary_first_dropout": 0.1, 22 | "summary_proj_to_labels": true, 23 | "summary_type": "cls_index", 24 | "summary_use_proj": true, 25 | "task_specific_params": { 26 | "text-generation": { 27 | "do_sample": true, 28 | "max_length": 50 29 | } 30 | }, 31 | "transformers_version": "4.25.1", 32 | "use_cache": true, 33 | "vocab_size": 50257 34 | } 35 | -------------------------------------------------------------------------------- /experiment/configs/gpt2-xl.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "gpt2-xl", 3 | "activation_function": "gelu_new", 4 | "architectures": [ 5 | "GPT2LMHeadModel" 6 | ], 7 | "attn_pdrop": 0.1, 8 | "bos_token_id": 50256, 9 | "embd_pdrop": 0.1, 10 | "eos_token_id": 50256, 11 | "initializer_range": 0.02, 12 | "layer_norm_epsilon": 1e-05, 13 | "model_type": "gpt2", 14 | "n_ctx": 1024, 15 | "n_embd": 1600, 16 | "n_head": 25, 17 | "n_inner": null, 18 | "n_layer": 48, 19 | "n_positions": 1024, 20 | "output_past": true, 21 | "reorder_and_upcast_attn": false, 22 | "resid_pdrop": 0.1, 23 | "scale_attn_by_inverse_layer_idx": false, 24 | "scale_attn_weights": true, 25 | "summary_activation": null, 26 | "summary_first_dropout": 0.1, 27 | "summary_proj_to_labels": true, 28 | "summary_type": "cls_index", 29 | "summary_use_proj": true, 30 | "task_specific_params": { 31 | "text-generation": { 32 | "do_sample": true, 33 | "max_length": 50 34 | } 35 | }, 36 | "transformers_version": "4.25.1", 37 | "use_cache": true, 38 | "vocab_size": 50257 39 | } 40 | -------------------------------------------------------------------------------- /experiment/configs/gpt2.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "gpt2_with_memoria", 3 | "activation_function": "gelu_new", 4 | "attn_pdrop": 0.1, 5 | "bos_token_id": 50256, 6 | "embd_pdrop": 0.1, 7 | "eos_token_id": 50256, 8 | "initializer_range": 0.02, 9 | "layer_norm_epsilon": 1e-05, 10 | "n_ctx": 1024, 11 | "n_embd": 512, 12 | "n_head": 4, 13 | "n_inner": null, 14 | "n_layer": 4, 15 | "n_positions": 1024, 16 | "reorder_and_upcast_attn": false, 17 | "resid_pdrop": 0.1, 18 | "scale_attn_by_inverse_layer_idx": false, 19 | "scale_attn_weights": true, 20 | "summary_activation": null, 21 | "summary_first_dropout": 0.1, 22 | "summary_proj_to_labels": true, 23 | "summary_type": "cls_index", 24 | "summary_use_proj": true, 25 | "task_specific_params": { 26 | "text-generation": { 27 | "do_sample": true, 28 | "max_length": 50 29 | } 30 | }, 31 | "transformers_version": "4.25.1", 32 | "use_cache": true, 33 | "vocab_size": 50257 34 | } 35 | -------------------------------------------------------------------------------- /experiment/configs/infinity-gpt2-enwik8.json: -------------------------------------------------------------------------------- 1 | { 2 | "activation_function": "gelu_new", 3 | "attn_drop": 0.1, 4 | "attn_pdrop": 0.1, 5 | "bos_token_id": 50256, 6 | "detach_recursive_outputs": true, 7 | "embd_pdrop": 0.1, 8 | "eos_token_id": 50256, 9 | "gradient_checkpointing": false, 10 | "initializer_range": 0.02, 11 | "kl_lambda": 1e-06, 12 | "layer_norm_epsilon": 1e-05, 13 | "longterm_attention_dropout": 0.1, 14 | "mask_dropout": 0.1, 15 | "mask_type": "cnn", 16 | "memory_length": 512, 17 | "model_type": "infinity_gpt2", 18 | "mu_0": -1.0, 19 | "n_ctx": 1024, 20 | "n_embd": 512, 21 | "n_head": 8, 22 | "n_inner": null, 23 | "n_layer": 12, 24 | "n_positions": 1024, 25 | "normalize_function": "softmax", 26 | "num_basis": 512, 27 | "num_samples": 512, 28 | "resid_pdrop": 0.1, 29 | "sigma_0": 0.05, 30 | "summary_activation": null, 31 | "summary_first_dropout": 0.1, 32 | "summary_proj_to_labels": true, 33 | "summary_type": "cls_index", 34 | "summary_use_proj": true, 35 | "task_specific_params": { 36 | "text-generation": { 37 | "do_sample": true, 38 | "max_length": 50 39 | } 40 | }, 41 | "tau": 0.5, 42 | "transformers_version": "4.25.1", 43 | "use_affines": true, 44 | "use_cache": true, 45 | "use_kl_regularizer": true, 46 | "use_sticky_memories": true, 47 | "vocab_size": 204 48 | } 49 | -------------------------------------------------------------------------------- /experiment/configs/infinity-gpt2-sort.json: -------------------------------------------------------------------------------- 1 | { 2 | "activation_function": "gelu_new", 3 | "attn_drop": 0.1, 4 | "attn_pdrop": 0.1, 5 | "bos_token_id": 50256, 6 | "detach_recursive_outputs": true, 7 | "embd_pdrop": 0.1, 8 | "eos_token_id": 50256, 9 | "gradient_checkpointing": false, 10 | "initializer_range": 0.02, 11 | "kl_lambda": 1e-06, 12 | "layer_norm_epsilon": 1e-05, 13 | "longterm_attention_dropout": 0.1, 14 | "mask_dropout": 0.1, 15 | "mask_type": "cnn", 16 | "memory_length": 1024, 17 | "model_type": "infinity_gpt2", 18 | "mu_0": -1.0, 19 | "n_ctx": 1024, 20 | "n_embd": 512, 21 | "n_head": 4, 22 | "n_inner": 2048, 23 | "n_layer": 4, 24 | "n_positions": 1024, 25 | "normalize_function": "softmax", 26 | "num_basis": 1024, 27 | "num_samples": 1024, 28 | "resid_pdrop": 0.1, 29 | "sigma_0": 0.05, 30 | "summary_activation": null, 31 | "summary_first_dropout": 0.1, 32 | "summary_proj_to_labels": true, 33 | "summary_type": "cls_index", 34 | "summary_use_proj": true, 35 | "task_specific_params": { 36 | "text-generation": { 37 | "do_sample": true, 38 | "max_length": 50 39 | } 40 | }, 41 | "tau": 0.5, 42 | "transformers_version": "4.25.1", 43 | "use_affines": true, 44 | "use_cache": true, 45 | "use_kl_regularizer": true, 46 | "use_sticky_memories": true, 47 | "vocab_size": 50257 48 | } 49 | -------------------------------------------------------------------------------- /experiment/configs/infinity-gpt2-synthetic.json: -------------------------------------------------------------------------------- 1 | { 2 | "activation_function": "gelu_new", 3 | "attn_drop": 0.1, 4 | "attn_pdrop": 0.1, 5 | "bos_token_id": 50256, 6 | "detach_recursive_outputs": true, 7 | "embd_pdrop": 0.1, 8 | "eos_token_id": 50256, 9 | "gradient_checkpointing": false, 10 | "initializer_range": 0.02, 11 | "kl_lambda": 1e-06, 12 | "layer_norm_epsilon": 1e-05, 13 | "longterm_attention_dropout": 0.1, 14 | "mask_dropout": 0.1, 15 | "mask_type": "cnn", 16 | "memory_length": 18, 17 | "model_type": "infinity_gpt2", 18 | "mu_0": -1.0, 19 | "n_ctx": 1024, 20 | "n_embd": 512, 21 | "n_head": 4, 22 | "n_inner": null, 23 | "n_layer": 4, 24 | "n_positions": 1024, 25 | "normalize_function": "softmax", 26 | "num_basis": 18, 27 | "num_samples": 18, 28 | "resid_pdrop": 0.1, 29 | "sigma_0": 0.05, 30 | "summary_activation": null, 31 | "summary_first_dropout": 0.1, 32 | "summary_proj_to_labels": true, 33 | "summary_type": "cls_index", 34 | "summary_use_proj": true, 35 | "task_specific_params": { 36 | "text-generation": { 37 | "do_sample": true, 38 | "max_length": 50 39 | } 40 | }, 41 | "tau": 0.5, 42 | "transformers_version": "4.25.1", 43 | "use_affines": true, 44 | "use_cache": true, 45 | "use_kl_regularizer": true, 46 | "use_sticky_memories": true, 47 | "vocab_size": 50257 48 | } 49 | -------------------------------------------------------------------------------- /experiment/configs/infinity-gpt2.json: -------------------------------------------------------------------------------- 1 | { 2 | "activation_function": "gelu_new", 3 | "attn_drop": 0.1, 4 | "attn_pdrop": 0.1, 5 | "bos_token_id": 50256, 6 | "detach_recursive_outputs": true, 7 | "embd_pdrop": 0.1, 8 | "eos_token_id": 50256, 9 | "gradient_checkpointing": false, 10 | "initializer_range": 0.02, 11 | "kl_lambda": 1e-06, 12 | "layer_norm_epsilon": 1e-05, 13 | "longterm_attention_dropout": 0.1, 14 | "mask_dropout": 0.1, 15 | "mask_type": "cnn", 16 | "memory_length": 150, 17 | "model_type": "infinity_gpt2", 18 | "mu_0": -1.0, 19 | "n_ctx": 1024, 20 | "n_embd": 768, 21 | "n_head": 12, 22 | "n_inner": null, 23 | "n_layer": 12, 24 | "n_positions": 1024, 25 | "normalize_function": "softmax", 26 | "num_basis": 150, 27 | "num_samples": 150, 28 | "resid_pdrop": 0.1, 29 | "sigma_0": 0.05, 30 | "summary_activation": null, 31 | "summary_first_dropout": 0.1, 32 | "summary_proj_to_labels": true, 33 | "summary_type": "cls_index", 34 | "summary_use_proj": true, 35 | "task_specific_params": { 36 | "text-generation": { 37 | "do_sample": true, 38 | "max_length": 50 39 | } 40 | }, 41 | "tau": 0.5, 42 | "transformers_version": "4.25.1", 43 | "use_affines": true, 44 | "use_cache": true, 45 | "use_kl_regularizer": true, 46 | "use_sticky_memories": true, 47 | "vocab_size": 50257 48 | } 49 | -------------------------------------------------------------------------------- /experiment/configs/memoria-enwik8.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "gpt2_with_memoria", 3 | "activation_function": "gelu_new", 4 | "attn_pdrop": 0.1, 5 | "bos_token_id": 50256, 6 | "embd_pdrop": 0.1, 7 | "eos_token_id": 50256, 8 | "initializer_range": 0.02, 9 | "layer_norm_epsilon": 1e-05, 10 | "memoria_stm_capacity": 1536, 11 | "memoria_num_memories": 128, 12 | "memoria_initial_lifespan": 9, 13 | "memoria_lifespan_extend_scale": 8.0, 14 | "memoria_ltm_search_depth": 10, 15 | "memoria_reset_period": 500, 16 | "memoria_num_reminded_stm": 192, 17 | "memoria_num_reminded_ltm": 192, 18 | "memoria_device": null, 19 | "n_ctx": 1024, 20 | "n_embd": 512, 21 | "n_head": 8, 22 | "n_inner": null, 23 | "n_layer": 12, 24 | "n_positions": 1024, 25 | "reorder_and_upcast_attn": false, 26 | "resid_pdrop": 0.1, 27 | "scale_attn_by_inverse_layer_idx": false, 28 | "scale_attn_weights": true, 29 | "summary_activation": null, 30 | "summary_first_dropout": 0.1, 31 | "summary_proj_to_labels": true, 32 | "summary_type": "cls_index", 33 | "summary_use_proj": true, 34 | "task_specific_params": { 35 | "text-generation": { 36 | "do_sample": true, 37 | "max_length": 50 38 | } 39 | }, 40 | "transformers_version": "4.25.1", 41 | "use_cache": true, 42 | "vocab_size": 204 43 | } 44 | -------------------------------------------------------------------------------- /experiment/configs/memoria-gpt2-large.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "gpt2_with_memoria", 3 | "activation_function": "gelu_new", 4 | "attn_pdrop": 0.1, 5 | "bos_token_id": 50256, 6 | "embd_pdrop": 0.1, 7 | "eos_token_id": 50256, 8 | "initializer_range": 0.02, 9 | "layer_norm_epsilon": 1e-05, 10 | "memoria_stm_capacity": 400, 11 | "memoria_num_memories": 50, 12 | "memoria_initial_lifespan": 9, 13 | "memoria_lifespan_extend_scale": 8.0, 14 | "memoria_ltm_search_depth": 10, 15 | "memoria_reset_period": 500, 16 | "memoria_num_reminded_stm": 50, 17 | "memoria_num_reminded_ltm": 50, 18 | "memoria_device": null, 19 | "n_ctx": 1024, 20 | "n_embd": 1280, 21 | "n_head": 20, 22 | "n_inner": null, 23 | "n_layer": 36, 24 | "n_positions": 1024, 25 | "reorder_and_upcast_attn": false, 26 | "resid_pdrop": 0.1, 27 | "scale_attn_by_inverse_layer_idx": false, 28 | "scale_attn_weights": true, 29 | "summary_activation": null, 30 | "summary_first_dropout": 0.1, 31 | "summary_proj_to_labels": true, 32 | "summary_type": "cls_index", 33 | "summary_use_proj": true, 34 | "task_specific_params": { 35 | "text-generation": { 36 | "do_sample": true, 37 | "max_length": 50 38 | } 39 | }, 40 | "transformers_version": "4.25.1", 41 | "use_cache": true, 42 | "vocab_size": 50257 43 | } 44 | -------------------------------------------------------------------------------- /experiment/configs/memoria-gpt2-medium.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "gpt2_with_memoria", 3 | "activation_function": "gelu_new", 4 | "attn_pdrop": 0.1, 5 | "bos_token_id": 50256, 6 | "embd_pdrop": 0.1, 7 | "eos_token_id": 50256, 8 | "initializer_range": 0.02, 9 | "layer_norm_epsilon": 1e-05, 10 | "memoria_stm_capacity": 400, 11 | "memoria_num_memories": 50, 12 | "memoria_initial_lifespan": 9, 13 | "memoria_lifespan_extend_scale": 8.0, 14 | "memoria_ltm_search_depth": 10, 15 | "memoria_reset_period": 500, 16 | "memoria_num_reminded_stm": 50, 17 | "memoria_num_reminded_ltm": 50, 18 | "memoria_device": null, 19 | "n_ctx": 1024, 20 | "n_embd": 1024, 21 | "n_head": 16, 22 | "n_inner": null, 23 | "n_layer": 24, 24 | "n_positions": 1024, 25 | "reorder_and_upcast_attn": false, 26 | "resid_pdrop": 0.1, 27 | "scale_attn_by_inverse_layer_idx": false, 28 | "scale_attn_weights": true, 29 | "summary_activation": null, 30 | "summary_first_dropout": 0.1, 31 | "summary_proj_to_labels": true, 32 | "summary_type": "cls_index", 33 | "summary_use_proj": true, 34 | "task_specific_params": { 35 | "text-generation": { 36 | "do_sample": true, 37 | "max_length": 50 38 | } 39 | }, 40 | "transformers_version": "4.25.1", 41 | "use_cache": true, 42 | "vocab_size": 50257 43 | } 44 | -------------------------------------------------------------------------------- /experiment/configs/memoria-gpt2-sort.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "gpt2_with_memoria", 3 | "activation_function": "gelu_new", 4 | "attn_pdrop": 0.1, 5 | "bos_token_id": 50256, 6 | "embd_pdrop": 0.1, 7 | "eos_token_id": 50256, 8 | "initializer_range": 0.02, 9 | "layer_norm_epsilon": 1e-05, 10 | "memoria_stm_capacity": 512, 11 | "memoria_num_memories": 128, 12 | "memoria_initial_lifespan": 5, 13 | "memoria_lifespan_extend_scale": 8.0, 14 | "memoria_ltm_search_depth": 10, 15 | "memoria_reset_period": 500, 16 | "memoria_num_reminded_stm": 256, 17 | "memoria_num_reminded_ltm": 640, 18 | "memoria_device": null, 19 | "n_ctx": 10000, 20 | "n_embd": 512, 21 | "n_head": 4, 22 | "n_inner": 2048, 23 | "n_layer": 4, 24 | "n_positions": 10000, 25 | "reorder_and_upcast_attn": false, 26 | "resid_pdrop": 0.1, 27 | "scale_attn_by_inverse_layer_idx": false, 28 | "scale_attn_weights": true, 29 | "summary_activation": null, 30 | "summary_first_dropout": 0.1, 31 | "summary_proj_to_labels": true, 32 | "summary_type": "cls_index", 33 | "summary_use_proj": true, 34 | "task_specific_params": { 35 | "text-generation": { 36 | "do_sample": true, 37 | "max_length": 50 38 | } 39 | }, 40 | "transformers_version": "4.25.1", 41 | "use_cache": true, 42 | "vocab_size": 50257 43 | } 44 | -------------------------------------------------------------------------------- /experiment/configs/memoria-gpt2-synthetic.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "gpt2_with_memoria", 3 | "activation_function": "gelu_new", 4 | "attn_pdrop": 0.1, 5 | "bos_token_id": 50256, 6 | "embd_pdrop": 0.1, 7 | "eos_token_id": 50256, 8 | "initializer_range": 0.02, 9 | "layer_norm_epsilon": 1e-05, 10 | "memoria_stm_capacity": 32, 11 | "memoria_num_memories": null, 12 | "memoria_initial_lifespan": 5, 13 | "memoria_lifespan_extend_scale": 6.0, 14 | "memoria_ltm_search_depth": 30, 15 | "memoria_reset_period": 500, 16 | "memoria_num_reminded_stm": 8, 17 | "memoria_num_reminded_ltm": 8, 18 | "memoria_device": null, 19 | "n_ctx": 1024, 20 | "n_embd": 512, 21 | "n_head": 4, 22 | "n_inner": null, 23 | "n_layer": 4, 24 | "n_positions": 1024, 25 | "reorder_and_upcast_attn": false, 26 | "resid_pdrop": 0.1, 27 | "scale_attn_by_inverse_layer_idx": false, 28 | "scale_attn_weights": true, 29 | "summary_activation": null, 30 | "summary_first_dropout": 0.1, 31 | "summary_proj_to_labels": true, 32 | "summary_type": "cls_index", 33 | "summary_use_proj": true, 34 | "task_specific_params": { 35 | "text-generation": { 36 | "do_sample": true, 37 | "max_length": 50 38 | } 39 | }, 40 | "transformers_version": "4.25.1", 41 | "use_cache": true, 42 | "vocab_size": 50257 43 | } 44 | -------------------------------------------------------------------------------- /experiment/configs/memoria-gpt2-xl.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "gpt2_with_memoria", 3 | "activation_function": "gelu_new", 4 | "attn_pdrop": 0.1, 5 | "bos_token_id": 50256, 6 | "embd_pdrop": 0.1, 7 | "eos_token_id": 50256, 8 | "initializer_range": 0.02, 9 | "layer_norm_epsilon": 1e-05, 10 | "memoria_stm_capacity": 400, 11 | "memoria_num_memories": 50, 12 | "memoria_initial_lifespan": 9, 13 | "memoria_lifespan_extend_scale": 8.0, 14 | "memoria_ltm_search_depth": 10, 15 | "memoria_reset_period": 500, 16 | "memoria_num_reminded_stm": 50, 17 | "memoria_num_reminded_ltm": 50, 18 | "memoria_device": null, 19 | "n_ctx": 1024, 20 | "n_embd": 1600, 21 | "n_head": 25, 22 | "n_inner": null, 23 | "n_layer": 48, 24 | "n_positions": 1024, 25 | "reorder_and_upcast_attn": false, 26 | "resid_pdrop": 0.1, 27 | "scale_attn_by_inverse_layer_idx": false, 28 | "scale_attn_weights": true, 29 | "summary_activation": null, 30 | "summary_first_dropout": 0.1, 31 | "summary_proj_to_labels": true, 32 | "summary_type": "cls_index", 33 | "summary_use_proj": true, 34 | "task_specific_params": { 35 | "text-generation": { 36 | "do_sample": true, 37 | "max_length": 50 38 | } 39 | }, 40 | "transformers_version": "4.25.1", 41 | "use_cache": true, 42 | "vocab_size": 50257 43 | } 44 | -------------------------------------------------------------------------------- /experiment/configs/memoria-gpt2.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "gpt2_with_memoria", 3 | "activation_function": "gelu_new", 4 | "attn_pdrop": 0.1, 5 | "bos_token_id": 50256, 6 | "embd_pdrop": 0.1, 7 | "eos_token_id": 50256, 8 | "initializer_range": 0.02, 9 | "layer_norm_epsilon": 1e-05, 10 | "memoria_stm_capacity": 400, 11 | "memoria_num_memories": 50, 12 | "memoria_initial_lifespan": 9, 13 | "memoria_lifespan_extend_scale": 8.0, 14 | "memoria_ltm_search_depth": 10, 15 | "memoria_reset_period": 500, 16 | "memoria_num_reminded_stm": 50, 17 | "memoria_num_reminded_ltm": 50, 18 | "memoria_device": null, 19 | "n_ctx": 1024, 20 | "n_embd": 768, 21 | "n_head": 12, 22 | "n_inner": null, 23 | "n_layer": 12, 24 | "n_positions": 1024, 25 | "reorder_and_upcast_attn": false, 26 | "resid_pdrop": 0.1, 27 | "scale_attn_by_inverse_layer_idx": false, 28 | "scale_attn_weights": true, 29 | "summary_activation": null, 30 | "summary_first_dropout": 0.1, 31 | "summary_proj_to_labels": true, 32 | "summary_type": "cls_index", 33 | "summary_use_proj": true, 34 | "task_specific_params": { 35 | "text-generation": { 36 | "do_sample": true, 37 | "max_length": 50 38 | } 39 | }, 40 | "transformers_version": "4.25.1", 41 | "use_cache": true, 42 | "vocab_size": 50257 43 | } 44 | -------------------------------------------------------------------------------- /experiment/configs/transfo-xl-enwik8.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "transfo-xl", 3 | "adaptive": true, 4 | "attn_type": 0, 5 | "clamp_len": 1000, 6 | "cutoffs": [ 7 | 204 8 | ], 9 | "d_embed": 512, 10 | "d_head": 64, 11 | "d_inner": 2048, 12 | "d_model": 512, 13 | "div_val": 4, 14 | "dropatt": 0.0, 15 | "dropout": 0.1, 16 | "eos_token_id": 0, 17 | "init": "normal", 18 | "init_range": 0.01, 19 | "init_std": 0.02, 20 | "layer_norm_epsilon": 1e-05, 21 | "mem_len": 512, 22 | "n_head": 8, 23 | "n_layer": 12, 24 | "pre_lnorm": false, 25 | "proj_init_std": 0.01, 26 | "same_length": true, 27 | "sample_softmax": -1, 28 | "untie_r": true, 29 | "vocab_size": 204 30 | } 31 | -------------------------------------------------------------------------------- /experiment/configs/transfo-xl-sort.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "transfo-xl", 3 | "adaptive": true, 4 | "attn_type": 0, 5 | "clamp_len": 1000, 6 | "cutoffs": [ 7 | 21 8 | ], 9 | "d_embed": 512, 10 | "d_head": 128, 11 | "d_inner": 2048, 12 | "d_model": 512, 13 | "div_val": 4, 14 | "dropatt": 0.0, 15 | "dropout": 0.1, 16 | "eos_token_id": 0, 17 | "init": "normal", 18 | "init_range": 0.01, 19 | "init_std": 0.02, 20 | "layer_norm_epsilon": 1e-05, 21 | "mem_len": 1024, 22 | "n_head": 4, 23 | "n_layer": 4, 24 | "pre_lnorm": false, 25 | "proj_init_std": 0.01, 26 | "same_length": true, 27 | "sample_softmax": -1, 28 | "untie_r": true, 29 | "vocab_size": 50257 30 | } 31 | -------------------------------------------------------------------------------- /experiment/configs/transfo-xl-synthetic.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "transfo-xl", 3 | "adaptive": true, 4 | "attn_type": 0, 5 | "clamp_len": 1000, 6 | "cutoffs": [ 7 | 10 8 | ], 9 | "d_embed": 512, 10 | "d_head": 64, 11 | "d_inner": 2048, 12 | "d_model": 512, 13 | "div_val": 4, 14 | "dropatt": 0.0, 15 | "dropout": 0.1, 16 | "eos_token_id": 0, 17 | "init": "normal", 18 | "init_range": 0.01, 19 | "init_std": 0.02, 20 | "layer_norm_epsilon": 1e-05, 21 | "mem_len": 100, 22 | "n_head": 4, 23 | "n_layer": 4, 24 | "pre_lnorm": false, 25 | "proj_init_std": 0.01, 26 | "same_length": true, 27 | "sample_softmax": -1, 28 | "untie_r": true, 29 | "vocab_size": 50257 30 | } 31 | -------------------------------------------------------------------------------- /experiment/configs/transfo-xl.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "transfo-xl", 3 | "adaptive": true, 4 | "attn_type": 0, 5 | "clamp_len": 1000, 6 | "cutoffs": [ 7 | 50257 8 | ], 9 | "d_embed": 768, 10 | "d_head": 64, 11 | "d_inner": 3072, 12 | "d_model": 768, 13 | "div_val": 4, 14 | "dropatt": 0.0, 15 | "dropout": 0.1, 16 | "eos_token_id": 0, 17 | "init": "normal", 18 | "init_range": 0.01, 19 | "init_std": 0.02, 20 | "layer_norm_epsilon": 1e-05, 21 | "mem_len": 150, 22 | "n_head": 12, 23 | "n_layer": 12, 24 | "pre_lnorm": false, 25 | "proj_init_std": 0.01, 26 | "same_length": true, 27 | "sample_softmax": -1, 28 | "untie_r": true, 29 | "vocab_size": 50257 30 | } 31 | -------------------------------------------------------------------------------- /experiment/eval_classification.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Dict 3 | 4 | import pytorch_lightning as pl 5 | import torch 6 | from longseq_formers.data import CLASSIFICATION_DATASETS, load_hyperpartisan_data 7 | from longseq_formers.dataset import ClassificationDataset 8 | from longseq_formers.task import Classification 9 | from longseq_formers.utils import get_logger 10 | from torch.utils.data import DataLoader 11 | from transformers import AutoTokenizer 12 | 13 | # fmt: off 14 | parser = argparse.ArgumentParser(prog="train_classification", description="Train & Test Long Sequence Classification") 15 | 16 | g = parser.add_argument_group("Train Parameter") 17 | g.add_argument("--model", type=str, required=True, help="lightning checkpoint") 18 | g.add_argument("--tokenizer", type=str, required=True, help="huggingface tokenizer") 19 | g.add_argument("--dataset", type=str, default="hyperpartisan", choices=CLASSIFICATION_DATASETS, help="dataset name") 20 | g.add_argument("--valid-batch-size", type=int, default=1, help="validation batch size") 21 | g.add_argument("--max-length", type=int, default=512, help="max sequence length") 22 | g.add_argument("--memory-length", type=int, default=512, help="max sequence length for bert one inference on infinity former") 23 | g.add_argument("--seed", type=int, default=42, help="random seed") 24 | g.add_argument("--not-truncate", action="store_false", dest="truncation", help="not truncate sequence") 25 | g.add_argument("--segment-size", type=int, help="segment size for infinity former") 26 | # fmt: on 27 | 28 | 29 | def main(args: argparse.Namespace) -> dict[str, float]: 30 | logger = get_logger("evaluate_classification") 31 | 32 | logger.info(f"[+] Set Random Seed to {args.seed}") 33 | pl.seed_everything(args.seed, workers=True) 34 | 35 | logger.info(f'[+] Load Tokenizer: "{args.tokenizer}"') 36 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) 37 | 38 | logger.info(f'[+] Use Dataset: "{args.dataset}"') 39 | if args.dataset == "hyperpartisan": 40 | datasets = load_hyperpartisan_data() 41 | 42 | valid_dataset = ClassificationDataset(datasets["dev"]) 43 | test_dataset = ClassificationDataset(datasets["test"]) 44 | 45 | logger.info(f"[+] # of valid examples: {len(valid_dataset)}") 46 | logger.info(f"[+] # of test examples: {len(test_dataset)}") 47 | 48 | logger.info(f'[+] Load Model: "{args.model}"') 49 | classification = Classification.load_from_checkpoint( 50 | args.model, tokenizer=tokenizer, max_length=args.max_length, truncation=args.truncation 51 | ) 52 | 53 | collate_fn = ClassificationDataset.pad_collate_fn if not args.truncation else None 54 | valid_dataloader = DataLoader(valid_dataset, batch_size=args.valid_batch_size, collate_fn=collate_fn) 55 | test_dataloader = DataLoader(test_dataset, batch_size=args.valid_batch_size, collate_fn=collate_fn) 56 | 57 | tester = pl.Trainer(accelerator="gpu" if torch.cuda.device_count() else None, devices=1) 58 | 59 | pl.seed_everything(args.seed, workers=True) 60 | result1 = tester.test(classification, valid_dataloader)[0] 61 | 62 | pl.seed_everything(args.seed, workers=True) 63 | result2 = tester.test(classification, test_dataloader)[0] 64 | 65 | print(result1) 66 | print(result2) 67 | 68 | 69 | if __name__ == "__main__": 70 | main(parser.parse_args()) 71 | exit(0) 72 | -------------------------------------------------------------------------------- /experiment/eval_language_modeling.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Dict 3 | 4 | import pytorch_lightning as pl 5 | import torch 6 | from longseq_formers.data import ( 7 | LANGUAGE_MODELING_DATASETS, 8 | enwik8_tokenize, 9 | load_enwik8_data, 10 | load_pg19_data, 11 | load_wikitext103_data, 12 | ) 13 | from longseq_formers.dataset import LanguageModelingDataset, text_to_tokens 14 | from longseq_formers.task import LanguageModeling 15 | from longseq_formers.utils import get_logger 16 | from torch.utils.data import DataLoader 17 | from transformers import AutoTokenizer 18 | 19 | # fmt: off 20 | parser = argparse.ArgumentParser(prog="evaluate", description="Evaluate Language Modeling") 21 | 22 | g = parser.add_argument_group("Eval Parameter") 23 | g.add_argument("--model", type=str, required=True, help="huggingface model") 24 | g.add_argument("--tokenizer", type=str, default="gpt2", help="huggingface tokenizer") 25 | g.add_argument("--dataset", type=str, default="wikitext103", choices=LANGUAGE_MODELING_DATASETS, help="dataset name") 26 | g.add_argument("--valid-batch-size", type=int, default=1, help="validation batch size") 27 | g.add_argument("--max-length", type=int, default=512, help="max sequence length") 28 | g.add_argument("--seed", type=int, default=42, help="random seed") 29 | # fmt: on 30 | 31 | 32 | def main(args: argparse.Namespace) -> dict[str, float]: 33 | logger = get_logger("test_language_modeling") 34 | 35 | logger.info(" ====== Arguements ======") 36 | for k, v in vars(args).items(): 37 | logger.info(f"{k:25}: {v}") 38 | 39 | logger.info(f"[+] Set Random Seed to {args.seed}") 40 | pl.seed_everything(args.seed, workers=True) 41 | 42 | gpus = torch.cuda.device_count() 43 | logger.info(f"[+] GPU: {gpus}") 44 | 45 | if args.tokenizer is None: 46 | logger.info(f"[+] Use tokenizer same as model: {args.model}") 47 | args.tokenizer = args.model 48 | if args.dataset == "enwik8": 49 | logger.info(f"[+] Use character tokenizer for enwik8 dataset") 50 | tokenizer = enwik8_tokenize 51 | else: 52 | logger.info(f'[+] Load Tokenizer: "{args.tokenizer}"') 53 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) 54 | 55 | logger.info(f'[+] Use Dataset: "{args.dataset}"') 56 | if args.dataset == "wikitext103": 57 | data = load_wikitext103_data() 58 | elif args.dataset == "pg19": 59 | data = load_pg19_data() 60 | elif args.dataset == "enwik8": 61 | data = load_enwik8_data() 62 | else: 63 | raise ValueError(f"dataset `{args.dataset}` is not valid!") 64 | 65 | dev_tokens = text_to_tokens(data["dev"], tokenizer, args.valid_batch_size, args.max_length) 66 | test_tokens = text_to_tokens(data["test"], tokenizer, args.valid_batch_size, args.max_length) 67 | 68 | valid_dataset = LanguageModelingDataset(dev_tokens) 69 | test_dataset = LanguageModelingDataset(test_tokens) 70 | 71 | logger.info(f"[+] # of batched valid examples: {len(valid_dataset)}") 72 | logger.info(f"[+] # of batched test examples: {len(test_dataset)}") 73 | 74 | language_modeling = LanguageModeling.load_from_checkpoint(args.model) 75 | 76 | # Use batch size as 1 because already batched 77 | # train_dataloader = DataLoader(train_dataset, batch_size=1, collate_fn=LanguageModelingDataset.collate_fn) 78 | valid_dataloader = DataLoader(valid_dataset, batch_size=1, collate_fn=LanguageModelingDataset.collate_fn) 79 | test_dataloader = DataLoader(test_dataset, batch_size=1, collate_fn=LanguageModelingDataset.collate_fn) 80 | tester = pl.Trainer(accelerator="gpu" if gpus else None, devices=1) 81 | 82 | pl.seed_everything(args.seed, workers=True) 83 | result1 = tester.test(language_modeling, valid_dataloader)[0] 84 | 85 | pl.seed_everything(args.seed, workers=True) 86 | result2 = tester.test(language_modeling, test_dataloader)[0] 87 | 88 | print(result1) 89 | print(result2) 90 | 91 | 92 | if __name__ == "__main__": 93 | main(parser.parse_args()) 94 | exit(0) 95 | -------------------------------------------------------------------------------- /experiment/eval_synthetic.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Dict 3 | 4 | import pytorch_lightning as pl 5 | import torch 6 | from longseq_formers.dataset.synthetic import SyntheticDataset, parse_syntetic_data 7 | from longseq_formers.task import Synthetic 8 | from longseq_formers.utils import get_logger 9 | from torch.utils.data import DataLoader 10 | 11 | # fmt: off 12 | parser = argparse.ArgumentParser(prog="train_synthetic", description="Train & Test Synthetic Task") 13 | 14 | g = parser.add_argument_group("Train Parameter") 15 | g.add_argument("--model", type=str, required=True, help="model checkpoint") 16 | g.add_argument("--dataset", type=str, required=True, help="dataset name") 17 | g.add_argument("--valid-batch-size", type=int, default=1, help="validation batch size") 18 | g.add_argument("--max-length", type=int, default=150, help="max sequence length") 19 | g.add_argument("--seed", type=int, default=42, help="random seed") 20 | g.add_argument("--shuffle", action="store_true", help="shuffle data order") 21 | # fmt: on 22 | 23 | 24 | def main(args: argparse.Namespace) -> dict[str, float]: 25 | logger = get_logger("eval_synthetic_task") 26 | 27 | logger.info(" ====== Arguements ======") 28 | for k, v in vars(args).items(): 29 | logger.info(f"{k:25}: {v}") 30 | 31 | logger.info(f"[+] Set Random Seed to {args.seed}") 32 | pl.seed_everything(args.seed, workers=True) 33 | 34 | logger.info(f'[+] Use Dataset: "{args.dataset}"') 35 | _, vocab_size, _, dev_examples, test_examples = parse_syntetic_data(args.dataset) 36 | 37 | valid_dataset = SyntheticDataset(dev_examples) 38 | test_dataset = SyntheticDataset(test_examples) 39 | 40 | logger.info(f"[+] # of batched valid examples: {len(valid_dataset)}") 41 | logger.info(f"[+] # of batched test examples: {len(test_dataset)}") 42 | 43 | valid_dataloader = DataLoader(valid_dataset, batch_size=args.valid_batch_size) 44 | test_dataloader = DataLoader(test_dataset, batch_size=args.valid_batch_size) 45 | 46 | synthetic_task = Synthetic.load_from_checkpoint(args.model, vocab_size=vocab_size) 47 | 48 | logger.info(f"[+] Start Evaluation") 49 | 50 | tester = pl.Trainer(accelerator="gpu" if torch.cuda.device_count() else None, devices=1) 51 | 52 | pl.seed_everything(args.seed, workers=True) 53 | result1 = tester.test(synthetic_task, valid_dataloader)[0] 54 | 55 | pl.seed_everything(args.seed, workers=True) 56 | result2 = tester.test(synthetic_task, test_dataloader)[0] 57 | 58 | print(result1) 59 | print(result2) 60 | 61 | 62 | if __name__ == "__main__": 63 | main(parser.parse_args()) 64 | exit(0) 65 | -------------------------------------------------------------------------------- /experiment/longseq_formers/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data, dataset, model, task 2 | 3 | __all__ = ["data", "dataset", "model", "task"] 4 | -------------------------------------------------------------------------------- /experiment/longseq_formers/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .enwik8 import enwik8_tokenize, load_enwik8_data 2 | from .hyperpartisan import load_hyperpartisan_data 3 | from .pg19 import load_pg19_data 4 | from .wikitext103 import load_wikitext103_data 5 | 6 | CLASSIFICATION_DATASETS = ["hyperpartisan"] 7 | LANGUAGE_MODELING_DATASETS = ["wikitext103", "pg19", "enwik8"] 8 | DATASETS = CLASSIFICATION_DATASETS + LANGUAGE_MODELING_DATASETS 9 | 10 | __all__ = [ 11 | "enwik8_tokenize", 12 | "load_enwik8_data", 13 | "load_hyperpartisan_data", 14 | "load_pg19_data", 15 | "load_wikitext103_data", 16 | "DATASETS", 17 | "CLASSIFICATION_DATASETS", 18 | "LANGUAGE_MODELING_DATASETS", 19 | ] 20 | -------------------------------------------------------------------------------- /experiment/longseq_formers/data/enwik8.py: -------------------------------------------------------------------------------- 1 | """ 2 | Refered https://github.com/salesforce/awd-lstm-lm/blob/master/data/enwik8/prep_enwik8.py 3 | """ 4 | 5 | from typing import Dict 6 | 7 | from datasets import Dataset, DatasetDict, load_dataset 8 | 9 | # fmt: off 10 | CHAR_INDICES = ['9', '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '80', '81', '82', '83', '84', '85', '86', '87', '88', '89', '90', '91', '92', '93', '94', '95', '96', '97', '98', '99', '100', '101', '102', '103', '104', '105', '106', '107', '108', '109', '110', '111', '112', '113', '114', '115', '116', '117', '118', '119', '120', '121', '122', '123', '124', '125', '126', '128', '129', '130', '131', '132', '133', '134', '135', '136', '137', '138', '139', '140', '141', '142', '143', '144', '145', '146', '147', '148', '149', '150', '151', '152', '153', '154', '155', '156', '157', '158', '159', '160', '161', '162', '163', '164', '165', '166', '167', '168', '169', '170', '171', '172', '173', '174', '175', '176', '177', '178', '179', '180', '181', '182', '183', '184', '185', '186', '187', '188', '189', '190', '191', '194', '195', '196', '197', '198', '199', '200', '201', '202', '203', '204', '205', '206', '207', '208', '209', '210', '211', '212', '213', '214', '215', '216', '217', '218', '219', '220', '222', '224', '225', '226', '227', '228', '229', '230', '231', '232', '233', '234', '235', '236', '237', '239', '240'] 11 | CHAR_TO_INDEX = {c: i for i, c in enumerate(CHAR_INDICES)} 12 | VOCAB_SIZE = len(CHAR_INDICES) 13 | assert VOCAB_SIZE == 204 14 | # fmt: on 15 | 16 | 17 | def enwik8_tokenize(text: str) -> Dict: 18 | input_ids = [CHAR_TO_INDEX[c] for c in text.split()] 19 | return {"input_ids": input_ids, "attention_mask": [1.0] * len(input_ids)} 20 | 21 | 22 | def load_enwik8_data() -> Dataset: 23 | dataset = load_dataset("enwik8", "enwik8-raw", revision="a3d620ecedec0d39511d1dfdc3a27a69e648be84")["train"] 24 | 25 | num_test_chars = 5000000 26 | 27 | def _preprocess(data): 28 | whole_text = data["text"] 29 | whole_bytes = whole_text.encode() 30 | 31 | train_data = whole_bytes[: -2 * num_test_chars] 32 | valid_data = whole_bytes[-2 * num_test_chars : -num_test_chars] 33 | test_data = whole_bytes[-num_test_chars:] 34 | 35 | train, dev, test = ( 36 | " ".join([str(c) if c != ord("\n") else "\n" for c in part]) for part in (train_data, valid_data, test_data) 37 | ) 38 | 39 | return {"train": train, "dev": dev, "test": test} 40 | 41 | dataset = dataset.map(_preprocess, remove_columns=dataset.column_names, load_from_cache_file=True) 42 | 43 | train = dataset["train"][0] 44 | dev = dataset["dev"][0] 45 | test = dataset["test"][0] 46 | 47 | def _gen(source): 48 | yield {"text": source} 49 | 50 | train_dataset = Dataset.from_generator(_gen, gen_kwargs={"source": train}) 51 | dev_dataset = Dataset.from_generator(_gen, gen_kwargs={"source": dev}) 52 | test_dataset = Dataset.from_generator(_gen, gen_kwargs={"source": test}) 53 | dataset = DatasetDict({"train": train_dataset, "dev": dev_dataset, "test": test_dataset}) 54 | return dataset 55 | -------------------------------------------------------------------------------- /experiment/longseq_formers/data/hyperpartisan.py: -------------------------------------------------------------------------------- 1 | import re 2 | from collections import defaultdict 3 | from typing import Dict, List 4 | 5 | import datasets 6 | from bs4 import BeautifulSoup 7 | 8 | from ..dataset.classification import ClassificationDatum 9 | 10 | # fmt: off 11 | # hp-splits from longformer (https://github.com/allenai/longformer/blob/master/scripts/hp-splits.json) 12 | HYPERPARTISAN_SPLITS = { 13 | "train": [239, 342, 401, 424, 518, 374, 457, 81, 208, 216, 112, 77, 448, 596, 388, 505, 362, 180, 587, 398, 636, 297, 363, 389, 148, 567, 163, 549, 472, 26, 427, 227, 213, 470, 346, 383, 585, 352, 22, 20, 390, 3, 97, 439, 637, 197, 392, 480, 225, 414, 333, 561, 615, 359, 598, 107, 12, 195, 54, 459, 23, 455, 624, 233, 17, 499, 307, 416, 578, 568, 220, 334, 65, 73, 170, 215, 447, 446, 606, 276, 502, 534, 582, 241, 425, 356, 192, 301, 514, 589, 466, 207, 82, 201, 391, 366, 476, 594, 477, 126, 393, 508, 158, 483, 604, 206, 15, 353, 372, 512, 543, 330, 290, 539, 444, 399, 410, 169, 125, 487, 74, 381, 479, 556, 292, 576, 224, 173, 441, 205, 29, 559, 509, 552, 317, 231, 296, 643, 524, 209, 433, 397, 488, 18, 553, 149, 380, 168, 484, 234, 586, 486, 555, 232, 246, 373, 139, 458, 157, 644, 257, 91, 53, 59, 341, 159, 36, 109, 2, 106, 485, 258, 422, 404, 313, 402, 183, 419, 283, 87, 351, 75, 187, 310, 320, 19, 304, 38, 471, 129, 66, 151, 266, 268, 548, 328, 405, 371, 580, 51, 492, 474, 510, 468, 396, 308, 408, 526, 622, 511, 63, 274, 531, 128, 368, 599, 426, 43, 360, 541, 454, 263, 407, 138, 76, 530, 517, 165, 641, 436, 493, 326, 194, 202, 546, 238, 382, 92, 52, 120, 437, 71, 504, 532, 237, 314, 625, 617, 605, 171, 331, 456, 607, 542, 55, 475, 584, 251, 611, 40, 122, 100, 570, 338, 137, 597, 101, 324, 95, 577, 31, 116, 176, 145, 211, 236, 627, 143, 638, 620, 219, 10, 60, 198, 7, 293, 452, 590, 579, 141, 558, 160, 214, 166, 593, 538, 33, 364, 635, 119, 250, 223, 319, 619, 339, 616, 618, 284, 533, 603, 302, 49, 588, 572, 575, 515, 21, 1, 103, 150, 529, 506, 69, 343, 323, 482, 222, 535, 188, 14, 299, 489, 108, 140, 39, 420, 285, 86, 554, 259, 564, 400, 269, 281, 248, 272, 24, 629, 130, 226, 525, 80, 117, 115, 305, 370, 465, 186, 93, 113, 46, 461, 378, 184, 336, 50, 309, 48, 72, 495, 131, 507, 325, 298, 412, 406, 240, 278, 212, 279, 5, 90, 181, 8, 288, 61, 300, 174, 608, 58, 520, 449, 218, 294, 354, 494, 417, 99, 154, 89, 527, 273, 11, 162, 610, 179, 56, 613, 329, 377, 335, 253, 501, 442, 252, 614, 327, 98, 88, 631, 609, 547, 376, 581, 621, 152, 228, 4, 565, 540, 132, 110, 191, 30, 6, 189, 303, 270, 255, 415, 172, 64, 267, 503, 78, 118, 235, 435, 167, 453, 282, 573, 291, 642, 123, 395, 551, 94, 450, 478, 311, 289, 153, 102, 421, 277, 583, 164, 244, 229, 178, 217, 523, 96, 280, 68, 497, 430, 190, 516, 445, 428, 633, 536, 434, 387, 355, 528, 287, 144, 210, 295, 385, 185, 467, 256, 44, 83, 67, 175, 204, 602, 42, 358, 384, 28, 45, 569, 127, 47, 491, 265, 463, 121, 135, 460], 14 | "dev": [182, 438, 545, 286, 142, 27, 394, 261, 411, 0, 79, 550, 640, 254, 560, 386, 62, 440, 104, 473, 155, 432, 124, 133, 136, 519, 322, 318, 245, 249, 612, 349, 623, 591, 429, 306, 592, 375, 203, 544, 312, 114, 41, 344, 571, 134, 462, 347, 464, 566, 350, 199, 562, 357, 361, 521, 574, 315, 243, 601, 260, 409, 337, 177], 15 | "test": [537, 517, 23, 459, 593, 258, 227, 16, 204, 367, 159, 142, 214, 82, 182, 564, 411, 600, 610, 306, 21, 434, 625, 197, 202, 489, 404, 400, 551, 320, 36, 435, 344, 183, 134, 19, 253, 231, 383, 572, 201, 528, 15, 116, 265, 221, 462, 342, 465, 436, 490, 442, 547, 282, 535, 256, 160, 140, 555, 51, 540, 165, 504, 181, 147], 16 | } 17 | # fmt: on 18 | 19 | 20 | def load_hyperpartisan_data() -> dict[str, list[ClassificationDatum]]: 21 | """Load Hyperpartisan dataset 22 | 23 | Returns: 24 | datasets like below 25 | { 26 | "train": [{ 27 | "text": "...", 28 | "label": 1 29 | }, {...}, ...], 30 | "dev": ..., 31 | "test": ... 32 | } 33 | """ 34 | data = datasets.load_dataset( 35 | "hyperpartisan_news_detection", "byarticle", revision="c315cc4a12a27cde08fd55c0beda41ced8b75923" 36 | )["train"] 37 | 38 | split_datasets = defaultdict(list) 39 | for split, indices in HYPERPARTISAN_SPLITS.items(): 40 | for index in indices: 41 | datum = data[index] 42 | normalized_text = BeautifulSoup(datum["text"], "html.parser").get_text() 43 | text = re.sub(r"\s+", " ", normalized_text) 44 | label = int(datum["hyperpartisan"]) 45 | split_datasets[split].append({"text": text, "label": label}) 46 | return split_datasets 47 | -------------------------------------------------------------------------------- /experiment/longseq_formers/data/pg19.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | 3 | 4 | def load_pg19_data(train_dataset_percent: int = 7) -> datasets.Dataset: 5 | dataset = datasets.load_dataset( 6 | "pg19", 7 | revision="dd75f494ab94328d0ce92c05390ab91a96920a9d", 8 | split={ 9 | "train": f"train[:{train_dataset_percent}%]", 10 | "dev": "validation", 11 | "test": "test", 12 | }, 13 | ) 14 | return dataset 15 | -------------------------------------------------------------------------------- /experiment/longseq_formers/data/wikitext103.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import datasets 4 | 5 | 6 | def load_wikitext103_data() -> datasets.Dataset: 7 | dataset = datasets.load_dataset( 8 | "wikitext", 9 | "wikitext-103-raw-v1", 10 | revision="dfd72879b14bf51e8f831b4b092c4f58f356a70f", 11 | split={"train": f"train", "dev": "validation", "test": "test"}, 12 | ) 13 | 14 | def _join_segment_text(example): 15 | whole_text = "".join(example["text"]) 16 | start_idxs = [m.start() - 1 for m in re.finditer(r"\n\s*= [^=]+ =\s*\n", whole_text)] 17 | all_idxs = [0] + start_idxs + [len(whole_text)] 18 | segments = [whole_text[all_idxs[i] : all_idxs[i + 1]].strip() for i in range(len(all_idxs) - 1)] 19 | return {"text": segments} 20 | 21 | dataset = dataset.map( 22 | _join_segment_text, 23 | load_from_cache_file=True, 24 | batched=True, 25 | batch_size=len(dataset["train"]), 26 | drop_last_batch=False, 27 | remove_columns=["text"], 28 | ) 29 | return dataset 30 | -------------------------------------------------------------------------------- /experiment/longseq_formers/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .classification import ClassificationDataset, ClassificationDatum 2 | from .language_modeling import LanguageModelingDataset, text_to_tokens 3 | from .synthetic import SyntheticDataset 4 | 5 | __all__ = [ 6 | "ClassificationDataset", 7 | "ClassificationDatum", 8 | "LanguageModelingDataset", 9 | "text_to_tokens", 10 | "SyntheticDataset", 11 | ] 12 | -------------------------------------------------------------------------------- /experiment/longseq_formers/dataset/classification.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, TypedDict 2 | 3 | import torch 4 | from torch.nn.utils.rnn import pad_sequence 5 | from transformers import AutoTokenizer 6 | 7 | 8 | class ClassificationDatum(TypedDict): 9 | text: str 10 | label: int 11 | 12 | 13 | class ClassificationDataset(torch.utils.data.Dataset): 14 | """ClassificationDataset 15 | 16 | Attributes: 17 | data: data for text classification 18 | tokenizer: huggingface tokenizer 19 | max_length: token max length 20 | """ 21 | 22 | def __init__( 23 | self, data: list[ClassificationDatum], tokenizer: AutoTokenizer, max_length: int, truncation: bool = True 24 | ) -> None: 25 | super().__init__() 26 | 27 | self.data = data 28 | self.tokenizer = tokenizer 29 | self.max_length = max_length 30 | self.truncation = truncation 31 | 32 | def __len__(self) -> int: 33 | return len(self.data) 34 | 35 | def __getitem__(self, index: int) -> dict[str, torch.Tensor]: 36 | item = self.data[index] 37 | text = item["text"] 38 | label = item["label"] 39 | 40 | inputs = self.tokenizer( 41 | text, 42 | add_special_tokens=True, 43 | max_length=self.max_length, 44 | truncation=self.truncation, 45 | padding="max_length", 46 | return_token_type_ids=True, 47 | return_tensors="pt", 48 | ) 49 | inputs = {k: v.squeeze(dim=0) for k, v in inputs.items()} 50 | inputs["labels"] = torch.tensor(label) 51 | 52 | return inputs 53 | 54 | @staticmethod 55 | def pad_collate_fn(batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]: 56 | # [NumTimeSteps, BatchSize, MaxSequenceLength] 57 | padded_batch = {k: [item[k] for item in batch] for k in batch[0].keys()} 58 | for k in padded_batch: 59 | if k == "labels": 60 | padded_batch[k] = torch.stack(padded_batch[k], dim=0) 61 | else: 62 | padded_batch[k] = pad_sequence(padded_batch[k], batch_first=True) 63 | 64 | return padded_batch 65 | -------------------------------------------------------------------------------- /experiment/longseq_formers/dataset/language_modeling.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional 2 | 3 | import datasets 4 | import torch 5 | from transformers import AutoTokenizer 6 | 7 | 8 | def text_to_tokens( 9 | dataset: datasets.Dataset, 10 | tokenizer: AutoTokenizer, 11 | batch_size: int, 12 | max_length: int, 13 | batch_size_per_device: Optional[int] = None, 14 | ) -> datasets.Dataset: 15 | """Tokenize a series of text into tokens and chunk 16 | the processed datasets will be cached automatically. 17 | 18 | Args: 19 | dataset: huggingface dataset containing "text" field 20 | tokenizer: huggingface tokenizer 21 | batch_size: batch size, in same batch, there's sequential dataset. 22 | max_length: max length of each example. the remainder will be dropped. 23 | batch_size_per_device: batch size per device with using DDP. 24 | Return: 25 | huggingface input dictionary. 26 | the values shaped [NumExamples, BatchSize, MaxLength] 27 | """ 28 | 29 | def _tokenize(example): 30 | return tokenizer(example["text"]) 31 | 32 | token_dataset = dataset.map(_tokenize, remove_columns=dataset.column_names, load_from_cache_file=True) 33 | 34 | def _segment(example): 35 | num_segments = len(example["input_ids"]) // max_length 36 | return { 37 | "data": [ 38 | {k: v[i * max_length : (i + 1) * max_length] for k, v in example.items()} for i in range(num_segments) 39 | ], 40 | "is_end": [False] * (num_segments - 1) + [True] if num_segments else [], 41 | } 42 | 43 | segment_dataset = token_dataset.map(_segment, remove_columns=token_dataset.column_names, load_from_cache_file=True) 44 | 45 | def _merge(examples): 46 | data = examples["data"] 47 | is_ends = examples["is_end"] 48 | merged = {k: [example[k] for datum in data for example in datum] for k in data[0][0].keys()} 49 | merged["is_end"] = [v for is_end in is_ends for v in is_end] 50 | return merged 51 | 52 | merge_dataset = segment_dataset.map( 53 | _merge, 54 | remove_columns=segment_dataset.column_names, 55 | load_from_cache_file=True, 56 | batched=True, 57 | batch_size=len(segment_dataset), 58 | ) 59 | 60 | num_examples = len(merge_dataset) // batch_size 61 | 62 | def _batching(example): 63 | return { 64 | k: [v[i : num_examples * batch_size : num_examples] for i in range(num_examples)] 65 | for k, v in example.items() 66 | } 67 | 68 | batch_dataset = merge_dataset.map(_batching, load_from_cache_file=True, batched=True, batch_size=len(merge_dataset)) 69 | 70 | def _rebatching_for_multi_device(example): 71 | return { 72 | k: [v[0][i : i + batch_size_per_device] for i in range(0, batch_size, batch_size_per_device)] 73 | for k, v in example.items() 74 | } 75 | 76 | if batch_size_per_device is not None and batch_size != batch_size_per_device: 77 | batch_dataset = batch_dataset.map( 78 | _rebatching_for_multi_device, load_from_cache_file=True, batched=True, batch_size=1 79 | ) 80 | batch_dataset.set_format(type="torch", columns=batch_dataset.column_names) 81 | return batch_dataset 82 | 83 | 84 | class LanguageModelingDataset(torch.utils.data.Dataset): 85 | def __init__(self, data: datasets.Dataset) -> None: 86 | super().__init__() 87 | 88 | self.data = {k: data[k] for k in data.column_names} 89 | 90 | def __len__(self) -> int: 91 | return len(self.data["input_ids"]) 92 | 93 | def __getitem__(self, index: int) -> dict[str, torch.Tensor]: 94 | # [BatchSize, MaxLength] 95 | inputs = {k: v[index] for k, v in self.data.items()} 96 | inputs["labels"] = inputs["input_ids"] 97 | return inputs 98 | 99 | @staticmethod 100 | def collate_fn(batches: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]: 101 | """Select first item becuase batch size is 1""" 102 | assert len(batches) == 1 103 | return batches[0] 104 | -------------------------------------------------------------------------------- /experiment/longseq_formers/dataset/synthetic.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Dict, List, Tuple 3 | 4 | import torch 5 | 6 | 7 | def parse_syntetic_data( 8 | path: str, 9 | ) -> Tuple[int, int, list[dict[str, list[int]]], list[dict[str, list[int]]], list[dict[str, list[int]]]]: 10 | with open(path, "r") as f: 11 | data = json.load(f) 12 | 13 | prompt_length = data["prompt_length"] 14 | vocab_size = data["vocab_size"] 15 | train_examples = data["train"] 16 | dev_examples = data["dev"] 17 | test_examples = data["test"] 18 | 19 | return prompt_length, vocab_size, train_examples, dev_examples, test_examples 20 | 21 | 22 | class SyntheticDataset(torch.utils.data.Dataset): 23 | def __init__(self, examples: list[dict[str, list[int]]]) -> None: 24 | super().__init__() 25 | 26 | self.data = examples 27 | 28 | def __len__(self) -> int: 29 | return len(self.data) 30 | 31 | def __getitem__(self, index: int) -> dict[str, torch.Tensor]: 32 | example = self.data[index] 33 | input_ids = example["prompt_ids"] + example["target_ids"][:-1] 34 | labels = [-100] * (len(example["prompt_ids"]) - 1) + example["target_ids"] 35 | return {"input_ids": torch.tensor(input_ids), "labels": torch.tensor(labels)} 36 | -------------------------------------------------------------------------------- /experiment/longseq_formers/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .compressive_former import CompressiveFormerConfig, CompressiveFormerLMHeadModel, CompressiveFormerModel 2 | from .gpt2_with_memoria import GPT2WithMemoriaConfig, GPT2WithMemoriaLMHeadModel, GPT2WithMemoriaModel 3 | from .infinity_gpt2 import InfinityGPT2Config, InfinityGPT2LMHeadModel, InfinityGPT2Model 4 | from .memoria_bert import MemoriaBertConfig, MemoriaBertForSequenceClassification, MemoriaBertModel 5 | from .memoria_roberta import MemoriaRobertaConfig, MemoriaRobertaForSequenceClassification, MemoriaRobertaModel 6 | 7 | __all__ = [ 8 | "CompressiveFormerConfig", 9 | "CompressiveFormerLMHeadModel", 10 | "CompressiveFormerModel", 11 | "GPT2WithMemoriaConfig", 12 | "GPT2WithMemoriaLMHeadModel", 13 | "GPT2WithMemoriaModel", 14 | "InfinityGPT2Config", 15 | "InfinityGPT2LMHeadModel", 16 | "InfinityGPT2Model", 17 | "MemoriaBertConfig", 18 | "MemoriaBertForSequenceClassification", 19 | "MemoriaBertModel", 20 | "MemoriaRobertaConfig", 21 | "MemoriaRobertaForSequenceClassification", 22 | "MemoriaRobertaModel", 23 | ] 24 | -------------------------------------------------------------------------------- /experiment/longseq_formers/model/compressive_former/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_compressive_transformer import ( 2 | CompressiveFormerConfig, 3 | CompressiveFormerLMHeadModel, 4 | CompressiveFormerModel, 5 | ) 6 | 7 | __all__ = [ 8 | "CompressiveFormerConfig", 9 | "CompressiveFormerLMHeadModel", 10 | "CompressiveFormerModel", 11 | ] 12 | -------------------------------------------------------------------------------- /experiment/longseq_formers/model/gpt2_with_memoria/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_gpt2_with_memoria import GPT2WithMemoriaConfig, GPT2WithMemoriaLMHeadModel, GPT2WithMemoriaModel 2 | 3 | __all__ = [ 4 | "GPT2WithMemoriaConfig", 5 | "GPT2WithMemoriaLMHeadModel", 6 | "GPT2WithMemoriaModel", 7 | ] 8 | -------------------------------------------------------------------------------- /experiment/longseq_formers/model/infinity_gpt2/__init__.py: -------------------------------------------------------------------------------- 1 | from .configuration_infinity_gpt2 import InfinityGPT2Config 2 | from .modeling_infinity_gpt2 import InfinityGPT2LMHeadModel, InfinityGPT2Model 3 | 4 | __all__ = [ 5 | "InfinityGPT2Config", 6 | "InfinityGPT2LMHeadModel", 7 | "InfinityGPT2Model", 8 | ] 9 | -------------------------------------------------------------------------------- /experiment/longseq_formers/model/infinity_gpt2/basis_functions.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | 6 | class BasisFunctions(object): 7 | def __init__(self): 8 | pass 9 | 10 | def __len__(self): 11 | """Number of basis functions.""" 12 | pass 13 | 14 | def evaluate(self, t): 15 | pass 16 | 17 | def integrate_t2_times_psi(self, a, b): 18 | """Compute integral int_a^b (t**2) * psi(t).""" 19 | pass 20 | 21 | def integrate_t_times_psi(self, a, b): 22 | """Compute integral int_a^b t * psi(t).""" 23 | pass 24 | 25 | def integrate_psi(self, a, b): 26 | """Compute integral int_a^b psi(t).""" 27 | pass 28 | 29 | 30 | class PowerBasisFunctions(BasisFunctions): 31 | """Function phi(t) = t**degree.""" 32 | 33 | def __init__(self, degree): 34 | self.degree = degree.unsqueeze(0) 35 | 36 | def __len__(self): 37 | """Number of basis functions.""" 38 | return self.degree.size(1) 39 | 40 | def evaluate(self, t): 41 | return t**self.degree 42 | 43 | def integrate_t2_times_psi(self, a, b): 44 | """Compute integral int_a^b (t**2) * psi(t).""" 45 | return (b ** (self.degree + 3) - a ** (self.degree + 3)) / (self.degree + 3) 46 | 47 | def integrate_t_times_psi(self, a, b): 48 | """Compute integral int_a^b t * psi(t).""" 49 | return (b ** (self.degree + 2) - a ** (self.degree + 2)) / (self.degree + 2) 50 | 51 | def integrate_psi(self, a, b): 52 | """Compute integral int_a^b psi(t).""" 53 | return (b ** (self.degree + 1) - a ** (self.degree + 1)) / (self.degree + 1) 54 | 55 | def __repr__(self): 56 | return f"PowerBasisFunction(degree={self.degree})" 57 | 58 | 59 | class SineBasisFunctions(BasisFunctions): 60 | """Function phi(t) = sin(omega*t).""" 61 | 62 | def __init__(self, omega): 63 | self.omega = omega.unsqueeze(0) 64 | 65 | def __repr__(self): 66 | return f"SineBasisFunction(omega={self.omega})" 67 | 68 | def __len__(self): 69 | """Number of basis functions.""" 70 | return self.omega.size(1) 71 | 72 | def evaluate(self, t): 73 | return torch.sin(self.omega * t) 74 | 75 | def integrate_t2_times_psi(self, a, b): 76 | """Compute integral int_a^b (t**2) * psi(t).""" 77 | # The antiderivative of (t**2)*sin(omega*t) is 78 | # ((2-(t**2)*(omega**2))*cos(omega*t) + 2*omega*t*sin(omega*t)) / omega**3. # noqa 79 | return ( 80 | (2 - (b**2) * (self.omega**2)) * torch.cos(self.omega * b) 81 | + 2 * self.omega * b * torch.sin(self.omega * b) 82 | - (2 - (a**2) * (self.omega**2)) * torch.cos(self.omega * a) 83 | - 2 * self.omega * a * torch.sin(self.omega * a) 84 | ) / (self.omega**3) 85 | 86 | def integrate_t_times_psi(self, a, b): 87 | """Compute integral int_a^b t * psi(t).""" 88 | # The antiderivative of t*sin(omega*t) is 89 | # (sin(omega*t) - omega*t*cos(omega*t)) / omega**2. 90 | return ( 91 | torch.sin(self.omega * b) 92 | - self.omega * b * torch.cos(self.omega * b) 93 | - torch.sin(self.omega * a) 94 | + self.omega * a * torch.cos(self.omega * a) 95 | ) / (self.omega**2) 96 | 97 | def integrate_psi(self, a, b): 98 | """Compute integral int_a^b psi(t).""" 99 | # The antiderivative of sin(omega*t) is -cos(omega*t)/omega. 100 | return (-torch.cos(self.omega * b) + torch.cos(self.omega * a)) / self.omega 101 | 102 | 103 | class CosineBasisFunctions(BasisFunctions): 104 | """Function phi(t) = cos(omega*t).""" 105 | 106 | def __init__(self, omega): 107 | self.omega = omega.unsqueeze(0) 108 | 109 | def __repr__(self): 110 | return f"CosineBasisFunction(omega={self.omega})" 111 | 112 | def __len__(self): 113 | """Number of basis functions.""" 114 | return self.omega.size(1) 115 | 116 | def evaluate(self, t): 117 | return torch.cos(self.omega * t) 118 | 119 | def integrate_t2_times_psi(self, a, b): 120 | """Compute integral int_a^b (t**2) * psi(t).""" 121 | # The antiderivative of (t**2)*cos(omega*t) is 122 | # (((t**2)*(omega**2)-2)*cos(omega*t) + 2*omega*t*sin(omega*t)) / omega**3. # noqa 123 | return ( 124 | ((b**2) * (self.omega**2) - 2) * torch.sin(self.omega * b) 125 | + 2 * self.omega * b * torch.cos(self.omega * b) 126 | - ((a**2) * (self.omega**2) - 2) * torch.sin(self.omega * a) 127 | - 2 * self.omega * a * torch.cos(self.omega * a) 128 | ) / (self.omega**3) 129 | 130 | def integrate_t_times_psi(self, a, b): 131 | """Compute integral int_a^b t * psi(t).""" 132 | # The antiderivative of t*cos(omega*t) is 133 | # (cos(omega*t) + omega*t*sin(omega*t)) / omega**2. 134 | return ( 135 | torch.cos(self.omega * b) 136 | + self.omega * b * torch.sin(self.omega * b) 137 | - torch.cos(self.omega * a) 138 | - self.omega * a * torch.sin(self.omega * a) 139 | ) / (self.omega**2) 140 | 141 | def integrate_psi(self, a, b): 142 | """Compute integral int_a^b psi(t).""" 143 | # The antiderivative of cos(omega*t) is sin(omega*t)/omega. 144 | return (torch.sin(self.omega * b) - torch.sin(self.omega * a)) / self.omega 145 | 146 | 147 | class GaussianBasisFunctions(BasisFunctions): 148 | """Function phi(t) = Gaussian(t; mu, sigma_sq) 149 | 150 | Attributes: 151 | mu: mu shaped [1, NumBasis] 152 | sigma: sigma shaped [1, NumBasis] 153 | """ 154 | 155 | def __init__(self, mu, sigma): 156 | self.mu = mu.unsqueeze(0) 157 | self.sigma = sigma.unsqueeze(0) 158 | 159 | def __repr__(self): 160 | return f"GaussianBasisFunction(mu={self.mu}, sigma={self.sigma})" 161 | 162 | def __len__(self): 163 | """Number of basis functions.""" 164 | return self.mu.size(1) 165 | 166 | def _phi(self, t): 167 | return 1.0 / math.sqrt(2 * math.pi) * torch.exp(-0.5 * t**2) 168 | 169 | def _Phi(self, t): 170 | return 0.5 * (1 + torch.erf(t / math.sqrt(2))) 171 | 172 | def _integrate_product_of_gaussians(self, mu, sigma_sq): 173 | sigma = torch.sqrt(self.sigma**2 + sigma_sq) 174 | return self._phi((mu - self.mu) / sigma) / sigma 175 | 176 | def evaluate(self, t): 177 | """Return Gaussian Function value 178 | 179 | Args: 180 | t: [BatchSize, NumBasis] or [BatchSize, 1] or scalar 181 | considered same value for all basis if NumBasis shape is none. 182 | Return: 183 | Gaussian function value shaped [BatchSize, NumBasis] 184 | """ 185 | return self._phi((t - self.mu) / self.sigma) / self.sigma 186 | 187 | def integrate_t2_times_psi(self, a, b): 188 | """Compute integral int_a^b (t**2) * psi(t).""" 189 | return ( 190 | (self.mu**2 + self.sigma**2) 191 | * (self._Phi((b - self.mu) / self.sigma) - self._Phi((a - self.mu) / self.sigma)) 192 | - (self.sigma * (b + self.mu) * self._phi((b - self.mu) / self.sigma)) 193 | + (self.sigma * (a + self.mu) * self._phi((a - self.mu) / self.sigma)) 194 | ) 195 | 196 | def integrate_t_times_psi(self, a, b): 197 | """Compute integral int_a^b t * psi(t).""" 198 | return self.mu * ( 199 | self._Phi((b - self.mu) / self.sigma) - self._Phi((a - self.mu) / self.sigma) 200 | ) - self.sigma * (self._phi((b - self.mu) / self.sigma) - self._phi((a - self.mu) / self.sigma)) 201 | 202 | def integrate_psi(self, a, b): 203 | """Compute integral int_a^b psi(t).""" 204 | return self._Phi((b - self.mu) / self.sigma) - self._Phi((a - self.mu) / self.sigma) 205 | 206 | def integrate_t2_times_psi_gaussian(self, mu, sigma_sq): 207 | """Compute integral int N(t; mu, sigma_sq) * t**2 * psi(t).""" 208 | S_tilde = self._integrate_product_of_gaussians(mu, sigma_sq) 209 | mu_tilde = (self.mu * sigma_sq + mu * self.sigma**2) / (self.sigma**2 + sigma_sq) 210 | sigma_sq_tilde = ((self.sigma**2) * sigma_sq) / (self.sigma**2 + sigma_sq) 211 | return S_tilde * (mu_tilde**2 + sigma_sq_tilde) 212 | 213 | def integrate_t_times_psi_gaussian(self, mu, sigma_sq): 214 | """Compute integral int N(t; mu, sigma_sq) * t * psi(t).""" 215 | S_tilde = self._integrate_product_of_gaussians(mu, sigma_sq) 216 | mu_tilde = (self.mu * sigma_sq + mu * self.sigma**2) / (self.sigma**2 + sigma_sq) 217 | return S_tilde * mu_tilde 218 | 219 | def integrate_psi_gaussian(self, mu, sigma_sq): 220 | """Compute integral int N(t; mu, sigma_sq) * psi(t).""" 221 | return self._integrate_product_of_gaussians(mu, sigma_sq) 222 | -------------------------------------------------------------------------------- /experiment/longseq_formers/model/infinity_gpt2/configuration_infinity_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ OpenAI GPT-2 configuration """ 17 | 18 | from transformers import AutoConfig 19 | from transformers.configuration_utils import PretrainedConfig 20 | from transformers.utils import logging 21 | 22 | logger = logging.get_logger(__name__) 23 | 24 | 25 | class InfinityGPT2Config(PretrainedConfig): 26 | """ 27 | This is the configuration class to store the configuration of a :class:`~transformers.GPT2Model` or a 28 | :class:`~transformers.TFGPT2Model`. It is used to instantiate a GPT-2 model according to the specified arguments, 29 | defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration 30 | to that of the GPT-2 `small `__ architecture. 31 | 32 | Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model 33 | outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. 34 | 35 | 36 | Args: 37 | vocab_size (:obj:`int`, `optional`, defaults to 50257): 38 | Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the 39 | :obj:`inputs_ids` passed when calling :class:`~transformers.GPT2Model` or 40 | :class:`~transformers.TFGPT2Model`. 41 | n_positions (:obj:`int`, `optional`, defaults to 1024): 42 | The maximum sequence length that this model might ever be used with. Typically set this to something large 43 | just in case (e.g., 512 or 1024 or 2048). 44 | n_ctx (:obj:`int`, `optional`, defaults to 1024): 45 | Dimensionality of the causal mask (usually same as n_positions). 46 | n_embd (:obj:`int`, `optional`, defaults to 768): 47 | Dimensionality of the embeddings and hidden states. 48 | n_layer (:obj:`int`, `optional`, defaults to 12): 49 | Number of hidden layers in the Transformer encoder. 50 | n_head (:obj:`int`, `optional`, defaults to 12): 51 | Number of attention heads for each attention layer in the Transformer encoder. 52 | n_inner (:obj:`int`, `optional`, defaults to None): 53 | Dimensionality of the inner feed-forward layers. :obj:`None` will set it to 4 times n_embd 54 | activation_function (:obj:`str`, `optional`, defaults to :obj:`"gelu"`): 55 | Activation function, to be selected in the list :obj:`["relu", "silu", "gelu", "tanh", "gelu_new"]`. 56 | resid_pdrop (:obj:`float`, `optional`, defaults to 0.1): 57 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 58 | embd_pdrop (:obj:`int`, `optional`, defaults to 0.1): 59 | The dropout ratio for the embeddings. 60 | attn_pdrop (:obj:`float`, `optional`, defaults to 0.1): 61 | The dropout ratio for the attention. 62 | layer_norm_epsilon (:obj:`float`, `optional`, defaults to 1e-5): 63 | The epsilon to use in the layer normalization layers 64 | initializer_range (:obj:`float`, `optional`, defaults to 0.02): 65 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 66 | summary_type (:obj:`string`, `optional`, defaults to :obj:`"cls_index"`): 67 | Argument used when doing sequence summary, used in the models :class:`~transformers.GPT2DoubleHeadsModel` 68 | and :class:`~transformers.TFGPT2DoubleHeadsModel`. 69 | 70 | Has to be one of the following options: 71 | 72 | - :obj:`"last"`: Take the last token hidden state (like XLNet). 73 | - :obj:`"first"`: Take the first token hidden state (like BERT). 74 | - :obj:`"mean"`: Take the mean of all tokens hidden states. 75 | - :obj:`"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2). 76 | - :obj:`"attn"`: Not implemented now, use multi-head attention. 77 | summary_use_proj (:obj:`bool`, `optional`, defaults to :obj:`True`): 78 | Argument used when doing sequence summary, used in the models :class:`~transformers.GPT2DoubleHeadsModel` 79 | and :class:`~transformers.TFGPT2DoubleHeadsModel`. 80 | 81 | Whether or not to add a projection after the vector extraction. 82 | summary_activation (:obj:`str`, `optional`): 83 | Argument used when doing sequence summary. Used in for the multiple choice head in 84 | :class:`~transformers.GPT2DoubleHeadsModel`. 85 | 86 | Pass :obj:`"tanh"` for a tanh activation to the output, any other value will result in no activation. 87 | summary_proj_to_labels (:obj:`bool`, `optional`, defaults to :obj:`True`): 88 | Argument used when doing sequence summary, used in the models :class:`~transformers.GPT2DoubleHeadsModel` 89 | and :class:`~transformers.TFGPT2DoubleHeadsModel`. 90 | 91 | Whether the projection outputs should have :obj:`config.num_labels` or :obj:`config.hidden_size` classes. 92 | summary_first_dropout (:obj:`float`, `optional`, defaults to 0.1): 93 | Argument used when doing sequence summary, used in the models :class:`~transformers.GPT2DoubleHeadsModel` 94 | and :class:`~transformers.TFGPT2DoubleHeadsModel`. 95 | 96 | The dropout ratio to be used after the projection and activation. 97 | gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): 98 | Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass. 99 | use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): 100 | Whether or not the model should return the last key/values attentions (not used by all models). 101 | 102 | Example:: 103 | 104 | >>> from transformers import GPT2Model, GPT2Config 105 | 106 | >>> # Initializing a GPT2 configuration 107 | >>> configuration = GPT2Config() 108 | 109 | >>> # Initializing a model from the configuration 110 | >>> model = GPT2Model(configuration) 111 | 112 | >>> # Accessing the model configuration 113 | >>> configuration = model.config 114 | """ 115 | 116 | model_type = "infinity_gpt2" 117 | keys_to_ignore_at_inference = ["past_key_values"] 118 | 119 | def __init__( 120 | self, 121 | vocab_size=50257, 122 | n_positions=512, 123 | n_ctx=512, 124 | n_embd=1024, 125 | n_layer=24, 126 | n_head=16, 127 | n_inner=None, 128 | activation_function="gelu_new", 129 | resid_pdrop=0.1, 130 | embd_pdrop=0.1, 131 | attn_drop=0.1, 132 | attn_pdrop=0.1, 133 | layer_norm_epsilon=1e-5, 134 | initializer_range=0.02, 135 | summary_type="cls_index", 136 | summary_use_proj=True, 137 | summary_activation=None, 138 | summary_proj_to_labels=True, 139 | summary_first_dropout=0.1, 140 | gradient_checkpointing=False, 141 | use_cache=True, 142 | bos_token_id=50256, 143 | eos_token_id=50256, 144 | memory_length=150, 145 | num_basis=150, 146 | num_samples=150, 147 | tau=0.5, 148 | normalize_function="softmax", 149 | mask_type="cnn", 150 | mask_dropout=0.1, 151 | longterm_attention_dropout=0.1, 152 | use_affines=True, 153 | use_kl_regularizer=True, 154 | use_sticky_memories=True, 155 | mu_0=-1.0, 156 | sigma_0=0.05, 157 | kl_lambda=1e-6, 158 | detach_recursive_outputs=True, 159 | **kwargs 160 | ): 161 | super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) 162 | 163 | self.vocab_size = vocab_size 164 | self.n_ctx = n_ctx 165 | self.n_positions = n_positions 166 | self.n_embd = n_embd 167 | self.n_layer = n_layer 168 | self.n_head = n_head 169 | self.n_inner = n_inner 170 | self.activation_function = activation_function 171 | self.resid_pdrop = resid_pdrop 172 | self.embd_pdrop = embd_pdrop 173 | self.attn_drop = attn_drop 174 | self.attn_pdrop = attn_pdrop 175 | self.layer_norm_epsilon = layer_norm_epsilon 176 | self.initializer_range = initializer_range 177 | self.summary_type = summary_type 178 | self.summary_use_proj = summary_use_proj 179 | self.summary_activation = summary_activation 180 | self.summary_first_dropout = summary_first_dropout 181 | self.summary_proj_to_labels = summary_proj_to_labels 182 | self.gradient_checkpointing = gradient_checkpointing 183 | self.use_cache = use_cache 184 | 185 | self.bos_token_id = bos_token_id 186 | self.eos_token_id = eos_token_id 187 | 188 | self.memory_length = memory_length 189 | self.num_basis = num_basis 190 | self.num_samples = num_samples 191 | self.tau = tau 192 | self.normalize_function = normalize_function 193 | self.mask_type = mask_type 194 | self.mask_dropout = mask_dropout 195 | self.longterm_attention_dropout = longterm_attention_dropout 196 | self.use_affines = use_affines 197 | self.use_kl_regularizer = use_kl_regularizer 198 | self.use_sticky_memories = use_sticky_memories 199 | self.mu_0 = mu_0 200 | self.sigma_0 = sigma_0 201 | self.kl_lambda = kl_lambda 202 | self.detach_recursive_outputs = detach_recursive_outputs 203 | 204 | @property 205 | def max_position_embeddings(self): 206 | return self.n_positions 207 | 208 | @property 209 | def hidden_size(self): 210 | return self.n_embd 211 | 212 | @property 213 | def num_attention_heads(self): 214 | return self.n_head 215 | 216 | @property 217 | def num_hidden_layers(self): 218 | return self.n_layer 219 | 220 | 221 | InfinityGPT2Config.register_for_auto_class() 222 | AutoConfig.register("infinity_gpt2", InfinityGPT2Config) 223 | -------------------------------------------------------------------------------- /experiment/longseq_formers/model/infinity_gpt2/continuous_softmax.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd.function import FunctionCtx 6 | 7 | from .basis_functions import GaussianBasisFunctions 8 | 9 | 10 | class ContinuousSoftmaxFunction(torch.autograd.Function): 11 | @classmethod 12 | def _expectation_phi_psi(cls, ctx: FunctionCtx, mu: torch.FloatTensor, sigma_sq: torch.FloatTensor): 13 | """Compute expectation of phi(t) * psi(t).T under N(mu, sigma_sq).""" 14 | num_basis = [len(basis_functions) for basis_functions in ctx.psi] 15 | total_basis = sum(num_basis) 16 | V = torch.zeros((mu.shape[0], 2, total_basis), dtype=ctx.dtype, device=ctx.device) 17 | offsets = torch.cumsum(torch.tensor(num_basis, dtype=torch.int, device=ctx.device), dim=0) 18 | start = 0 19 | for j, basis_functions in enumerate(ctx.psi): 20 | V[:, 0, start : offsets[j]] = basis_functions.integrate_t_times_psi_gaussian(mu, sigma_sq) 21 | V[:, 1, start : offsets[j]] = basis_functions.integrate_t2_times_psi_gaussian(mu, sigma_sq) 22 | start = offsets[j] 23 | return V 24 | 25 | @classmethod 26 | def _expectation_psi( 27 | cls, ctx: FunctionCtx, mu: torch.FloatTensor, sigma_sq: torch.FloatTensor 28 | ) -> torch.FloatTensor: 29 | """Compute expectation of psi under N(mu, sigma_sq). 30 | 31 | Args: 32 | mu: mu of distribution shaped [BatchSize, 1] 33 | sigma_sq: sigma_sq of distribution shaped [BatchSize, 1] 34 | Return: 35 | integraded result shaped [BatchSize, TotalBasis] 36 | """ 37 | psi: list[GaussianBasisFunctions] = ctx.psi 38 | num_basis = [len(basis_functions) for basis_functions in psi] 39 | total_basis = sum(num_basis) 40 | r = torch.zeros(mu.shape[0], total_basis, dtype=ctx.dtype, device=ctx.device) 41 | offsets = torch.cumsum(torch.tensor(num_basis, dtype=torch.int, device=ctx.device), dim=0) 42 | start = 0 43 | for j, basis_functions in enumerate(psi): 44 | r[:, start : offsets[j]] = basis_functions.integrate_psi_gaussian(mu, sigma_sq) 45 | start = offsets[j] 46 | return r 47 | 48 | @classmethod 49 | def _expectation_phi(cls, ctx: FunctionCtx, mu: torch.FloatTensor, sigma_sq: torch.FloatTensor): 50 | """Compute expectation of phi under N(mu, sigma_sq).""" 51 | v = torch.zeros(mu.shape[0], 2, dtype=ctx.dtype, device=ctx.device) 52 | v[:, 0] = mu.squeeze(1) 53 | v[:, 1] = (mu**2 + sigma_sq).squeeze(1) 54 | return v 55 | 56 | @classmethod 57 | def forward( 58 | cls, ctx: FunctionCtx, theta: torch.FloatTensor, psi: list[GaussianBasisFunctions] 59 | ) -> torch.FloatTensor: 60 | """ 61 | We assume a Gaussian. 62 | We have: 63 | theta = [mu/sigma**2, -1/(2*sigma**2)], 64 | phi(t) = [t, t**2], 65 | p(t) = Gaussian(t; mu, sigma**2). 66 | 67 | Args: 68 | theta: shaped [BatchSize, 2] 69 | psi: list of basis functions 70 | """ 71 | ctx.dtype = theta.dtype 72 | ctx.device = theta.device 73 | ctx.psi = psi 74 | # sigma_sq, mu: [BatchSize, 1] 75 | sigma_sq = (-0.5 / theta[:, 1]).unsqueeze(1) 76 | mu = theta[:, 0].unsqueeze(1) * sigma_sq 77 | 78 | r = cls._expectation_psi(ctx, mu, sigma_sq) 79 | ctx.save_for_backward(mu, sigma_sq, r) 80 | return r 81 | 82 | @classmethod 83 | def backward(cls, ctx: FunctionCtx, grad_output): 84 | mu, sigma_sq, r = ctx.saved_tensors 85 | J = cls._expectation_phi_psi(ctx, mu, sigma_sq) 86 | e_phi = cls._expectation_phi(ctx, mu, sigma_sq) 87 | e_psi = cls._expectation_psi(ctx, mu, sigma_sq) 88 | J -= torch.bmm(e_phi.unsqueeze(2), e_psi.unsqueeze(1)) 89 | grad_input = torch.matmul(J, grad_output.unsqueeze(2)).squeeze(2) 90 | return grad_input, None 91 | 92 | 93 | class ContinuousSoftmax(nn.Module): 94 | def __init__(self, psi: Optional[list[GaussianBasisFunctions]] = None): 95 | super(ContinuousSoftmax, self).__init__() 96 | self.psi = psi 97 | 98 | def forward(self, theta): 99 | return ContinuousSoftmaxFunction.apply(theta, self.psi) 100 | -------------------------------------------------------------------------------- /experiment/longseq_formers/model/infinity_gpt2/continuous_sparsemax.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ContinuousSparsemaxFunction(torch.autograd.Function): 6 | @classmethod 7 | def _integrate_phi_times_psi(cls, ctx, a, b): 8 | """Compute integral int_a^b phi(t) * psi(t).T.""" 9 | num_basis = [len(basis_functions) for basis_functions in ctx.psi] 10 | total_basis = sum(num_basis) 11 | V = torch.zeros((a.shape[0], 2, total_basis), dtype=ctx.dtype, device=ctx.device) 12 | offsets = torch.cumsum(torch.tensor(num_basis, dtype=torch.int, device=ctx.device), dim=0) 13 | start = 0 14 | for j, basis_functions in enumerate(ctx.psi): 15 | V[:, 0, start : offsets[j]] = basis_functions.integrate_t_times_psi(a, b) 16 | V[:, 1, start : offsets[j]] = basis_functions.integrate_t2_times_psi(a, b) 17 | start = offsets[j] 18 | return V 19 | 20 | @classmethod 21 | def _integrate_psi(cls, ctx, a, b): 22 | """Compute integral int_a^b psi(t).""" 23 | num_basis = [len(basis_functions) for basis_functions in ctx.psi] 24 | total_basis = sum(num_basis) 25 | v = torch.zeros(a.shape[0], total_basis, dtype=ctx.dtype, device=ctx.device) 26 | offsets = torch.cumsum(torch.tensor(num_basis, dtype=torch.int, device=ctx.device), dim=0) 27 | start = 0 28 | for j, basis_functions in enumerate(ctx.psi): 29 | v[:, start : offsets[j]] = basis_functions.integrate_psi(a, b) 30 | start = offsets[j] 31 | return v 32 | 33 | @classmethod 34 | def _integrate_phi(cls, ctx, a, b): 35 | """Compute integral int_a^b phi(t).""" 36 | v = torch.zeros(a.shape[0], 2, dtype=ctx.dtype, device=ctx.device) 37 | v[:, 0] = ((b**2 - a**2) / 2).squeeze(1) 38 | v[:, 1] = ((b**3 - a**3) / 3).squeeze(1) 39 | return v 40 | 41 | @classmethod 42 | def forward(cls, ctx, theta, psi): 43 | # We assume a truncated parabola. 44 | # We have: 45 | # theta = [mu/sigma**2, -1/(2*sigma**2)], 46 | # phi(t) = [t, t**2], 47 | # p(t) = [theta.dot(phi(t)) - A]_+, 48 | # supported on [mu - a, mu + a]. 49 | ctx.dtype = theta.dtype 50 | ctx.device = theta.device 51 | ctx.psi = psi 52 | sigma = torch.sqrt(-0.5 / theta[:, 1]) 53 | mu = theta[:, 0] * sigma**2 54 | A = -0.5 * (3.0 / (2 * sigma)) ** (2.0 / 3) 55 | a = torch.sqrt(-2 * A) * sigma 56 | A += mu**2 / (2 * sigma**2) 57 | left = (mu - a).unsqueeze(1) 58 | right = (mu + a).unsqueeze(1) 59 | V = cls._integrate_phi_times_psi(ctx, left, right) 60 | u = cls._integrate_psi(ctx, left, right) 61 | r = torch.matmul(theta.unsqueeze(1), V).squeeze(1) - A.unsqueeze(1) * u 62 | ctx.save_for_backward(mu, a, V, u) 63 | return r 64 | 65 | @classmethod 66 | def backward(cls, ctx, grad_output): 67 | mu, a, V, u = ctx.saved_tensors 68 | # J.T = int_{-a}^{+a} phi(t+mu)*psi(t+mu).T 69 | # - (int_{-a}^{+a} phi(t+mu)) * (int_{-a}^{+a} psi(t+mu).T) / (2*a) 70 | left = (mu - a).unsqueeze(1) 71 | right = (mu + a).unsqueeze(1) 72 | i_phi = cls._integrate_phi(ctx, left, right) 73 | ger = torch.bmm(i_phi.unsqueeze(2), u.unsqueeze(1)) 74 | # ger = torch.einsum('bi,bj->bij', (i_phi, u)) 75 | J = V - ger / (2 * a.unsqueeze(1).unsqueeze(2)) 76 | grad_input = torch.matmul(J, grad_output.unsqueeze(2)).squeeze(2) 77 | return grad_input, None 78 | 79 | 80 | class ContinuousSparsemax(nn.Module): 81 | def __init__(self, psi=None): 82 | super(ContinuousSparsemax, self).__init__() 83 | self.psi = psi 84 | 85 | def forward(self, theta): 86 | return ContinuousSparsemaxFunction.apply(theta, self.psi) 87 | -------------------------------------------------------------------------------- /experiment/longseq_formers/model/memoria_bert/__init__.py: -------------------------------------------------------------------------------- 1 | from .configuration_memoria_bert import MemoriaBertConfig 2 | from .modeling_memoria_bert import MemoriaBertForSequenceClassification, MemoriaBertModel 3 | 4 | __all__ = [ 5 | "MemoriaBertConfig", 6 | "MemoriaBertForSequenceClassification", 7 | "MemoriaBertModel", 8 | ] 9 | -------------------------------------------------------------------------------- /experiment/longseq_formers/model/memoria_bert/configuration_memoria_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ BERT model configuration""" 17 | 18 | from typing import Optional 19 | 20 | from transformers import AutoConfig 21 | from transformers.configuration_utils import PretrainedConfig 22 | from transformers.utils import logging 23 | 24 | logger = logging.get_logger(__name__) 25 | 26 | 27 | class MemoriaBertConfig(PretrainedConfig): 28 | r""" 29 | Args: 30 | vocab_size (`int`, *optional*, defaults to 30522): 31 | Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the 32 | `inputs_ids` passed when calling [`BertModel`] or [`TFBertModel`]. 33 | hidden_size (`int`, *optional*, defaults to 768): 34 | Dimensionality of the encoder layers and the pooler layer. 35 | num_hidden_layers (`int`, *optional*, defaults to 12): 36 | Number of hidden layers in the Transformer encoder. 37 | num_attention_heads (`int`, *optional*, defaults to 12): 38 | Number of attention heads for each attention layer in the Transformer encoder. 39 | intermediate_size (`int`, *optional*, defaults to 3072): 40 | Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. 41 | hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): 42 | The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, 43 | `"relu"`, `"silu"` and `"gelu_new"` are supported. 44 | hidden_dropout_prob (`float`, *optional*, defaults to 0.1): 45 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 46 | attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): 47 | The dropout ratio for the attention probabilities. 48 | max_position_embeddings (`int`, *optional*, defaults to 512): 49 | The maximum sequence length that this model might ever be used with. Typically set this to something large 50 | just in case (e.g., 512 or 1024 or 2048). 51 | type_vocab_size (`int`, *optional*, defaults to 2): 52 | The vocabulary size of the `token_type_ids` passed when calling [`BertModel`] or [`TFBertModel`]. 53 | initializer_range (`float`, *optional*, defaults to 0.02): 54 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 55 | layer_norm_eps (`float`, *optional*, defaults to 1e-12): 56 | The epsilon used by the layer normalization layers. 57 | position_embedding_type (`str`, *optional*, defaults to `"absolute"`): 58 | Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For 59 | positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to 60 | [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). 61 | For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models 62 | with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658). 63 | use_cache (`bool`, *optional*, defaults to `True`): 64 | Whether or not the model should return the last key/values attentions (not used by all models). Only 65 | relevant if `config.is_decoder=True`. 66 | classifier_dropout (`float`, *optional*): 67 | The dropout ratio for the classification head. 68 | ```""" 69 | model_type = "memoria_bert" 70 | 71 | def __init__( 72 | self, 73 | vocab_size=30522, 74 | hidden_size=768, 75 | num_hidden_layers=12, 76 | num_attention_heads=12, 77 | intermediate_size=3072, 78 | hidden_act="gelu", 79 | hidden_dropout_prob=0.1, 80 | attention_probs_dropout_prob=0.1, 81 | max_position_embeddings=512, 82 | type_vocab_size=2, 83 | initializer_range=0.02, 84 | layer_norm_eps=1e-12, 85 | pad_token_id=0, 86 | position_embedding_type="absolute", 87 | use_cache=True, 88 | classifier_dropout=None, 89 | memory_layer_index: int = 9, 90 | memoria_num_memories: float = 64, 91 | memoria_lifespan_extend_scale: float = 8.0, 92 | memoria_num_reminded_stm: float = 64, 93 | memoria_num_reminded_ltm: float = 64, 94 | memoria_stm_capacity: int = 128, 95 | memoria_ltm_search_depth: int = 10, 96 | memoria_initial_lifespan: int = 12, 97 | memoria_reset_period: int = 500, 98 | memoria_device: Optional[str] = None, 99 | **kwargs 100 | ): 101 | super().__init__(pad_token_id=pad_token_id, **kwargs) 102 | 103 | self.vocab_size = vocab_size 104 | self.hidden_size = hidden_size 105 | self.num_hidden_layers = num_hidden_layers 106 | self.num_attention_heads = num_attention_heads 107 | self.hidden_act = hidden_act 108 | self.intermediate_size = intermediate_size 109 | self.hidden_dropout_prob = hidden_dropout_prob 110 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 111 | self.max_position_embeddings = max_position_embeddings 112 | self.type_vocab_size = type_vocab_size 113 | self.initializer_range = initializer_range 114 | self.layer_norm_eps = layer_norm_eps 115 | self.position_embedding_type = position_embedding_type 116 | self.use_cache = use_cache 117 | self.classifier_dropout = classifier_dropout 118 | 119 | self.memory_layer_index: int = memory_layer_index 120 | self.memoria_num_memories: int = memoria_num_memories 121 | self.memoria_lifespan_extend_scale: float = memoria_lifespan_extend_scale 122 | self.memoria_num_reminded_stm: int = memoria_num_reminded_stm 123 | self.memoria_num_reminded_ltm: int = memoria_num_reminded_ltm 124 | self.memoria_stm_capacity: int = memoria_stm_capacity 125 | self.memoria_ltm_search_depth: int = memoria_ltm_search_depth 126 | self.memoria_initial_lifespan: int = memoria_initial_lifespan 127 | self.memoria_reset_period: int = memoria_reset_period 128 | self.memoria_device: Optional[str] = memoria_device 129 | 130 | 131 | MemoriaBertConfig.register_for_auto_class() 132 | AutoConfig.register("memoria_bert", MemoriaBertConfig) 133 | -------------------------------------------------------------------------------- /experiment/longseq_formers/model/memoria_roberta/__init__.py: -------------------------------------------------------------------------------- 1 | from .configuration_memoria_roberta import MemoriaRobertaConfig 2 | from .modeling_memoria_roberta import MemoriaRobertaForSequenceClassification, MemoriaRobertaModel 3 | 4 | __all__ = [ 5 | "MemoriaRobertaConfig", 6 | "MemoriaRobertaForSequenceClassification", 7 | "MemoriaRobertaModel", 8 | ] 9 | -------------------------------------------------------------------------------- /experiment/longseq_formers/model/memoria_roberta/configuration_memoria_roberta.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ RoBERTa configuration""" 17 | 18 | from typing import Optional 19 | 20 | from transformers import AutoConfig 21 | from transformers.configuration_utils import PretrainedConfig 22 | from transformers.utils import logging 23 | 24 | logger = logging.get_logger(__name__) 25 | 26 | 27 | class MemoriaRobertaConfig(PretrainedConfig): 28 | r""" 29 | This is the configuration class to store the configuration of a [`RobertaModel`] or a [`TFRobertaModel`]. It is 30 | used to instantiate a RoBERTa model according to the specified arguments, defining the model architecture. 31 | Instantiating a configuration with the defaults will yield a similar configuration to that of the RoBERTa 32 | [roberta-base](https://huggingface.co/roberta-base) architecture. 33 | 34 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 35 | documentation from [`PretrainedConfig`] for more information. 36 | 37 | 38 | Args: 39 | vocab_size (`int`, *optional*, defaults to 30522): 40 | Vocabulary size of the RoBERTa model. Defines the number of different tokens that can be represented by the 41 | `inputs_ids` passed when calling [`RobertaModel`] or [`TFRobertaModel`]. 42 | hidden_size (`int`, *optional*, defaults to 768): 43 | Dimensionality of the encoder layers and the pooler layer. 44 | num_hidden_layers (`int`, *optional*, defaults to 12): 45 | Number of hidden layers in the Transformer encoder. 46 | num_attention_heads (`int`, *optional*, defaults to 12): 47 | Number of attention heads for each attention layer in the Transformer encoder. 48 | intermediate_size (`int`, *optional*, defaults to 3072): 49 | Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. 50 | hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): 51 | The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, 52 | `"relu"`, `"silu"` and `"gelu_new"` are supported. 53 | hidden_dropout_prob (`float`, *optional*, defaults to 0.1): 54 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 55 | attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): 56 | The dropout ratio for the attention probabilities. 57 | max_position_embeddings (`int`, *optional*, defaults to 512): 58 | The maximum sequence length that this model might ever be used with. Typically set this to something large 59 | just in case (e.g., 512 or 1024 or 2048). 60 | type_vocab_size (`int`, *optional*, defaults to 2): 61 | The vocabulary size of the `token_type_ids` passed when calling [`RobertaModel`] or [`TFRobertaModel`]. 62 | initializer_range (`float`, *optional*, defaults to 0.02): 63 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 64 | layer_norm_eps (`float`, *optional*, defaults to 1e-12): 65 | The epsilon used by the layer normalization layers. 66 | position_embedding_type (`str`, *optional*, defaults to `"absolute"`): 67 | Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For 68 | positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to 69 | [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). 70 | For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models 71 | with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658). 72 | use_cache (`bool`, *optional*, defaults to `True`): 73 | Whether or not the model should return the last key/values attentions (not used by all models). Only 74 | relevant if `config.is_decoder=True`. 75 | classifier_dropout (`float`, *optional*): 76 | The dropout ratio for the classification head. 77 | 78 | Examples: 79 | 80 | ```python 81 | >>> from transformers import RobertaConfig, RobertaModel 82 | 83 | >>> # Initializing a RoBERTa configuration 84 | >>> configuration = RobertaConfig() 85 | 86 | >>> # Initializing a model (with random weights) from the configuration 87 | >>> model = RobertaModel(configuration) 88 | 89 | >>> # Accessing the model configuration 90 | >>> configuration = model.config 91 | ```""" 92 | model_type = "memoria_roberta" 93 | 94 | def __init__( 95 | self, 96 | vocab_size=30522, 97 | hidden_size=768, 98 | num_hidden_layers=12, 99 | num_attention_heads=12, 100 | intermediate_size=3072, 101 | hidden_act="gelu", 102 | hidden_dropout_prob=0.1, 103 | attention_probs_dropout_prob=0.1, 104 | max_position_embeddings=512, 105 | type_vocab_size=2, 106 | initializer_range=0.02, 107 | layer_norm_eps=1e-12, 108 | pad_token_id=1, 109 | bos_token_id=0, 110 | eos_token_id=2, 111 | position_embedding_type="absolute", 112 | use_cache=True, 113 | classifier_dropout=None, 114 | memory_layer_index: int = 9, 115 | memoria_num_memories: float = 64, 116 | memoria_lifespan_extend_scale: float = 8.0, 117 | memoria_num_reminded_stm: float = 64, 118 | memoria_num_reminded_ltm: float = 64, 119 | memoria_stm_capacity: int = 128, 120 | memoria_ltm_search_depth: int = 10, 121 | memoria_initial_lifespan: int = 12, 122 | memoria_reset_period: int = 500, 123 | memoria_device: Optional[str] = None, 124 | **kwargs 125 | ): 126 | super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) 127 | 128 | self.vocab_size = vocab_size 129 | self.hidden_size = hidden_size 130 | self.num_hidden_layers = num_hidden_layers 131 | self.num_attention_heads = num_attention_heads 132 | self.hidden_act = hidden_act 133 | self.intermediate_size = intermediate_size 134 | self.hidden_dropout_prob = hidden_dropout_prob 135 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 136 | self.max_position_embeddings = max_position_embeddings 137 | self.type_vocab_size = type_vocab_size 138 | self.initializer_range = initializer_range 139 | self.layer_norm_eps = layer_norm_eps 140 | self.position_embedding_type = position_embedding_type 141 | self.use_cache = use_cache 142 | self.classifier_dropout = classifier_dropout 143 | 144 | self.memory_layer_index: int = memory_layer_index 145 | self.memoria_num_memories: int = memoria_num_memories 146 | self.memoria_lifespan_extend_scale: float = memoria_lifespan_extend_scale 147 | self.memoria_num_reminded_stm: int = memoria_num_reminded_stm 148 | self.memoria_num_reminded_ltm: int = memoria_num_reminded_ltm 149 | self.memoria_stm_capacity: int = memoria_stm_capacity 150 | self.memoria_ltm_search_depth: int = memoria_ltm_search_depth 151 | self.memoria_initial_lifespan: int = memoria_initial_lifespan 152 | self.memoria_reset_period: int = memoria_reset_period 153 | self.memoria_device: Optional[str] = memoria_device 154 | 155 | 156 | MemoriaRobertaConfig.register_for_auto_class() 157 | AutoConfig.register("memoria_roberta", MemoriaRobertaConfig) 158 | -------------------------------------------------------------------------------- /experiment/longseq_formers/task/__init__.py: -------------------------------------------------------------------------------- 1 | from .classification import Classification 2 | from .language_modeling import LanguageModeling 3 | from .synthetic import Synthetic 4 | 5 | __all__ = ["Classification", "LanguageModeling", "Synthetic"] 6 | -------------------------------------------------------------------------------- /experiment/longseq_formers/task/classification.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Literal, Optional 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | import torch.nn.functional as F 6 | from torchmetrics.classification import Accuracy, MulticlassF1Score 7 | from torchmetrics.collections import MetricCollection 8 | from transformers import AutoConfig, AutoModelForSequenceClassification, get_linear_schedule_with_warmup 9 | 10 | 11 | class Classification(pl.LightningModule): 12 | """Classification 13 | 14 | Attributes: 15 | model: model for classification 16 | num_classes: the number of classes 17 | total_steps: total training steps for lr scheduling 18 | learning_rate: Max LR 19 | warmup_rate: warmup step rate 20 | """ 21 | 22 | def __init__( 23 | self, 24 | model: AutoModelForSequenceClassification, 25 | num_classes: int, 26 | total_steps: int, 27 | learning_rate: float, 28 | warmup_rate: float, 29 | segment_size: Optional[int] = None, 30 | aggregate: Literal["mean", "last"] = "mean", 31 | eval_aggregate: Literal["mean", "last"] = "last", 32 | ): 33 | super().__init__() 34 | 35 | self.model = model 36 | self.num_classes = num_classes 37 | self.total_steps = total_steps 38 | self.learning_rate = learning_rate 39 | self.warmup_rate = warmup_rate 40 | self.segment_size = segment_size 41 | self.aggregate = aggregate 42 | self.eval_aggregate = eval_aggregate 43 | self.automatic_optimization = False 44 | 45 | metric_collection = MetricCollection( 46 | { 47 | "acc": Accuracy(task="multiclass", top_k=1, num_classes=self.num_classes), 48 | "f1": MulticlassF1Score(task="multiclass", num_classes=self.num_classes, average="macro"), 49 | } 50 | ) 51 | self.train_metrics = metric_collection.clone(prefix="train/") 52 | self.val_metrics = metric_collection.clone(prefix="val/") 53 | self.test_metrics = metric_collection.clone(prefix="test/") 54 | self.metrics = {"train/": self.train_metrics, "val/": self.val_metrics, "test/": self.test_metrics} 55 | 56 | self.save_hyperparameters( 57 | { 58 | "model": None, 59 | "model_config": model.config.to_dict() if model is not None else None, 60 | "num_classes": num_classes, 61 | "total_steps": total_steps, 62 | "learning_rate": learning_rate, 63 | "warmup_rate": warmup_rate, 64 | "segment_size": segment_size, 65 | "aggregate": aggregate, 66 | "eval_aggregate": eval_aggregate, 67 | } 68 | ) 69 | 70 | def _single_step(self, batch: dict[str, torch.Tensor], batch_idx: int, prefix="") -> dict[str, float]: 71 | """Common step function 72 | 73 | Args: 74 | batch: training batch input/label 75 | Returns: 76 | metrics dictionary of this train step 77 | """ 78 | labels = batch.pop("labels") 79 | 80 | outputs = self.model(**batch) 81 | logits = outputs.logits 82 | 83 | ce_loss = F.cross_entropy(logits, labels, reduction="none") 84 | loss = ce_loss 85 | other_metrics = {"ce_loss": ce_loss.mean()} 86 | if self.model.config.model_type == "memoria_bert": 87 | ltm_mask = self.model.bert.encoder.memoria.engrams.longterm_memory_mask 88 | other_metrics["num_ltms_per_batch"] = ( 89 | ltm_mask.sum(dim=1).float().mean(dim=0) 90 | if ltm_mask.numel() > 0 91 | else torch.tensor(0.0, device=loss.device) 92 | ) 93 | if self.model.config.model_type == "memoria_roberta": 94 | ltm_mask = self.model.roberta.encoder.memoria.engrams.longterm_memory_mask 95 | other_metrics["num_ltms_per_batch"] = ( 96 | ltm_mask.sum(dim=1).float().mean(dim=0) 97 | if ltm_mask.numel() > 0 98 | else torch.tensor(0.0, device=loss.device) 99 | ) 100 | other_metrics["loss"] = loss 101 | 102 | other_metrics = {prefix + k: v for k, v in other_metrics.items()} 103 | return other_metrics, logits.detach(), labels.detach() 104 | 105 | def _segment_step( 106 | self, 107 | batch: dict[str, torch.Tensor], 108 | batch_idx: int, 109 | aggregate: Literal["mean", "last"], 110 | prefix="", 111 | ) -> dict[str, float]: 112 | batch_size, length = batch["input_ids"].shape 113 | num_valid_segments = batch["attention_mask"][:, :: self.segment_size].sum(dim=1) 114 | all_metrics = [] 115 | all_probs = [] 116 | indices = list(range(0, length, self.segment_size)) 117 | prev_indices = [None] + indices[:-1] 118 | post_indices = indices[1:] + [None] 119 | final_loss = 0.0 120 | for pre_i, i, post_i in zip(prev_indices, indices, post_indices): 121 | segment_batch = {k: v[:, i : i + self.segment_size] if k != "labels" else v for k, v in batch.items()} 122 | pre_batch = ( 123 | {k: v[:, pre_i : pre_i + self.segment_size] if k != "labels" else v for k, v in batch.items()} 124 | if pre_i is not None 125 | else None 126 | ) 127 | post_batch = ( 128 | {k: v[:, post_i : post_i + self.segment_size] if k != "labels" else v for k, v in batch.items()} 129 | if post_i is not None 130 | else None 131 | ) 132 | 133 | current_valid = segment_batch["attention_mask"].bool().any(dim=1) 134 | is_last = current_valid 135 | if pre_batch is not None: 136 | pre_valid = pre_batch["attention_mask"].bool().any(dim=1) 137 | is_last &= pre_valid 138 | if post_batch is not None: 139 | post_valid = post_batch["attention_mask"].bool().any(dim=1) 140 | is_last &= ~post_valid 141 | 142 | segment_metrics, logits, labels = self._single_step(segment_batch, batch_idx, prefix) 143 | if aggregate == "last": 144 | loss = segment_metrics[f"{prefix}loss"] / batch_size 145 | loss = loss[is_last].sum() 146 | final_loss += loss.item() 147 | 148 | if logits[is_last].numel(): 149 | self.metrics[prefix].update(logits[is_last], labels[is_last]) 150 | segment_metrics[f"{prefix}loss"] = loss 151 | elif aggregate == "mean": 152 | loss = segment_metrics[f"{prefix}loss"].mean() / len(indices) 153 | final_loss += loss.item() 154 | 155 | probs = logits.softmax(dim=-1) 156 | probs[~current_valid] = 0.0 157 | all_probs.append(probs) 158 | 159 | segment_metrics[f"{prefix}loss"] = loss 160 | else: 161 | raise ValueError(f"Unknown aggregate method: {aggregate}") 162 | 163 | if prefix == "train/": 164 | self.manual_backward(loss) 165 | 166 | all_metrics.append(segment_metrics) 167 | if aggregate == "mean": 168 | all_metrics = { 169 | k: torch.stack([m[k] for m in all_metrics], dim=0).mean(dim=0) for k in all_metrics[0].keys() 170 | } 171 | mean_logits = torch.stack(all_probs, dim=-1).mean(dim=-1) 172 | self.metrics[prefix].update(mean_logits, labels) 173 | segment_metrics = all_metrics 174 | 175 | segment_metrics.update(self.metrics[prefix].compute()) 176 | return segment_metrics 177 | 178 | def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> dict[str, float]: 179 | """Train step function""" 180 | opt = self.optimizers() 181 | sch = self.lr_schedulers() 182 | opt.zero_grad() 183 | 184 | if self.segment_size: 185 | metrics = self._segment_step(batch=batch, batch_idx=batch_idx, aggregate=self.aggregate, prefix="train/") 186 | else: 187 | metrics, logits, labels = self._single_step(batch=batch, batch_idx=batch_idx, prefix="train/") 188 | metrics = {k: v.mean() for k, v in metrics.items()} 189 | self.manual_backward(metrics["train/loss"]) 190 | metrics.update(self.metrics["train/"](logits, labels)) 191 | 192 | opt.step() 193 | if sch is not None: 194 | sch.step() 195 | 196 | self.log_dict(metrics, prog_bar=True, logger=True, on_step=True, sync_dist=True) 197 | return metrics 198 | 199 | def validation_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> dict[str, float]: 200 | """Validation step function""" 201 | if self.segment_size: 202 | metrics = self._segment_step(batch=batch, batch_idx=batch_idx, aggregate=self.eval_aggregate, prefix="val/") 203 | else: 204 | metrics, logits, labels = self._single_step(batch=batch, batch_idx=batch_idx, prefix="val/") 205 | metrics = {k: v.mean() for k, v in metrics.items()} 206 | metrics.update(self.metrics["val/"](logits, labels)) 207 | self.log_dict(metrics, prog_bar=True, logger=True, on_step=True, sync_dist=True) 208 | return metrics 209 | 210 | def test_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> dict[str, float]: 211 | """Test step function""" 212 | if self.segment_size: 213 | metrics = self._segment_step( 214 | batch=batch, batch_idx=batch_idx, aggregate=self.eval_aggregate, prefix="test/" 215 | ) 216 | else: 217 | metrics, logits, labels = self._single_step(batch=batch, batch_idx=batch_idx, prefix="test/") 218 | metrics = {k: v.mean() for k, v in metrics.items()} 219 | metrics.update(self.metrics["test/"](logits, labels)) 220 | self.log_dict(metrics, prog_bar=True, logger=True, on_step=True, sync_dist=True) 221 | return metrics 222 | 223 | def configure_optimizers(self) -> Dict: 224 | optimizer = torch.optim.Adam(params=self.model.parameters(), lr=self.learning_rate) 225 | optimizers = {"optimizer": optimizer} 226 | 227 | if self.warmup_rate is not None: 228 | scheduler = get_linear_schedule_with_warmup( 229 | optimizer, 230 | num_warmup_steps=int(self.total_steps * self.warmup_rate), 231 | num_training_steps=self.total_steps, 232 | ) 233 | optimizers["lr_scheduler"] = {"scheduler": scheduler, "interval": "step", "name": "Learning Rate"} 234 | 235 | return optimizers 236 | 237 | def on_save_checkpoint(self, checkpoint: dict[str, Any]): 238 | checkpoint["model_config"] = self.model.config.to_dict() 239 | checkpoint["model_type"] = self.model.config.model_type 240 | 241 | def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: 242 | config_dict = checkpoint["model_config"] 243 | config_cls = AutoConfig.for_model(checkpoint["model_type"]) 244 | config = config_cls.from_dict(config_dict) 245 | self.model = AutoModelForSequenceClassification.from_config(config) 246 | return super().on_load_checkpoint(checkpoint) 247 | 248 | def on_train_batch_start(self, batch: Any, batch_idx: int) -> None: 249 | self.metrics["train/"].reset() 250 | if self.model.config.model_type == "memoria_bert": 251 | self.model.bert.encoder.memoria.reset_memory() 252 | if self.model.config.model_type == "memoria_roberta": 253 | self.model.roberta.encoder.memoria.reset_memory() 254 | 255 | def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: 256 | if self.model.config.model_type == "memoria_bert": 257 | self.model.bert.encoder.memoria.reset_memory() 258 | if self.model.config.model_type == "memoria_roberta": 259 | self.model.roberta.encoder.memoria.reset_memory() 260 | 261 | def on_test_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: 262 | if self.model.config.model_type == "memoria_bert": 263 | self.model.bert.encoder.memoria.reset_memory() 264 | if self.model.config.model_type == "memoria_roberta": 265 | self.model.roberta.encoder.memoria.reset_memory() 266 | 267 | def _epoch_end(self, outputs, prefix: str = "") -> None: 268 | results = self.metrics[prefix].compute() 269 | results = {k + "_final": v for k, v in results.items()} 270 | self.metrics[prefix].reset() 271 | self.log_dict(results, logger=True, sync_dist=True) 272 | 273 | def validation_epoch_end(self, outputs) -> None: 274 | return self._epoch_end(outputs, prefix="val/") 275 | 276 | def test_epoch_end(self, outputs) -> None: 277 | return self._epoch_end(outputs, prefix="test/") 278 | -------------------------------------------------------------------------------- /experiment/longseq_formers/task/language_modeling.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | from transformers import AutoConfig, AutoModelForCausalLM, get_linear_schedule_with_warmup 6 | 7 | 8 | class LanguageModeling(pl.LightningModule): 9 | """LanguageModeling 10 | 11 | Attributes: 12 | model: model for language modeling 13 | total_steps: total training steps for lr scheduling 14 | learning_rate: Max LR 15 | warmup_rate: warmup step rate 16 | """ 17 | 18 | def __init__( 19 | self, model: Optional[AutoModelForCausalLM], total_steps: int, learning_rate: float, warmup_rate: float 20 | ): 21 | super().__init__() 22 | 23 | self.model = model 24 | self.total_steps = total_steps 25 | self.learning_rate = learning_rate 26 | self.warmup_rate = warmup_rate 27 | 28 | self.save_hyperparameters( 29 | { 30 | "model": None, 31 | "total_steps": total_steps, 32 | "learning_rate": learning_rate, 33 | "warmup_rate": warmup_rate, 34 | "model_config": model.config.to_dict() if model is not None else None, 35 | } 36 | ) 37 | 38 | def _step(self, batch: dict[str, torch.Tensor], batch_idx: int, prefix="") -> dict[str, float]: 39 | """Common step function 40 | 41 | Args: 42 | batch: training batch input/label 43 | Returns: 44 | metrics dictionary of this train step 45 | """ 46 | is_end = batch.pop("is_end", None) 47 | 48 | if self.model.config.model_type in ["transfo-xl", "memoria-xl"]: 49 | del batch["attention_mask"] 50 | if hasattr(self, "_mems"): 51 | batch["mems"] = self._mems 52 | if hasattr(self, "_cmems"): 53 | batch["cmems"] = self._cmems 54 | outputs = self.model(**batch) 55 | lm_loss = outputs.loss 56 | 57 | if self.model.config.model_type in ["compressive_transformer"]: 58 | lm_loss = outputs.lm_loss 59 | if hasattr(outputs, "mems"): 60 | self._mems = outputs.mems 61 | if hasattr(outputs, "cmems"): 62 | self._cmems = outputs.cmems 63 | 64 | loss = outputs.loss 65 | ppl = lm_loss.detach().exp() 66 | metrics = {"loss": loss, "lm_loss": lm_loss, "ppl": ppl} 67 | if self.model.config.model_type in ["gpt2_with_memoria"]: 68 | ltm_mask = self.model.transformer.memoria.engrams.longterm_memory_mask 69 | metrics["num_ltms_per_batch"] = ltm_mask.sum(dim=1).float().mean(dim=0) if ltm_mask.numel() > 0 else 0.0 70 | metrics = {prefix + k: v for k, v in metrics.items()} 71 | return metrics 72 | 73 | def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> dict[str, float]: 74 | """Train step function""" 75 | metrics = self._step(batch=batch, batch_idx=batch_idx, prefix="") 76 | self.log_dict(metrics, prog_bar=True, logger=True, on_step=True, sync_dist=True) 77 | return metrics 78 | 79 | def validation_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> dict[str, float]: 80 | """Validation step function""" 81 | metrics = self._step(batch=batch, batch_idx=batch_idx, prefix="val/") 82 | self.log_dict(metrics, prog_bar=True, logger=True, on_step=True, sync_dist=True) 83 | return metrics 84 | 85 | def test_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> dict[str, float]: 86 | """Test step function""" 87 | metrics = self._step(batch=batch, batch_idx=batch_idx, prefix="test/") 88 | self.log_dict(metrics, prog_bar=True, logger=True, on_step=True, sync_dist=True) 89 | return metrics 90 | 91 | def configure_optimizers(self) -> Dict: 92 | optimizer = torch.optim.Adam(params=self.model.parameters(), lr=self.learning_rate) 93 | optimizers = {"optimizer": optimizer} 94 | 95 | scheduler = get_linear_schedule_with_warmup( 96 | optimizer, 97 | num_warmup_steps=int(self.total_steps * self.warmup_rate) if self.warmup_rate else 0, 98 | num_training_steps=self.total_steps, 99 | ) 100 | optimizers["lr_scheduler"] = {"scheduler": scheduler, "interval": "step", "name": "Learning Rate"} 101 | 102 | return optimizers 103 | 104 | def reset_memories(self) -> None: 105 | if self.model.config.model_type in ["gpt2_with_memoria"]: 106 | self.model.transformer.memoria.reset_memory() 107 | self.model.transformer.prev_hidden = None 108 | if self.model.config.model_type in ["transfo-xl"] and hasattr(self, "_mems"): 109 | del self._mems 110 | if self.model.config.model_type == "compressive_transformer": 111 | if hasattr(self, "_mems"): 112 | del self._mems 113 | if hasattr(self, "_cmems"): 114 | del self._cmems 115 | 116 | def on_train_start(self) -> None: 117 | self.reset_memories() 118 | 119 | def on_train_end(self) -> None: 120 | self.reset_memories() 121 | 122 | def on_validation_start(self) -> None: 123 | self.reset_memories() 124 | 125 | def on_validation_end(self) -> None: 126 | self.reset_memories() 127 | 128 | def on_test_start(self) -> None: 129 | self.reset_memories() 130 | 131 | def on_test_end(self) -> None: 132 | self.reset_memories() 133 | 134 | def on_train_batch_start(self, batch: Any, batch_idx: int) -> None: 135 | if self.model.config.model_type in ["gpt2_with_memoria"]: 136 | if batch_idx % self.model.config.memoria_reset_period == 0: 137 | self.model.transformer.memoria.reset_memory() 138 | self.model.transformer.prev_hidden = None 139 | 140 | def on_save_checkpoint(self, checkpoint: dict[str, Any]): 141 | checkpoint["model_config"] = self.model.config.to_dict() 142 | checkpoint["model_type"] = self.model.config.model_type 143 | 144 | def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: 145 | config_dict = checkpoint["model_config"] 146 | config_cls = AutoConfig.for_model(checkpoint["model_type"]) 147 | config = config_cls.from_dict(config_dict) 148 | self.model = AutoModelForCausalLM.from_config(config) 149 | return super().on_load_checkpoint(checkpoint) 150 | -------------------------------------------------------------------------------- /experiment/longseq_formers/task/synthetic.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | import torch.nn.functional as F 6 | from torchmetrics.classification import Accuracy 7 | from transformers import AutoConfig, AutoModelForCausalLM, get_linear_schedule_with_warmup 8 | 9 | 10 | class Synthetic(pl.LightningModule): 11 | """Synthetic Task 12 | 13 | Attributes: 14 | model: model for classification 15 | num_classes: the number of classes 16 | total_steps: total training steps for lr scheduling 17 | learning_rate: Max LR 18 | warmup_rate: warmup step rate 19 | segment_size: segment size 20 | vocab_size: vocab size 21 | """ 22 | 23 | def __init__( 24 | self, 25 | model: AutoModelForCausalLM, 26 | total_steps: int, 27 | learning_rate: float, 28 | warmup_rate: float, 29 | segment_size: int, 30 | vocab_size: int, 31 | max_grad_norm: Optional[float] = None, 32 | ): 33 | super().__init__() 34 | 35 | self.model = model 36 | self.total_steps = total_steps 37 | self.learning_rate = learning_rate 38 | self.warmup_rate = warmup_rate 39 | self.segment_size = segment_size 40 | self.vocab_size = vocab_size 41 | self.max_grad_norm = max_grad_norm 42 | self.automatic_optimization = False 43 | 44 | self.train_acc = Accuracy(task="multiclass", top_k=1, num_classes=vocab_size, ignore_index=-100) 45 | self.valid_acc = Accuracy(task="multiclass", top_k=1, num_classes=vocab_size, ignore_index=-100) 46 | self.test_acc = Accuracy(task="multiclass", top_k=1, num_classes=vocab_size, ignore_index=-100) 47 | self.accs = {"train": self.train_acc, "val": self.valid_acc, "test": self.test_acc} 48 | 49 | self.save_hyperparameters( 50 | { 51 | "model": None, 52 | "model_config": model.config.to_dict() if model is not None else None, 53 | "total_steps": total_steps, 54 | "learning_rate": learning_rate, 55 | "warmup_rate": warmup_rate, 56 | "segment_size": segment_size, 57 | "vocab_size": vocab_size, 58 | "max_grad_norm": max_grad_norm, 59 | } 60 | ) 61 | 62 | def _step(self, batch: dict[str, torch.Tensor], batch_idx: int, prefix: str) -> dict[str, float]: 63 | """Train step function""" 64 | batch_size, length = batch["input_ids"].size() 65 | num_valid_labels = (batch["labels"] != -100).sum(dim=1) 66 | indices = range(0, length, self.segment_size) 67 | loss_mean = 0.0 68 | acc = self.accs[prefix] 69 | for i in indices: 70 | segment_batch = {k: v[:, i : i + self.segment_size] for k, v in batch.items()} 71 | labels = segment_batch.pop("labels") 72 | if hasattr(self, "_mems"): 73 | segment_batch["mems"] = self._mems 74 | if hasattr(self, "_cmems"): 75 | segment_batch["cmems"] = self._cmems 76 | 77 | use_grad = prefix == "train" and (labels != -100).any().item() 78 | with torch.set_grad_enabled(use_grad): 79 | outputs = self.model(**segment_batch) 80 | if self.model.config.model_type in ["transfo-xl", "memoria-xl"]: 81 | self._mems = outputs.mems 82 | 83 | loss = ( 84 | F.cross_entropy( 85 | outputs.logits.view(-1, outputs.logits.size(-1)), 86 | labels.reshape(-1), 87 | ignore_index=-100, 88 | reduction="none", 89 | ) 90 | .view(batch_size, -1) 91 | .sum(dim=1) 92 | / num_valid_labels 93 | ).mean() 94 | 95 | if use_grad: 96 | self.manual_backward(loss) 97 | 98 | loss_mean += loss.item() 99 | preds = outputs.logits.argmax(dim=-1) 100 | acc.update(preds=preds, target=labels) 101 | 102 | metrics = {"loss": loss_mean, "acc": acc.compute()} 103 | metrics = {f"{prefix}/{k}": v for k, v in metrics.items()} 104 | return metrics 105 | 106 | def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> dict[str, float]: 107 | """Train step function""" 108 | opt = self.optimizers() 109 | sch = self.lr_schedulers() 110 | 111 | metrics = self._step(batch=batch, batch_idx=batch_idx, prefix="train") 112 | 113 | if self.max_grad_norm is not None: 114 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) 115 | opt.step() 116 | if sch is not None: 117 | sch.step() 118 | opt.zero_grad() 119 | 120 | self.train_acc.reset() 121 | self.log_dict(metrics, prog_bar=True, logger=True, on_step=True, sync_dist=True) 122 | return metrics 123 | 124 | def validation_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> dict[str, float]: 125 | """Validation step function""" 126 | metrics = self._step(batch=batch, batch_idx=batch_idx, prefix="val") 127 | self.log_dict(metrics, prog_bar=True, logger=True, on_step=True, on_epoch=False, sync_dist=True) 128 | return metrics 129 | 130 | def test_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> dict[str, float]: 131 | """Validation step function""" 132 | metrics = self._step(batch=batch, batch_idx=batch_idx, prefix="test") 133 | self.log_dict(metrics, prog_bar=True, logger=True, on_step=True, on_epoch=False, sync_dist=True) 134 | return metrics 135 | 136 | def validation_epoch_end(self, outputs): 137 | val_acc = self.valid_acc.compute() 138 | self.valid_acc.reset() 139 | self.log("val/acc-final", val_acc, logger=True, on_step=False, sync_dist=True) 140 | 141 | def test_epoch_end(self, outputs): 142 | test_acc = self.test_acc.compute() 143 | self.test_acc.reset() 144 | self.log("test/acc-final", test_acc, logger=True, on_step=False, sync_dist=True) 145 | 146 | def configure_optimizers(self) -> Dict: 147 | optimizer = torch.optim.Adam(params=self.model.parameters(), lr=self.learning_rate) 148 | optimizers = {"optimizer": optimizer} 149 | 150 | if self.warmup_rate is not None: 151 | scheduler = get_linear_schedule_with_warmup( 152 | optimizer, 153 | num_warmup_steps=int(self.total_steps * self.warmup_rate), 154 | num_training_steps=self.total_steps, 155 | ) 156 | optimizers["lr_scheduler"] = {"scheduler": scheduler, "interval": "step", "name": "Learning Rate"} 157 | 158 | return optimizers 159 | 160 | def on_save_checkpoint(self, checkpoint: dict[str, Any]): 161 | checkpoint["model_config"] = self.model.config.to_dict() 162 | checkpoint["model_type"] = self.model.config.model_type 163 | 164 | def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: 165 | config_dict = checkpoint["model_config"] 166 | config_cls = AutoConfig.for_model(checkpoint["model_type"]) 167 | config = config_cls.from_dict(config_dict) 168 | self.model = AutoModelForCausalLM.from_config(config) 169 | return super().on_load_checkpoint(checkpoint) 170 | 171 | def reset_memories(self) -> None: 172 | if self.model.config.model_type in ["gpt2_with_memoria", "memoria-xl"]: 173 | self.model.transformer.memoria.reset_memory() 174 | self.model.transformer.prev_hidden = None 175 | if self.model.config.model_type in ["transfo-xl", "memoria-xl"] and hasattr(self, "_mems"): 176 | del self._mems 177 | if self.model.config.model_type == "compressive_transformer": 178 | if hasattr(self, "_mems"): 179 | del self._mems 180 | if hasattr(self, "_cmems"): 181 | del self._cmems 182 | if self.model.config.model_type == "infinity_gpt2": 183 | self.model.reset_memories() 184 | 185 | def on_train_batch_start(self, batch, batch_idx) -> None: 186 | self.reset_memories() 187 | 188 | def on_validation_batch_start(self, batch, batch_idx, dataloader_idx) -> None: 189 | self.reset_memories() 190 | 191 | def on_test_batch_start(self, batch, batch_idx, dataloader_idx) -> None: 192 | self.reset_memories() 193 | -------------------------------------------------------------------------------- /experiment/longseq_formers/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | 4 | import pytorch_lightning as pl 5 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 6 | from torch.utils.data.distributed import DistributedSampler 7 | 8 | from .dataset.language_modeling import LanguageModelingDataset 9 | 10 | 11 | def get_logger(name: str) -> logging.Logger: 12 | """Return logger for logging 13 | 14 | Args: 15 | name: logger name 16 | """ 17 | logger = logging.getLogger(name) 18 | logger.propagate = False 19 | logger.setLevel(logging.DEBUG) 20 | if not logger.handlers: 21 | handler = logging.StreamHandler(sys.stdout) 22 | handler.setFormatter(logging.Formatter("[%(asctime)s] %(message)s")) 23 | logger.addHandler(handler) 24 | return logger 25 | 26 | 27 | class BatchedDataModule(pl.LightningDataModule): 28 | def __init__( 29 | self, 30 | train_dataset: LanguageModelingDataset, 31 | valid_dataset: LanguageModelingDataset, 32 | shuffle: bool, 33 | distributed: bool = True, 34 | ) -> None: 35 | super().__init__() 36 | 37 | self.train_dataset = train_dataset 38 | self.valid_dataset = valid_dataset 39 | self.shuffle = shuffle 40 | self.distributed = distributed 41 | 42 | def train_dataloader(self): 43 | # Use batch size as 1 because already batched 44 | if self.distributed: 45 | sampler = DistributedSampler(self.train_dataset, shuffle=self.shuffle) 46 | elif self.shuffle: 47 | sampler = RandomSampler(self.train_dataset) 48 | else: 49 | sampler = SequentialSampler(self.train_dataset) 50 | return DataLoader(self.train_dataset, batch_size=1, sampler=sampler, collate_fn=self.train_dataset.collate_fn) 51 | 52 | def val_dataloader(self): 53 | if self.distributed: 54 | sampler = DistributedSampler(self.valid_dataset, shuffle=False) 55 | else: 56 | sampler = SequentialSampler(self.valid_dataset) 57 | return DataLoader(self.valid_dataset, batch_size=1, sampler=sampler, collate_fn=self.valid_dataset.collate_fn) 58 | -------------------------------------------------------------------------------- /experiment/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.13.1 2 | pytorch-lightning==1.8.6 3 | transformers==4.25.1 4 | mogrifier # for compressive transformer 5 | 6 | datasets 7 | scikit-learn 8 | bs4 9 | nltk 10 | 11 | wandb 12 | 13 | memoria-pytorch==1.0.0 14 | -------------------------------------------------------------------------------- /experiment/train_classification.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import tempfile 4 | from typing import Dict 5 | 6 | import pytorch_lightning as pl 7 | import torch 8 | import wandb 9 | from longseq_formers.data import CLASSIFICATION_DATASETS, load_hyperpartisan_data 10 | from longseq_formers.dataset import ClassificationDataset 11 | from longseq_formers.task import Classification 12 | from longseq_formers.utils import get_logger 13 | from pytorch_lightning.callbacks import LearningRateMonitor 14 | from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger 15 | from torch.utils.data import DataLoader 16 | from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer 17 | 18 | # fmt: off 19 | parser = argparse.ArgumentParser(prog="train_classification", description="Train & Test Long Sequence Classification") 20 | 21 | g = parser.add_argument_group("Train Parameter") 22 | g.add_argument("--model", type=str, required=True, help="huggingface model") 23 | g.add_argument("--model-type", type=str, help="specific model type") 24 | g.add_argument("--tokenizer", type=str, help="huggingface tokenizer") 25 | g.add_argument("--dataset", type=str, default="hyperpartisan", choices=CLASSIFICATION_DATASETS, help="dataset name") 26 | g.add_argument("--batch-size", type=int, default=8, help="global training batch size") 27 | g.add_argument("--valid-batch-size", type=int, default=32, help="validation batch size") 28 | g.add_argument("--accumulate-grad-batches", type=int, default=1, help="the number of gradident accumulation steps") 29 | g.add_argument("--max-length", type=int, default=512, help="max sequence length") 30 | g.add_argument("--memory-length", type=int, default=512, help="max sequence length for bert one inference on infinity former") 31 | g.add_argument("--epochs", type=int, default=20, help="the number of training epochs") 32 | g.add_argument("--learning-rate", type=float, default=3e-5, help="learning rate") 33 | g.add_argument("--warmup-rate", type=float, help="warmup step rate") 34 | g.add_argument("--seed", type=int, default=42, help="random seed") 35 | g.add_argument("--test-ckpt", type=str, default="last", choices=["best", "last"], help="checkpoint type for testing") 36 | g.add_argument("--not-truncate", action="store_false", dest="truncation", help="not truncate sequence") 37 | g.add_argument("--segment-size", type=int, help="segment size for infinity former") 38 | 39 | g = parser.add_argument_group("Personal Options") 40 | g.add_argument("--output-dir", type=str, help="output directory path to save artifacts") 41 | g.add_argument("--gpus", type=int, help="the number of gpus, use all devices by default") 42 | g.add_argument("--logging-interval", type=int, default=10, help="logging interval") 43 | 44 | g = parser.add_argument_group("Wandb Options") 45 | g.add_argument("--wandb-run-name", type=str, help="wanDB run name") 46 | g.add_argument("--wandb-entity", type=str, help="wanDB entity name") 47 | g.add_argument("--wandb-project", type=str, help="wanDB project name") 48 | # fmt: on 49 | 50 | 51 | def main(args: argparse.Namespace) -> dict[str, float]: 52 | logger = get_logger("train_classification") 53 | 54 | if args.output_dir: 55 | os.makedirs(args.output_dir) 56 | logger.info(f'[+] Save output to "{args.output_dir}"') 57 | 58 | logger.info(" ====== Arguements ======") 59 | for k, v in vars(args).items(): 60 | logger.info(f"{k:25}: {v}") 61 | 62 | logger.info(f"[+] Set Random Seed to {args.seed}") 63 | pl.seed_everything(args.seed, workers=True) 64 | 65 | logger.info(f"[+] GPU: {args.gpus}") 66 | 67 | if args.tokenizer is None: 68 | if args.model: 69 | logger.info(f"[+] Use tokenizer same as model: {args.model}") 70 | args.tokenizer = args.model 71 | else: 72 | raise ValueError("you should set `--tokenizer` when use `--model-config`!") 73 | logger.info(f'[+] Load Tokenizer: "{args.tokenizer}"') 74 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) 75 | 76 | logger.info(f'[+] Use Dataset: "{args.dataset}"') 77 | if args.dataset == "hyperpartisan": 78 | datasets = load_hyperpartisan_data() 79 | num_classes = 2 80 | 81 | train_dataset = ClassificationDataset( 82 | datasets["train"], tokenizer=tokenizer, max_length=args.max_length, truncation=args.truncation 83 | ) 84 | valid_dataset = ClassificationDataset( 85 | datasets["dev"], tokenizer=tokenizer, max_length=args.max_length, truncation=args.truncation 86 | ) 87 | test_dataset = ClassificationDataset( 88 | datasets["test"], tokenizer=tokenizer, max_length=args.max_length, truncation=args.truncation 89 | ) 90 | 91 | logger.info(f"[+] # of train examples: {len(train_dataset)}") 92 | logger.info(f"[+] # of valid examples: {len(valid_dataset)}") 93 | logger.info(f"[+] # of test examples: {len(test_dataset)}") 94 | 95 | logger.info(f'[+] Load Model: "{args.model}"') 96 | if args.model_type: 97 | model_cls = type(AutoModelForSequenceClassification.from_config(AutoConfig.for_model(args.model_type))) 98 | else: 99 | model_cls = AutoModelForSequenceClassification 100 | model = model_cls.from_pretrained(args.model, num_labels=num_classes) 101 | 102 | if args.gpus is None: 103 | args.gpus = torch.cuda.device_count() 104 | num_parallels = max(args.gpus, 1) 105 | distributed = num_parallels > 1 106 | batch_size_per_device = max(args.batch_size // num_parallels, 1) 107 | global_batch_size = batch_size_per_device * args.gpus 108 | valid_batch_size_per_device = max(args.valid_batch_size // num_parallels, 1) 109 | global_valid_batch_size = valid_batch_size_per_device * num_parallels 110 | if args.batch_size != global_batch_size: 111 | logger.warning(f"[-] Batch size {args.batch_size} isn't dividable by {args.gpus}!") 112 | logger.warning(f"[-] Use batch size as {batch_size_per_device} per device, {global_batch_size} global") 113 | if args.valid_batch_size != global_valid_batch_size: 114 | logger.warning(f"[-] Valid Batch size {args.valid_batch_size} isn't dividable by {args.gpus}!") 115 | logger.warning( 116 | f"[-] Use batch size as {valid_batch_size_per_device} per device, {global_valid_batch_size} global" 117 | ) 118 | 119 | collate_fn = ClassificationDataset.pad_collate_fn if not args.truncation else None 120 | train_dataloader = DataLoader( 121 | train_dataset, 122 | shuffle=True, 123 | batch_size=batch_size_per_device, 124 | num_workers=os.cpu_count() // 2, 125 | pin_memory=True, 126 | collate_fn=collate_fn, 127 | ) 128 | valid_dataloader = DataLoader( 129 | valid_dataset, 130 | batch_size=args.valid_batch_size // num_parallels, 131 | collate_fn=collate_fn, 132 | ) 133 | test_dataloader = DataLoader( 134 | test_dataset, 135 | batch_size=args.valid_batch_size // num_parallels, 136 | collate_fn=collate_fn, 137 | ) 138 | 139 | total_steps = len(train_dataloader) * args.epochs 140 | 141 | classification = Classification( 142 | model=model, 143 | total_steps=total_steps, 144 | learning_rate=args.learning_rate, 145 | warmup_rate=args.warmup_rate, 146 | segment_size=args.segment_size, 147 | num_classes=num_classes, 148 | ) 149 | 150 | if args.output_dir: 151 | train_loggers = [TensorBoardLogger(args.output_dir, "", "logs")] 152 | model_dir = os.path.join(args.output_dir, "checkpoint") 153 | else: 154 | train_loggers = [] 155 | tmp_dir = tempfile.TemporaryDirectory() 156 | model_dir = tmp_dir.name 157 | 158 | logger.info(f"[+] Start Training") 159 | if args.wandb_project and (args.wandb_run_name or args.output_dir): 160 | wandb_logger = WandbLogger( 161 | name=args.wandb_run_name or os.path.basename(args.output_dir), 162 | project=args.wandb_project, 163 | entity=args.wandb_entity, 164 | save_dir=args.output_dir if args.output_dir else None, 165 | ) 166 | wandb_logger.log_hyperparams({"train_arguments": vars(args)}) 167 | train_loggers.append(wandb_logger) 168 | 169 | model_checkpoint_callback = pl.callbacks.ModelCheckpoint( 170 | model_dir, mode="max", monitor="val/f1_final", save_last=True, auto_insert_metric_name=True 171 | ) 172 | callbacks = [model_checkpoint_callback] 173 | 174 | if train_loggers: 175 | callbacks.append(LearningRateMonitor(logging_interval="step")) 176 | trainer = pl.Trainer( 177 | logger=train_loggers, 178 | max_epochs=args.epochs, 179 | log_every_n_steps=args.logging_interval, 180 | accumulate_grad_batches=args.accumulate_grad_batches, 181 | callbacks=callbacks, 182 | strategy="ddp_fork" if distributed else None, 183 | accelerator="gpu" if args.gpus else None, 184 | devices=num_parallels, 185 | ) 186 | trainer.fit(classification, train_dataloader, valid_dataloader) 187 | 188 | # Use seperated initialized trainer (https://github.com/Lightning-AI/lightning/issues/8375) 189 | tester = pl.Trainer( 190 | logger=train_loggers, 191 | callbacks=callbacks, 192 | accelerator="gpu" if args.gpus else None, 193 | devices=1, 194 | ) 195 | result = tester.test(classification, test_dataloader, ckpt_path=args.test_ckpt)[0] 196 | 197 | wandb.finish() 198 | 199 | if not args.output_dir: 200 | tmp_dir.cleanup() 201 | 202 | return result 203 | 204 | 205 | if __name__ == "__main__": 206 | main(parser.parse_args()) 207 | exit(0) 208 | -------------------------------------------------------------------------------- /experiment/train_language_modeling.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import tempfile 4 | from typing import Dict 5 | 6 | import pytorch_lightning as pl 7 | import torch 8 | import wandb 9 | from longseq_formers.data import ( 10 | LANGUAGE_MODELING_DATASETS, 11 | enwik8_tokenize, 12 | load_enwik8_data, 13 | load_pg19_data, 14 | load_wikitext103_data, 15 | ) 16 | from longseq_formers.dataset import LanguageModelingDataset, text_to_tokens 17 | from longseq_formers.task import LanguageModeling 18 | from longseq_formers.utils import BatchedDataModule, get_logger 19 | from pytorch_lightning.callbacks import LearningRateMonitor 20 | from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger 21 | from torch.utils.data import DataLoader 22 | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer 23 | 24 | # fmt: off 25 | parser = argparse.ArgumentParser(prog="train", description="Train & Test Language Modeling") 26 | 27 | g = parser.add_argument_group("Train Parameter") 28 | g.add_argument("--model-config", type=str, help="huggingface model config") 29 | g.add_argument("--model", type=str, help="huggingface model") 30 | g.add_argument("--model-type", type=str, help="specific model type") 31 | g.add_argument("--tokenizer", type=str, help="huggingface tokenizer") 32 | g.add_argument("--dataset", type=str, default="wikitext103", choices=LANGUAGE_MODELING_DATASETS, help="dataset name") 33 | g.add_argument("--batch-size", type=int, default=8, help="global training batch size") 34 | g.add_argument("--valid-batch-size", type=int, default=1, help="validation batch size") 35 | g.add_argument("--accumulate-grad-batches", type=int, default=1, help="the number of gradident accumulation steps") 36 | g.add_argument("--max-length", type=int, default=150, help="max sequence length") 37 | g.add_argument("--epochs", type=int, default=6, help="the number of training epochs") 38 | g.add_argument("--learning-rate", type=float, default=2e-4, help="learning rate") 39 | g.add_argument("--warmup-rate", type=float, default=0.06, help="warmup step rate") 40 | g.add_argument("--max-grad-norm", type=float, default=1.0, help="maximum gradient norm") 41 | g.add_argument("--seed", type=int, default=42, help="random seed") 42 | g.add_argument("--shuffle", action="store_true", help="shuffle data order") 43 | g.add_argument("--test-ckpt", type=str, default="last", choices=["best", "last"], help="checkpoint type for testing") 44 | 45 | g = parser.add_argument_group("Personal Options") 46 | g.add_argument("--output-dir", type=str, help="output directory path to save artifacts") 47 | g.add_argument("--gpus", type=int, help="the number of gpus, use all devices by default") 48 | g.add_argument("--logging-interval", type=int, default=100, help="logging interval") 49 | g.add_argument("--valid-interval", type=float, default=1.0, help="validation interval rate") 50 | 51 | g = parser.add_argument_group("Wandb Options") 52 | g.add_argument("--wandb-run-name", type=str, help="wanDB run name") 53 | g.add_argument("--wandb-entity", type=str, help="wanDB entity name") 54 | g.add_argument("--wandb-project", type=str, help="wanDB project name") 55 | # fmt: on 56 | 57 | 58 | def main(args: argparse.Namespace) -> dict[str, float]: 59 | logger = get_logger("train_language_modeling") 60 | 61 | if args.output_dir: 62 | os.makedirs(args.output_dir) 63 | logger.info(f'[+] Save output to "{args.output_dir}"') 64 | 65 | logger.info(" ====== Arguements ======") 66 | for k, v in vars(args).items(): 67 | logger.info(f"{k:25}: {v}") 68 | 69 | logger.info(f"[+] Set Random Seed to {args.seed}") 70 | pl.seed_everything(args.seed, workers=True) 71 | 72 | logger.info(f"[+] GPU: {args.gpus}") 73 | 74 | if args.tokenizer is None and args.dataset != "enwik8": 75 | if args.model: 76 | logger.info(f"[+] Use tokenizer same as model: {args.model}") 77 | args.tokenizer = args.model 78 | else: 79 | raise ValueError("you should set `--tokenizer` when use `--model-config`!") 80 | logger.info(f'[+] Load Tokenizer: "{args.tokenizer}"') 81 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) if args.tokenizer else enwik8_tokenize 82 | 83 | logger.info(f'[+] Use Dataset: "{args.dataset}"') 84 | if args.dataset == "wikitext103": 85 | data = load_wikitext103_data() 86 | elif args.dataset == "pg19": 87 | data = load_pg19_data() 88 | elif args.dataset == "enwik8": 89 | data = load_enwik8_data() 90 | else: 91 | raise ValueError(f"dataset `{args.dataset}` is not valid!") 92 | 93 | if args.gpus is None: 94 | args.gpus = torch.cuda.device_count() 95 | num_parallels = max(args.gpus, 1) 96 | distributed = num_parallels > 1 97 | batch_size_per_device = max(args.batch_size // num_parallels, 1) 98 | global_batch_size = batch_size_per_device * num_parallels 99 | valid_batch_size_per_device = max(args.valid_batch_size // num_parallels, 1) 100 | global_valid_batch_size = valid_batch_size_per_device * num_parallels 101 | if args.batch_size != global_batch_size: 102 | logger.warning(f"[-] Batch size {args.batch_size} isn't dividable by {args.gpus}!") 103 | logger.warning(f"[-] Use batch size as {batch_size_per_device} per device, {global_batch_size} global") 104 | if args.valid_batch_size != global_valid_batch_size: 105 | logger.warning(f"[-] Valid Batch size {args.valid_batch_size} isn't dividable by {args.gpus}!") 106 | logger.warning( 107 | f"[-] Use batch size as {valid_batch_size_per_device} per device, {global_valid_batch_size} global" 108 | ) 109 | 110 | train_tokens = text_to_tokens(data["train"], tokenizer, global_batch_size, args.max_length, batch_size_per_device) 111 | dev_tokens = text_to_tokens( 112 | data["dev"], tokenizer, global_valid_batch_size, args.max_length, valid_batch_size_per_device 113 | ) 114 | test_tokens = text_to_tokens(data["test"], tokenizer, args.valid_batch_size, args.max_length) 115 | 116 | train_dataset = LanguageModelingDataset(train_tokens) 117 | valid_dataset = LanguageModelingDataset(dev_tokens) 118 | test_dataset = LanguageModelingDataset(test_tokens) 119 | 120 | logger.info(f"[+] # of batched train examples: {len(train_dataset)}") 121 | logger.info(f"[+] # of batched valid examples: {len(valid_dataset)}") 122 | logger.info(f"[+] # of batched test examples: {len(test_dataset)}") 123 | 124 | if args.model: 125 | logger.info(f'[+] Load Model: "{args.model}"') 126 | if args.model_type: 127 | model_cls = type(AutoModelForCausalLM.from_config(AutoConfig.for_model(args.model_type))) 128 | logger.info(f"[+] Use model type: {args.model_type}") 129 | else: 130 | model_cls = AutoModelForCausalLM 131 | model = model_cls.from_pretrained(args.model) 132 | elif args.model_config: 133 | logger.info(f'[+] Initialize Model with Config: "{args.model_config}"') 134 | config = AutoConfig.from_pretrained(args.model_config, trust_remote_code=True) 135 | model = AutoModelForCausalLM.from_config(config) 136 | else: 137 | raise ValueError("you should set `--model` or `--model-config` argument!") 138 | 139 | total_steps = len(train_tokens["input_ids"]) // num_parallels * args.epochs 140 | 141 | language_modeling = LanguageModeling( 142 | model=model, 143 | total_steps=total_steps, 144 | learning_rate=args.learning_rate, 145 | warmup_rate=args.warmup_rate, 146 | ) 147 | 148 | if args.output_dir: 149 | train_loggers = [TensorBoardLogger(args.output_dir, "", "logs")] 150 | model_dir = os.path.join(args.output_dir, "checkpoint") 151 | else: 152 | train_loggers = [] 153 | tmp_dir = tempfile.TemporaryDirectory() 154 | model_dir = tmp_dir.name 155 | 156 | logger.info(f"[+] Start Training") 157 | if args.wandb_project and (args.wandb_run_name or args.output_dir): 158 | wandb_logger = WandbLogger( 159 | name=args.wandb_run_name or os.path.basename(args.output_dir), 160 | project=args.wandb_project, 161 | entity=args.wandb_entity, 162 | save_dir=args.output_dir if args.output_dir else None, 163 | ) 164 | wandb_logger.log_hyperparams({"train_arguments": vars(args)}) 165 | train_loggers.append(wandb_logger) 166 | 167 | model_checkpoint_callback = pl.callbacks.ModelCheckpoint( 168 | model_dir, mode="min", monitor="val/ppl", save_last=True, auto_insert_metric_name=True 169 | ) 170 | callbacks = [model_checkpoint_callback] 171 | 172 | if train_loggers: 173 | callbacks.append(LearningRateMonitor(logging_interval="step")) 174 | trainer = pl.Trainer( 175 | logger=train_loggers, 176 | max_epochs=args.epochs, 177 | log_every_n_steps=args.logging_interval, 178 | val_check_interval=args.valid_interval, 179 | accumulate_grad_batches=args.accumulate_grad_batches, 180 | gradient_clip_val=args.max_grad_norm, 181 | callbacks=callbacks, 182 | strategy="ddp_fork" if distributed else None, 183 | accelerator="gpu" if args.gpus else None, 184 | devices=num_parallels, 185 | replace_sampler_ddp=False, 186 | ) 187 | trainer.fit( 188 | language_modeling, 189 | datamodule=BatchedDataModule(train_dataset, valid_dataset, args.shuffle, distributed), 190 | ) 191 | 192 | # Use seperated initialized trainer (https://github.com/Lightning-AI/lightning/issues/8375) 193 | # Use batch size as 1 because already batched 194 | test_dataloader = DataLoader(test_dataset, batch_size=1, collate_fn=test_dataset.collate_fn) 195 | tester = pl.Trainer( 196 | logger=train_loggers, 197 | callbacks=callbacks, 198 | accelerator="gpu" if args.gpus else None, 199 | devices=1, 200 | ) 201 | result = tester.test(language_modeling, test_dataloader, ckpt_path=args.test_ckpt)[0] 202 | 203 | wandb.finish() 204 | 205 | if not args.output_dir: 206 | tmp_dir.cleanup() 207 | 208 | return result 209 | 210 | 211 | if __name__ == "__main__": 212 | main(parser.parse_args()) 213 | exit(0) 214 | -------------------------------------------------------------------------------- /experiment/train_synthetic.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import tempfile 4 | from typing import Dict 5 | 6 | import pytorch_lightning as pl 7 | import torch 8 | import wandb 9 | from longseq_formers.dataset.synthetic import SyntheticDataset, parse_syntetic_data 10 | from longseq_formers.task import Synthetic 11 | from longseq_formers.utils import get_logger 12 | from pytorch_lightning.callbacks import LearningRateMonitor 13 | from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger 14 | from torch.utils.data import DataLoader 15 | from transformers import AutoConfig, AutoModelForCausalLM 16 | 17 | # fmt: off 18 | parser = argparse.ArgumentParser(prog="train_synthetic", description="Train & Test Synthetic Task") 19 | 20 | g = parser.add_argument_group("Train Parameter") 21 | g.add_argument("--model-config", type=str, required=True, help="huggingface model config") 22 | g.add_argument("--dataset", type=str, required=True, help="dataset name") 23 | g.add_argument("--batch-size", type=int, default=32, help="global training batch size") 24 | g.add_argument("--valid-batch-size", type=int, default=1, help="validation batch size") 25 | g.add_argument("--accumulate-grad-batches", type=int, default=1, help="the number of gradident accumulation steps") 26 | g.add_argument("--epochs", type=int, default=1, help="the number of training epochs") 27 | g.add_argument("--learning-rate", type=float, default=2e-4, help="learning rate") 28 | g.add_argument("--warmup-rate", type=float, default=0.06, help="warmup step rate") 29 | g.add_argument("--max-grad-norm", type=float, default=1.0, help="maximum gradient norm") 30 | g.add_argument("--seed", type=int, default=42, help="random seed") 31 | g.add_argument("--test-ckpt", type=str, default="last", choices=["best", "last"], help="checkpoint type for testing") 32 | g.add_argument("--segment-size", type=int, required=True, help="segment size for infinity former") 33 | 34 | g = parser.add_argument_group("Personal Options") 35 | g.add_argument("--output-dir", type=str, help="output directory path to save artifacts") 36 | g.add_argument("--gpus", type=int, help="the number of gpus, use all devices by default") 37 | g.add_argument("--logging-interval", type=int, default=100, help="logging interval") 38 | g.add_argument("--valid-interval", type=float, default=1.0, help="validation interval rate") 39 | 40 | g = parser.add_argument_group("Wandb Options") 41 | g.add_argument("--wandb-run-name", type=str, help="wanDB run name") 42 | g.add_argument("--wandb-entity", type=str, help="wanDB entity name") 43 | g.add_argument("--wandb-project", type=str, help="wanDB project name") 44 | # fmt: on 45 | 46 | 47 | def main(args: argparse.Namespace) -> dict[str, float]: 48 | logger = get_logger("train_synthetic_task") 49 | 50 | if args.output_dir: 51 | os.makedirs(args.output_dir) 52 | logger.info(f'[+] Save output to "{args.output_dir}"') 53 | 54 | logger.info(" ====== Arguements ======") 55 | for k, v in vars(args).items(): 56 | logger.info(f"{k:25}: {v}") 57 | 58 | logger.info(f"[+] Set Random Seed to {args.seed}") 59 | pl.seed_everything(args.seed, workers=True) 60 | 61 | logger.info(f"[+] GPU: {args.gpus}") 62 | 63 | logger.info(f'[+] Use Dataset: "{args.dataset}"') 64 | _, vocab_size, train_examples, dev_examples, test_examples = parse_syntetic_data(args.dataset) 65 | 66 | if args.gpus is None: 67 | args.gpus = torch.cuda.device_count() 68 | num_parallels = max(args.gpus, 1) 69 | distributed = num_parallels > 1 70 | batch_size_per_device = max(args.batch_size // num_parallels, 1) 71 | global_batch_size = batch_size_per_device * num_parallels 72 | valid_batch_size_per_device = max(args.valid_batch_size // num_parallels, 1) 73 | global_valid_batch_size = valid_batch_size_per_device * num_parallels 74 | if args.batch_size != global_batch_size: 75 | logger.warning(f"[-] Batch size {args.batch_size} isn't dividable by {args.gpus}!") 76 | logger.warning(f"[-] Use batch size as {batch_size_per_device} per device, {global_batch_size} global") 77 | if args.valid_batch_size != global_valid_batch_size: 78 | logger.warning(f"[-] Valid Batch size {args.valid_batch_size} isn't dividable by {args.gpus}!") 79 | logger.warning( 80 | f"[-] Use batch size as {valid_batch_size_per_device} per device, {global_valid_batch_size} global" 81 | ) 82 | 83 | train_dataset = SyntheticDataset(train_examples) 84 | valid_dataset = SyntheticDataset(dev_examples) 85 | test_dataset = SyntheticDataset(test_examples) 86 | 87 | logger.info(f"[+] # of batched train examples: {len(train_dataset)}") 88 | logger.info(f"[+] # of batched valid examples: {len(valid_dataset)}") 89 | logger.info(f"[+] # of batched test examples: {len(test_dataset)}") 90 | 91 | logger.info(f'[+] Initialize Model with Config: "{args.model_config}"') 92 | config = AutoConfig.from_pretrained(args.model_config, trust_remote_code=True, vocab_size=vocab_size) 93 | model = AutoModelForCausalLM.from_config(config) 94 | 95 | train_dataloader = DataLoader( 96 | train_dataset, shuffle=True, batch_size=batch_size_per_device, num_workers=os.cpu_count() // 2, pin_memory=True 97 | ) 98 | valid_dataloader = DataLoader(valid_dataset, batch_size=args.valid_batch_size // num_parallels) 99 | test_dataloader = DataLoader(test_dataset, batch_size=args.valid_batch_size // num_parallels) 100 | total_steps = len(train_dataloader) * args.epochs 101 | 102 | synthetic_task = Synthetic( 103 | model=model, 104 | total_steps=total_steps, 105 | learning_rate=args.learning_rate, 106 | warmup_rate=args.warmup_rate, 107 | segment_size=args.segment_size, 108 | vocab_size=vocab_size, 109 | max_grad_norm=args.max_grad_norm, 110 | ) 111 | 112 | if args.output_dir: 113 | train_loggers = [TensorBoardLogger(args.output_dir, "", "logs")] 114 | model_dir = os.path.join(args.output_dir, "checkpoint") 115 | else: 116 | train_loggers = [] 117 | tmp_dir = tempfile.TemporaryDirectory() 118 | model_dir = tmp_dir.name 119 | 120 | logger.info(f"[+] Start Training") 121 | if args.wandb_project and (args.wandb_run_name or args.output_dir): 122 | wandb_logger = WandbLogger( 123 | name=args.wandb_run_name or os.path.basename(args.output_dir), 124 | project=args.wandb_project, 125 | entity=args.wandb_entity, 126 | save_dir=args.output_dir if args.output_dir else None, 127 | ) 128 | wandb_logger.log_hyperparams({"train_arguments": vars(args)}) 129 | train_loggers.append(wandb_logger) 130 | 131 | model_checkpoint_callback = pl.callbacks.ModelCheckpoint( 132 | model_dir, mode="max", monitor="val/acc-final", save_last=True, auto_insert_metric_name=True 133 | ) 134 | callbacks = [model_checkpoint_callback] 135 | 136 | if train_loggers: 137 | callbacks.append(LearningRateMonitor(logging_interval="step")) 138 | trainer = pl.Trainer( 139 | logger=train_loggers, 140 | max_epochs=args.epochs, 141 | log_every_n_steps=args.logging_interval, 142 | val_check_interval=args.valid_interval, 143 | accumulate_grad_batches=args.accumulate_grad_batches, 144 | callbacks=callbacks, 145 | strategy="ddp_fork" if distributed else None, 146 | accelerator="gpu" if args.gpus else None, 147 | devices=num_parallels, 148 | replace_sampler_ddp=False, 149 | ) 150 | trainer.fit(synthetic_task, train_dataloader, valid_dataloader) 151 | 152 | # Use seperated initialized trainer (https://github.com/Lightning-AI/lightning/issues/8375) 153 | # Use batch size as 1 because already batched 154 | tester = pl.Trainer( 155 | logger=train_loggers, 156 | callbacks=callbacks, 157 | accelerator="gpu" if args.gpus else None, 158 | devices=1, 159 | ) 160 | result = tester.test(synthetic_task, test_dataloader, ckpt_path=args.test_ckpt)[0] 161 | 162 | wandb.finish() 163 | 164 | if not args.output_dir: 165 | tmp_dir.cleanup() 166 | 167 | return result 168 | 169 | 170 | if __name__ == "__main__": 171 | main(parser.parse_args()) 172 | exit(0) 173 | -------------------------------------------------------------------------------- /images/Memoria-Engrams.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cosmoquester/memoria/e4ba6e2e13410e01fba896dd7800e66520e9d716/images/Memoria-Engrams.gif -------------------------------------------------------------------------------- /memoria/__init__.py: -------------------------------------------------------------------------------- 1 | from . import utils 2 | from .abstractor import Abstractor 3 | from .engram import Engrams, EngramType 4 | from .history_manager import HistoryManager 5 | from .memoria import Memoria 6 | from .sparse_tensor import SparseTensor 7 | 8 | __all__ = ["utils", "Abstractor", "Engrams", "EngramType", "HistoryManager", "Memoria", "SparseTensor"] 9 | __version__ = "1.0.0" 10 | -------------------------------------------------------------------------------- /memoria/abstractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Abstractor(nn.Module): 6 | """Abstract Module to summarize data""" 7 | 8 | def __init__(self, num_memories: int, hidden_dim: int, feedforward_dim: int) -> None: 9 | """ 10 | Args: 11 | num_memories (int): Number of memories to be created 12 | hidden_dim (int): Hidden dimension of the model 13 | feedforward_dim (int): Feedforward dimension of the model 14 | """ 15 | super().__init__() 16 | 17 | w = torch.empty(1, num_memories, hidden_dim) 18 | nn.init.normal_(w, std=0.02) 19 | self.query_embeddings = nn.Parameter(w) 20 | self.key_transform = nn.Linear(hidden_dim, hidden_dim) 21 | self.value_transform = nn.Linear(hidden_dim, hidden_dim) 22 | self.feedforward = nn.Linear(hidden_dim, feedforward_dim) 23 | self.output = nn.Linear(feedforward_dim, hidden_dim) 24 | 25 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 26 | """ 27 | 28 | Args: 29 | hidden_states (torch.Tensor): [Batch, N, HiddenDim] 30 | Returns: 31 | torch.Tensor: [Batch, NumMemories, HiddenDim] 32 | """ 33 | query = self.query_embeddings 34 | key = self.key_transform(hidden_states) 35 | # [Batch, N, HiddemDim] 36 | value = self.value_transform(hidden_states) 37 | # [Batch, NumMemories, HiddenDim] x [Batch, N, HiddenDim] -> [Batch, NumMemories, N] 38 | attn = query @ key.transpose(-2, -1) 39 | attn = attn.softmax(dim=-1) 40 | attn = attn @ value 41 | attn = self.feedforward(attn) 42 | attn = self.output(attn) 43 | return attn 44 | -------------------------------------------------------------------------------- /memoria/history_manager.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import pickle 3 | from collections import defaultdict 4 | 5 | from .types import EngramHistory, EngramInfo, EngramsInfo, Firing 6 | 7 | 8 | class HistoryManager: 9 | """Managing History of engram summaries. 10 | 11 | Attributes: 12 | timestep: Current timestep. 13 | summaries: List of engram summaries. 14 | engram_creation_times: Dictionary of engram creation times. 15 | engram_deletion_times: Dictionary of engram deletion times. 16 | engram_durations: Dictionary of engram durations. 17 | engram_firing_times: Dictionary of engram firing times. 18 | engram_firings: Dictionary of engram firings. 19 | firings_per_time: List of firings per timestep. 20 | engram_fire_counts: Dictionary of engram fire counts. 21 | engram_ids: List of engram ids. 22 | alive_engram_ids: List of alive engram ids. 23 | deleted_engram_ids: List of deleted engram ids. 24 | """ 25 | 26 | def __init__(self): 27 | self.summaries: list[EngramsInfo] = [] 28 | self.engram_creation_times: dict[int, int] = {} 29 | self.engram_deletion_times: dict[int, int] = {} 30 | self.engram_durations: dict[int, int] = {} 31 | self.engram_firing_times: dict[int, list[int]] = defaultdict(list) 32 | self.engram_firings: dict[int, list[Firing]] = defaultdict(list) 33 | self.firings_per_time: list[list[Firing]] = [] 34 | 35 | self.alive_engram_ids: list[int] = [] 36 | self.deleted_engram_ids: list[int] = [] 37 | 38 | def __len__(self) -> int: 39 | return len(self.summaries) 40 | 41 | def __getitem__(self, index: int) -> EngramsInfo: 42 | return self.summaries[index] 43 | 44 | def save(self, path: str) -> None: 45 | """Save the history manager to a compressed data file. 46 | 47 | Args: 48 | path: Path to save the history manager. 49 | """ 50 | with gzip.open(path, "wb") as f: 51 | pickle.dump(self, f) 52 | 53 | @classmethod 54 | def load(cls, path: str) -> "HistoryManager": 55 | """Load the history manager from a compressed data file. 56 | 57 | Args: 58 | path: Path to load the history manager. 59 | Returns: 60 | HistoryManager: Loaded history manager. 61 | """ 62 | with gzip.open(path, "rb") as f: 63 | return pickle.load(f) 64 | 65 | @property 66 | def timestep(self) -> int: 67 | """Get the current timestep.""" 68 | return len(self) 69 | 70 | @property 71 | def engram_ids(self) -> list[int]: 72 | """Get the list of engram IDs.""" 73 | return list(self.engram_creation_times.keys()) 74 | 75 | @property 76 | def engram_fire_counts(self) -> dict[int, int]: 77 | """Get the fire counts of the engrams.""" 78 | return {engram_id: len(firings) for engram_id, firings in self.engram_firings.items()} 79 | 80 | @property 81 | def engram_lastest_alive_timestep(self) -> dict[int, int]: 82 | """Get the latest alive timestep of the engrams.""" 83 | return { 84 | engram_id: creation_time + self.engram_durations[engram_id] - 1 85 | for engram_id, creation_time in self.engram_creation_times.items() 86 | } 87 | 88 | @property 89 | def latest_engram_infos(self) -> dict[int, EngramInfo]: 90 | """Get the latest engram information before dying.""" 91 | last_timestep = self.engram_lastest_alive_timestep 92 | return {engram_id: self.summaries[last_timestep[engram_id]].engrams[engram_id] for engram_id in self.engram_ids} 93 | 94 | def add_summary(self, summary: EngramsInfo) -> None: 95 | firings = [] 96 | for engram_id, engram in summary.engrams.items(): 97 | if engram_id not in self.alive_engram_ids: 98 | self.engram_creation_times[engram_id] = self.timestep 99 | self.alive_engram_ids.append(engram_id) 100 | elif ( 101 | engram_id in self.summaries[-1].engrams 102 | and engram.fire_count > self.summaries[-1].engrams[engram_id].fire_count 103 | ): 104 | self.engram_firing_times[engram_id].append(self.timestep) 105 | firing = Firing( 106 | timestep=self.timestep, 107 | engram_id=engram_id, 108 | lifespan_gain=engram.lifespan - self.summaries[-1].engrams[engram_id].lifespan + 1.0, 109 | ) 110 | self.engram_firings[engram_id].append(firing) 111 | firings.append(firing) 112 | self.firings_per_time.append(firings) 113 | 114 | for engram_id in self.alive_engram_ids: 115 | if engram_id not in summary.engrams: 116 | self.engram_deletion_times[engram_id] = self.timestep 117 | self.deleted_engram_ids.append(engram_id) 118 | self.alive_engram_ids = list(summary.engrams.keys()) 119 | for engram_id in self.alive_engram_ids: 120 | self.engram_durations[engram_id] = self.timestep - self.engram_creation_times[engram_id] + 1 121 | 122 | self.summaries.append(summary) 123 | 124 | def inspect(self, engram_id: int) -> EngramHistory: 125 | """Inspect the history of an engram. 126 | 127 | Args: 128 | engram_id: Engram ID to inspect. 129 | Returns: 130 | EngramHistory: Historical information of the engram. 131 | """ 132 | creation_time = self.engram_creation_times[engram_id] 133 | deletion_time = self.engram_deletion_times.get(engram_id) 134 | duration = self.engram_durations.get(engram_id) 135 | firing_times = self.engram_firing_times.get(engram_id, []) 136 | firings = self.engram_firings.get(engram_id, []) 137 | related_summaries = self.summaries[creation_time:deletion_time] 138 | 139 | return EngramHistory( 140 | id=engram_id, 141 | creation_timestep=creation_time, 142 | deletion_timestep=deletion_time, 143 | duration=duration, 144 | firing_times=firing_times, 145 | firings=firings, 146 | summaries=related_summaries, 147 | ) 148 | -------------------------------------------------------------------------------- /memoria/types.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal, Optional, Tuple 3 | 4 | 5 | @dataclass(slots=True, frozen=True) 6 | class EngramConnection: 7 | """Data structure for engram connections.""" 8 | 9 | #: Source engram ID. 10 | source_id: int 11 | #: Target engram ID. 12 | target_id: int 13 | #: Connection weight (Probability). 14 | weight: float 15 | #: Cofire count. 16 | cofire_count: int 17 | 18 | 19 | @dataclass(slots=True, frozen=True) 20 | class EngramInfo: 21 | """Data structure for engram information.""" 22 | 23 | #: Engram ID. 24 | id: int 25 | #: Engram type. 26 | type: Literal["WORKING", "SHORTTERM", "LONGTERM"] 27 | #: Lifetime of the engram. 28 | lifespan: int 29 | #: The age of the engram. 30 | age: Optional[int] 31 | #: Fire count of the engram. 32 | fire_count: int 33 | #: The outgoing edges of the engram. 34 | outgoings: Tuple[EngramConnection] 35 | #: The incoming edges of the engram. 36 | incoming: Tuple[EngramConnection] 37 | 38 | @property 39 | def cofire_counts(self) -> dict[int, int]: 40 | """Get the cofire counts of the engram.""" 41 | return {edge.target_id: edge.cofire_count for edge in self.outgoings} 42 | 43 | 44 | @dataclass(slots=True, frozen=True) 45 | class EngramsInfo: 46 | """Data structure for engrams information.""" 47 | 48 | #: Engram ID to EngramInfo mapping. 49 | engrams: dict[int, EngramInfo] 50 | #: All engram connections mapping from source and target engram IDs. 51 | edges: dict[Tuple[int, int], EngramConnection] 52 | #: Working memory engram IDs. 53 | working: Tuple[int] 54 | #: Short-term memory engram IDs. 55 | shortterm: Tuple[int] 56 | #: Long-term memory engram IDs. 57 | longterm: Tuple[int] 58 | 59 | 60 | @dataclass(slots=True, frozen=True) 61 | class Firing: 62 | """Data structure for firing information.""" 63 | 64 | #: Firing timestep. 65 | timestep: int 66 | #: Engram ID. 67 | engram_id: int 68 | #: Lifespan Gain. 69 | lifespan_gain: float 70 | 71 | 72 | @dataclass(slots=True, frozen=True) 73 | class EngramHistory: 74 | """Historical information of an engram.""" 75 | 76 | #: Engram ID. 77 | id: int 78 | #: Creation time of the engram. 79 | creation_timestep: int 80 | #: Deletion time of the engram. 81 | deletion_timestep: Optional[int] 82 | #: Duration of the engram. 83 | duration: Optional[int] 84 | #: Firing times of the engram. 85 | firing_times: list[int] 86 | #: Firing information of the engram. 87 | firings: list[Firing] 88 | #: Summaries of the engram. 89 | summaries: list[EngramsInfo] 90 | -------------------------------------------------------------------------------- /memoria/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def super_unique(t: torch.Tensor, dim: int) -> torch.Tensor: 5 | if t.numel() == 0: 6 | return t 7 | 8 | min_value = t.min() 9 | t = t - min_value 10 | 11 | max_value = t.max() 12 | new_shape = list(t.shape) 13 | new_shape[dim] = max_value + 1 14 | unique_t_mask = torch.zeros(new_shape, dtype=torch.bool, device=t.device) 15 | unique_t_mask.scatter_(dim, t.long(), True) 16 | 17 | k = min(t.size(dim), unique_t_mask.sum(dim).max().item()) 18 | validity, unique_t = unique_t_mask.int().topk(k, dim=dim) 19 | unique_t += min_value 20 | unique_t.masked_fill_(~validity.bool(), -1) 21 | return unique_t 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | include = '\.pyi?$' 4 | 5 | [tool.isort] 6 | multi_line_output = 3 7 | line_length = 120 8 | 9 | [tool.pyright] 10 | reportUnknownVariableType = false 11 | reportUnknownMemberType = false 12 | reportUnusedImport = true 13 | reportUnusedVariable = true 14 | reportUnusedClass = true 15 | reportUnusedFunction = true 16 | reportImportCycles = true 17 | reportTypeshedErrors = true 18 | reportOptionalMemberAccess = true 19 | reportUntypedBaseClass = true 20 | reportPrivateUsage = true 21 | reportConstantRedefinition = true 22 | reportInvalidStringEscapeSequence = true 23 | reportUnnecessaryIsInstance = true 24 | reportUnnecessaryCast = true 25 | reportAssertAlwaysTrue = true 26 | reportSelfClsParameterName = true 27 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | black 2 | isort 3 | pytest 4 | pytest-cov 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy<2 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | with open("README.md", "r") as f: 4 | long_description = f.read() 5 | 6 | setup( 7 | name="memoria-pytorch", 8 | version="1.1.0", 9 | description="Memoria is a human-inspired memory architecture for neural networks.", 10 | long_description=long_description, 11 | long_description_content_type="text/markdown", 12 | python_requires=">=3.10", 13 | install_requires=["torch"], 14 | url="https://github.com/cosmoquester/memoria.git", 15 | author="Park Sangjun", 16 | keywords=["memoria", "hebbian", "memory", "transformer"], 17 | classifiers=[ 18 | "Programming Language :: Python :: 3", 19 | "License :: OSI Approved :: MIT License", 20 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 21 | ], 22 | packages=find_packages(exclude=["tests", "experiment"]), 23 | ) 24 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cosmoquester/memoria/e4ba6e2e13410e01fba896dd7800e66520e9d716/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_abstractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from memoria.abstractor import Abstractor 4 | 5 | 6 | def test_abstractor(): 7 | abstractor = Abstractor(num_memories=3, hidden_dim=4, feedforward_dim=5) 8 | hidden_states = torch.randn(2, 3, 4) 9 | output = abstractor(hidden_states) 10 | assert output.shape == (2, 3, 4) 11 | -------------------------------------------------------------------------------- /tests/test_history_manager.py: -------------------------------------------------------------------------------- 1 | from memoria.history_manager import HistoryManager 2 | from memoria.types import EngramInfo, EngramsInfo, Firing 3 | 4 | 5 | def test_history_manager(): 6 | history_manager = HistoryManager() 7 | assert len(history_manager) == 0 8 | assert history_manager.timestep == 0 9 | assert history_manager.engram_ids == [] 10 | assert history_manager.alive_engram_ids == [] 11 | assert history_manager.deleted_engram_ids == [] 12 | 13 | engrams = { 14 | 1: EngramInfo(id=1, type="WORKING", lifespan=4, age=0, fire_count=0, outgoings=[], incoming=[]), 15 | 2: EngramInfo(id=2, type="SHORTTERM", lifespan=5, age=0, fire_count=0, outgoings=[], incoming=[]), 16 | 3: EngramInfo(id=3, type="LONGTERM", lifespan=6, age=0, fire_count=0, outgoings=[], incoming=[]), 17 | } 18 | history_manager.add_summary(EngramsInfo(engrams=engrams, edges={}, working=[1], shortterm=[2], longterm=[3])) 19 | assert len(history_manager) == 1 20 | assert history_manager.timestep == 1 21 | assert history_manager.engram_ids == [1, 2, 3] 22 | assert history_manager.alive_engram_ids == [1, 2, 3] 23 | assert history_manager.deleted_engram_ids == [] 24 | assert history_manager.engram_fire_counts == {} 25 | assert history_manager.engram_lastest_alive_timestep == {1: 0, 2: 0, 3: 0} 26 | assert history_manager.latest_engram_infos == {1: engrams[1], 2: engrams[2], 3: engrams[3]} 27 | 28 | engrams2 = { 29 | 2: EngramInfo(id=2, type="SHORTTERM", lifespan=4, age=0, fire_count=0, outgoings=[], incoming=[]), 30 | 3: EngramInfo(id=3, type="LONGTERM", lifespan=8, age=0, fire_count=1, outgoings=[], incoming=[]), 31 | 4: EngramInfo(id=4, type="LONGTERM", lifespan=6, age=0, fire_count=0, outgoings=[], incoming=[]), 32 | } 33 | history_manager.add_summary(EngramsInfo(engrams=engrams2, edges={}, working=[], shortterm=[2], longterm=[3, 4])) 34 | assert len(history_manager) == 2 35 | assert history_manager.timestep == 2 36 | assert history_manager.engram_ids == [1, 2, 3, 4] 37 | assert history_manager.alive_engram_ids == [2, 3, 4] 38 | assert history_manager.deleted_engram_ids == [1] 39 | assert history_manager.engram_firing_times == {3: [1]} 40 | assert history_manager.engram_firings == {3: [Firing(timestep=1, engram_id=3, lifespan_gain=3.0)]} 41 | assert history_manager.engram_fire_counts == {3: 1} 42 | assert history_manager.engram_lastest_alive_timestep == {1: 0, 2: 1, 3: 1, 4: 1} 43 | assert history_manager.latest_engram_infos == {1: engrams[1], 2: engrams2[2], 3: engrams2[3], 4: engrams2[4]} 44 | 45 | engram_history = history_manager.inspect(1) 46 | assert engram_history.id == 1 47 | assert engram_history.creation_timestep == 0 48 | assert engram_history.deletion_timestep == 1 49 | assert engram_history.duration == 1 50 | assert engram_history.firing_times == [] 51 | assert engram_history.firings == [] 52 | -------------------------------------------------------------------------------- /tests/test_memoria.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from memoria.engram import Engrams, EngramType 4 | from memoria.memoria import Memoria 5 | 6 | 7 | def test_add_working_memory(): 8 | memoria = Memoria( 9 | num_reminded_stm=10, 10 | ltm_search_depth=3, 11 | stm_capacity=100, 12 | initial_lifespan=100, 13 | num_final_ltms=100, 14 | ) 15 | memoria.add_working_memory(torch.randn(3, 10, 32)) 16 | assert len(memoria.engrams) == 30 17 | 18 | 19 | def test_calculate_wm_stm_weight(): 20 | memoria = Memoria( 21 | num_reminded_stm=10, 22 | ltm_search_depth=3, 23 | stm_capacity=100, 24 | initial_lifespan=100, 25 | num_final_ltms=100, 26 | ) 27 | memoria.add_working_memory(torch.randn(3, 10, 32)) 28 | 29 | wm = Engrams(torch.randn(3, 10, 32)) 30 | stm = Engrams(torch.randn(3, 20, 32), engrams_types=EngramType.SHORTTERM) 31 | weight = memoria._calculate_memory_weight(wm, stm) 32 | assert weight.shape == torch.Size([3, 10, 20]) 33 | 34 | 35 | def test_remind_shortterm_memory(): 36 | memoria = Memoria( 37 | num_reminded_stm=2, 38 | ltm_search_depth=3, 39 | stm_capacity=100, 40 | initial_lifespan=100, 41 | num_final_ltms=100, 42 | ) 43 | 44 | weight = torch.tensor([[[0.51, 0.2, 0.2, 0.8]]]) 45 | shortterm_memory_indices = torch.tensor([[1, 2, 3, 4]]) 46 | reminded = memoria._remind_shortterm_memory(weight, shortterm_memory_indices) 47 | assert (reminded == torch.tensor([[1, -1, -1, 4]])).all() 48 | 49 | 50 | def test_find_initial_ltm(): 51 | num_stm = 5 52 | num_ltm = 4 53 | memoria = Memoria( 54 | num_reminded_stm=10, 55 | ltm_search_depth=3, 56 | stm_capacity=100, 57 | initial_lifespan=100, 58 | num_final_ltms=100, 59 | ) 60 | 61 | stm = Engrams(torch.randn(1, num_stm, 32), engrams_types=EngramType.SHORTTERM) 62 | ltm = Engrams(torch.randn(1, num_ltm, 32), engrams_types=EngramType.LONGTERM) 63 | engrams = stm + ltm 64 | memoria.engrams = engrams 65 | memoria.engrams.induce_counts[:, :num_stm, num_stm:] = torch.tensor( 66 | [ 67 | [ 68 | [1, 1, 2, 1], 69 | [1, 1, 1, 2], 70 | [1, 10, 1, 1], 71 | [1, 1, 2, 1], 72 | [1, 5, 1, 1], 73 | ] 74 | ] 75 | ) 76 | memoria.engrams.induce_counts[:, :num_stm, :num_stm] = 999 77 | nearest_stm_indices = torch.tensor([[0, 2, 3]]) 78 | 79 | initial_ltm_indices = memoria._find_initial_longterm_memory(nearest_stm_indices) 80 | assert (initial_ltm_indices == torch.tensor([[6, 7]])).all() 81 | 82 | 83 | def test_search_longterm_memories_with_initials(): 84 | num_stm = 5 85 | num_ltm = 4 86 | ltm_search_depth = 3 87 | memoria = Memoria( 88 | num_reminded_stm=10, 89 | ltm_search_depth=ltm_search_depth, 90 | stm_capacity=100, 91 | initial_lifespan=100, 92 | num_final_ltms=100, 93 | ) 94 | 95 | stm = Engrams(torch.randn(1, num_stm, 32), engrams_types=EngramType.SHORTTERM) 96 | ltm = Engrams(torch.randn(1, num_ltm, 32), engrams_types=EngramType.LONGTERM) 97 | memoria.engrams = stm + ltm 98 | 99 | initial_ltm_indices = torch.tensor([[5, 7]]) 100 | searched_ltm_indices = memoria._search_longterm_memories_with_initials(initial_ltm_indices, ltm) 101 | 102 | assert (searched_ltm_indices == torch.tensor([[1, 3, 2, -1, 0]])).all() 103 | 104 | 105 | def test_memorize_working_memory_as_shortterm_memory(): 106 | batch_size = 3 107 | num_wm = 5 108 | num_stm = 4 109 | num_ltm = 2 110 | ltm_search_depth = 3 111 | memoria = Memoria( 112 | num_reminded_stm=10, 113 | ltm_search_depth=ltm_search_depth, 114 | stm_capacity=100, 115 | initial_lifespan=100, 116 | num_final_ltms=100, 117 | ) 118 | 119 | wm = Engrams(torch.randn(batch_size, num_wm, 32), engrams_types=EngramType.WORKING) 120 | stm = Engrams(torch.randn(batch_size, num_stm, 32), engrams_types=EngramType.SHORTTERM) 121 | ltm = Engrams(torch.randn(batch_size, num_ltm, 32), engrams_types=EngramType.LONGTERM) 122 | memoria.engrams = wm + stm + ltm 123 | 124 | memoria._memorize_working_memory_as_shortterm_memory() 125 | 126 | assert memoria.engrams.get_shortterm_memory()[0].data.shape == torch.Size([batch_size, num_wm + num_stm, 32]) 127 | 128 | 129 | def test_memorize_shortterm_memory_as_longterm_memory_or_drop(): 130 | batch_size = 1 131 | num_stm = 5 132 | num_ltm = 3 133 | ltm_search_depth = 3 134 | memoria = Memoria( 135 | num_reminded_stm=10, 136 | ltm_search_depth=ltm_search_depth, 137 | stm_capacity=2, 138 | initial_lifespan=100, 139 | num_final_ltms=100, 140 | ) 141 | 142 | fire_count = torch.tensor([[0, 1, 2, 3, 0]], dtype=torch.int32) 143 | stm = Engrams(torch.randn(batch_size, num_stm, 32), engrams_types=EngramType.SHORTTERM) 144 | stm.fire_count = fire_count 145 | ltm = Engrams(torch.randn(batch_size, num_ltm, 32), engrams_types=EngramType.LONGTERM) 146 | memoria.engrams = stm + ltm 147 | 148 | memoria._memorize_shortterm_memory_as_longterm_memory() 149 | 150 | assert len(memoria.engrams) == batch_size * (num_stm + num_ltm) 151 | 152 | 153 | def test_remind(): 154 | num_reminded_stm = 2 155 | ltm_search_depth = 3 156 | stm_capacity = 100 157 | memoria = Memoria( 158 | num_reminded_stm=num_reminded_stm, 159 | ltm_search_depth=ltm_search_depth, 160 | stm_capacity=stm_capacity, 161 | initial_lifespan=100, 162 | num_final_ltms=100, 163 | ) 164 | 165 | batch_size = 3 166 | memory_length = 50 167 | hidden_dim = 32 168 | working_memory = torch.randn(batch_size, memory_length, hidden_dim) 169 | memoria.add_working_memory(working_memory) 170 | outputs, indices = memoria.remind() 171 | memoria.adjust_lifespan_and_memories(indices, torch.ones_like(indices, dtype=float)) 172 | assert len(memoria.engrams.get_shortterm_memory()[0]) == batch_size * memory_length 173 | assert outputs.size(1) == 0 174 | 175 | working_memory = torch.randn(batch_size, memory_length, hidden_dim) 176 | memoria.add_working_memory(working_memory) 177 | outputs, indices = memoria.remind() 178 | memoria.adjust_lifespan_and_memories(indices, torch.ones_like(indices, dtype=float)) 179 | assert len(memoria.engrams.get_shortterm_memory()[0]) == batch_size * memory_length * 2 180 | assert outputs.size(1) > 0 181 | 182 | working_memory = torch.randn(batch_size, memory_length, hidden_dim) 183 | memoria.add_working_memory(working_memory) 184 | outputs, indices = memoria.remind() 185 | memoria.adjust_lifespan_and_memories(indices, torch.ones_like(indices, dtype=float)) 186 | assert len(memoria.engrams.get_shortterm_memory()[0]) == batch_size * memory_length * 2 187 | assert outputs.size(1) > 0 188 | 189 | 190 | def test_reset_memory(): 191 | memoria = Memoria( 192 | num_reminded_stm=10, 193 | ltm_search_depth=3, 194 | stm_capacity=100, 195 | initial_lifespan=100, 196 | num_final_ltms=100, 197 | ) 198 | memoria.add_working_memory(torch.randn(3, 10, 32)) 199 | memoria.reset_memory() 200 | assert memoria.engrams == Engrams.empty() 201 | -------------------------------------------------------------------------------- /tests/test_sparse_tensor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from memoria.sparse_tensor import SparseTensor 4 | 5 | 6 | def test_from_tensor(): 7 | tensor = torch.tensor([[1, 0, 0], [0, 2, 0], [0, 0, 3]], dtype=torch.int32) 8 | sparse_tensor = SparseTensor.from_tensor(tensor) 9 | assert sparse_tensor.indices.tolist() == [[0, 0], [1, 1], [2, 2]] 10 | assert sparse_tensor.values.tolist() == [1, 2, 3] 11 | 12 | 13 | def test_get_item(): 14 | tensor = torch.tensor([[1, 0, 0], [0, 2, 0], [0, 0, 3]], dtype=torch.int32) 15 | sparse_tensor = SparseTensor.from_tensor(tensor) 16 | 17 | assert sparse_tensor.indices.tolist() == [[0, 0], [1, 1], [2, 2]] 18 | assert sparse_tensor.values.tolist() == [1, 2, 3] 19 | 20 | selected = sparse_tensor[0, 0] 21 | assert isinstance(selected, torch.Tensor) 22 | assert selected.item() == 1 23 | 24 | selected = sparse_tensor[0, 2] 25 | assert isinstance(selected, torch.Tensor) 26 | assert selected.item() == 0 27 | 28 | selected = sparse_tensor[1] 29 | assert selected.shape == (3,) 30 | assert selected.indices.tolist() == [[1]] 31 | assert selected.values.tolist() == [2] 32 | assert selected.tolist() == [0, 2, 0] 33 | 34 | selected = sparse_tensor[:, 2] 35 | assert selected.shape == (3,) 36 | assert selected.indices.tolist() == [[2]] 37 | assert selected.values.tolist() == [3] 38 | assert selected.tolist() == [0, 0, 3] 39 | 40 | selected = sparse_tensor[torch.tensor([0, 2])] 41 | assert selected.shape == (2, 3) 42 | assert selected.indices.tolist() == [[0, 0], [1, 2]] 43 | assert selected.values.tolist() == [1, 3] 44 | assert selected.tolist() == [[1, 0, 0], [0, 0, 3]] 45 | 46 | selected = sparse_tensor[torch.tensor([0, 2]), 2] 47 | assert selected.shape == (2,) 48 | assert selected.indices.tolist() == [[1]] 49 | assert selected.values.tolist() == [3] 50 | assert selected.tolist() == [0, 3] 51 | 52 | selected = sparse_tensor[torch.tensor([[[0, 2]]])] 53 | assert selected.shape == (1, 1, 2, 3) 54 | assert selected.indices.tolist() == [[0, 0, 0, 0], [0, 0, 1, 2]] 55 | assert selected.values.tolist() == [1, 3] 56 | assert selected.tolist() == [[[[1, 0, 0], [0, 0, 3]]]] 57 | 58 | selected = sparse_tensor[torch.tensor([[0, 1], [2, 0]]), torch.tensor([[1, 2], [2, 0]])] 59 | assert selected.shape == (2, 2) 60 | assert selected.indices.tolist() == [[1, 0], [1, 1]] 61 | assert selected.values.tolist() == [3, 1] 62 | assert selected.tolist() == [[0, 0], [3, 1]] 63 | 64 | selected = sparse_tensor[0:2] 65 | assert selected.shape == (2, 3) 66 | assert selected.indices.tolist() == [[0, 0], [1, 1]] 67 | assert selected.values.tolist() == [1, 2] 68 | assert selected.tolist() == [[1, 0, 0], [0, 2, 0]] 69 | 70 | selected = sparse_tensor[0:1, 1:3] 71 | assert selected.shape == (1, 2) 72 | assert selected.indices.tolist() == [] 73 | assert selected.values.tolist() == [] 74 | assert selected.tolist() == [[0, 0]] 75 | 76 | selected = sparse_tensor[0:2, 1:3] 77 | assert selected.shape == (2, 2) 78 | assert selected.indices.tolist() == [[1, 0]] 79 | assert selected.values.tolist() == [2] 80 | assert selected.tolist() == [[0, 0], [2, 0]] 81 | 82 | selected = sparse_tensor[torch.tensor([0, 2]), 1:3] 83 | assert selected.shape == (2, 2) 84 | assert selected.indices.tolist() == [[1, 1]] 85 | assert selected.values.tolist() == [3] 86 | assert selected.tolist() == [[0, 0], [0, 3]] 87 | 88 | 89 | def test_set_item(): 90 | tensor = torch.tensor([[1, 0, 0], [0, 2, 0], [0, 0, 3]], dtype=torch.int32) 91 | sparse_tensor = SparseTensor.from_tensor(tensor) 92 | 93 | sparse_tensor[0, 0] = torch.tensor(90) 94 | assert sparse_tensor.indices.tolist() == [[1, 1], [2, 2], [0, 0]] 95 | assert sparse_tensor.values.tolist() == [2, 3, 90] 96 | assert sparse_tensor.to_dense().tolist() == [[90, 0, 0], [0, 2, 0], [0, 0, 3]] 97 | 98 | sparse_tensor[0, 0] += 10 99 | assert sparse_tensor.indices.tolist() == [[1, 1], [2, 2], [0, 0]] 100 | assert sparse_tensor.values.tolist() == [2, 3, 100] 101 | assert sparse_tensor.to_dense().tolist() == [[100, 0, 0], [0, 2, 0], [0, 0, 3]] 102 | 103 | sparse_tensor = SparseTensor.from_tensor(tensor) 104 | sparse_tensor[0, 0] = 10 105 | assert sparse_tensor.indices.tolist() == [[1, 1], [2, 2], [0, 0]] 106 | assert sparse_tensor.values.tolist() == [2, 3, 10] 107 | assert sparse_tensor.to_dense().tolist() == [[10, 0, 0], [0, 2, 0], [0, 0, 3]] 108 | 109 | sparse_tensor[1] = 20 110 | assert sparse_tensor.indices.tolist() == [[2, 2], [0, 0], [1, 0], [1, 1], [1, 2]] 111 | assert sparse_tensor.values.tolist() == [3, 10, 20, 20, 20] 112 | assert sparse_tensor.to_dense().tolist() == [[10, 0, 0], [20, 20, 20], [0, 0, 3]] 113 | 114 | sparse_tensor[:, 2] = 30 115 | assert sparse_tensor.indices.tolist() == [[0, 0], [1, 0], [1, 1], [0, 2], [1, 2], [2, 2]] 116 | assert sparse_tensor.values.tolist() == [10, 20, 20, 30, 30, 30] 117 | assert sparse_tensor.to_dense().tolist() == [[10, 0, 30], [20, 20, 30], [0, 0, 30]] 118 | 119 | sparse_tensor = SparseTensor.from_tensor(tensor) 120 | sparse_tensor[torch.tensor([0, 2])] = 40 121 | assert sparse_tensor.indices.tolist() == [[1, 1], [0, 0], [0, 1], [0, 2], [2, 0], [2, 1], [2, 2]] 122 | assert sparse_tensor.values.tolist() == [2, 40, 40, 40, 40, 40, 40] 123 | assert sparse_tensor.to_dense().tolist() == [[40, 40, 40], [0, 2, 0], [40, 40, 40]] 124 | 125 | sparse_tensor = SparseTensor.from_tensor(tensor) 126 | sparse_tensor[torch.tensor([0, 2]), 2] = 50 127 | assert sparse_tensor.indices.tolist() == [[0, 0], [1, 1], [0, 2], [2, 2]] 128 | assert sparse_tensor.values.tolist() == [1, 2, 50, 50] 129 | assert sparse_tensor.to_dense().tolist() == [[1, 0, 50], [0, 2, 0], [0, 0, 50]] 130 | 131 | sparse_tensor = SparseTensor.from_tensor(tensor) 132 | sparse_tensor[torch.tensor([[[0, 2]]])] = 60 133 | assert sparse_tensor.indices.tolist() == [[1, 1], [0, 0], [0, 1], [0, 2], [2, 0], [2, 1], [2, 2]] 134 | assert sparse_tensor.values.tolist() == [2, 60, 60, 60, 60, 60, 60] 135 | assert sparse_tensor.to_dense().tolist() == [[60, 60, 60], [0, 2, 0], [60, 60, 60]] 136 | 137 | sparse_tensor = SparseTensor.from_tensor(tensor) 138 | sparse_tensor[torch.tensor([[0, 1], [2, 0]]), torch.tensor([[1, 2], [2, 0]])] = 70 139 | assert sparse_tensor.indices.tolist() == [[1, 1], [0, 1], [1, 2], [2, 2], [0, 0]] 140 | assert sparse_tensor.values.tolist() == [2, 70, 70, 70, 70] 141 | assert sparse_tensor.to_dense().tolist() == [[70, 70, 0], [0, 2, 70], [0, 0, 70]] 142 | 143 | sparse_tensor = SparseTensor.from_tensor(tensor) 144 | sparse_tensor[torch.tensor([[0, 1], [2, 0]]), torch.tensor([[1, 2], [2, 0]])] = torch.tensor([[70, 80], [90, 100]]) 145 | assert sparse_tensor.indices.tolist() == [[1, 1], [0, 1], [1, 2], [2, 2], [0, 0]] 146 | assert sparse_tensor.values.tolist() == [2, 70, 80, 90, 100] 147 | assert sparse_tensor.to_dense().tolist() == [[100, 70, 0], [0, 2, 80], [0, 0, 90]] 148 | 149 | sparse_tensor = SparseTensor.from_tensor(tensor) 150 | sparse_tensor[0:2] = 80 151 | assert sparse_tensor.indices.tolist() == [[2, 2], [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]] 152 | assert sparse_tensor.values.tolist() == [3, 80, 80, 80, 80, 80, 80] 153 | assert sparse_tensor.to_dense().tolist() == [[80, 80, 80], [80, 80, 80], [0, 0, 3]] 154 | 155 | 156 | def test_diagonal(): 157 | tensor = torch.randn(2, 5, 3, 5) 158 | sparse_tensor = SparseTensor.from_tensor(tensor) 159 | 160 | assert (tensor.diagonal(dim1=1, dim2=3) == sparse_tensor.diagonal(dim1=1, dim2=3).to_dense()).all() 161 | 162 | 163 | def test_equals(): 164 | tensor = torch.randn(2, 5, 3, 5) 165 | sparse_tensor = SparseTensor.from_tensor(tensor) 166 | 167 | assert tensor == sparse_tensor 168 | assert (tensor == sparse_tensor.to_dense()).all() 169 | assert sparse_tensor != SparseTensor.from_tensor(torch.randn(2, 5, 3, 5)) 170 | 171 | 172 | def test_add(): 173 | tensor = torch.randint(0, 5, [2, 5, 3, 5]) 174 | tensor2 = torch.randint(0, 5, [2, 5, 3, 5]) 175 | sparse_tensor = SparseTensor.from_tensor(tensor) 176 | sparse_tensor2 = SparseTensor.from_tensor(tensor2) 177 | 178 | assert (tensor + 1 == (sparse_tensor + 1).to_dense()).all() 179 | assert (tensor + tensor == (sparse_tensor + sparse_tensor).to_dense()).all() 180 | assert (tensor + tensor2 == (sparse_tensor + sparse_tensor2).to_dense()).all() 181 | 182 | 183 | def test_unsqueeze(): 184 | tensor = torch.randn(2, 5, 3, 5) 185 | sparse_tensor = SparseTensor.from_tensor(tensor) 186 | 187 | assert sparse_tensor.unsqueeze(0).shape == (1, 2, 5, 3, 5) 188 | assert sparse_tensor.unsqueeze(1).shape == (2, 1, 5, 3, 5) 189 | assert sparse_tensor.unsqueeze(2).shape == (2, 5, 1, 3, 5) 190 | assert sparse_tensor.unsqueeze(3).shape == (2, 5, 3, 1, 5) 191 | assert sparse_tensor.unsqueeze(4).shape == (2, 5, 3, 5, 1) 192 | 193 | 194 | def test_to(): 195 | tensor = torch.randn(2, 5, 3, 5) 196 | sparse_tensor = SparseTensor.from_tensor(tensor) 197 | 198 | assert sparse_tensor.to(torch.device("cpu")) == sparse_tensor 199 | assert (sparse_tensor.to(torch.device("cpu")).to_dense() == tensor).all() 200 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from memoria.utils import super_unique 4 | 5 | 6 | def test_super_unique(): 7 | x = torch.tensor( 8 | [ 9 | [[0, 1, 0, 2, 1], [1, 4, 1, 2, 1], [3, 4, 1, 4, 3], [2, 0, 0, 3, 2]], 10 | [[0, 1, 4, 2, 4], [1, 4, 3, 2, 0], [2, 3, 4, 2, 0], [4, 4, 4, 3, 0]], 11 | [[2, 3, 2, 3, 3], [2, 0, 1, 3, 0], [1, 3, 2, 0, 0], [1, 2, 3, 0, 0]], 12 | ], 13 | dtype=torch.int32, 14 | ) 15 | assert ( 16 | super_unique(x, dim=1) 17 | == torch.tensor( 18 | [ 19 | [[1, 1, 1, 2, 1], [3, 4, 0, 4, 3], [2, 0, -1, 3, 2], [0, -1, -1, -1, -1]], 20 | [[2, 1, 4, 2, 4], [4, 4, 3, 3, 0], [0, 3, -1, -1, -1], [1, -1, -1, -1, -1]], 21 | [[1, 2, 1, 0, 0], [2, 3, 3, 3, 3], [-1, 0, 2, -1, -1], [-1, -1, -1, -1, -1]], 22 | ] 23 | ) 24 | ).all() 25 | 26 | x = torch.tensor([[2, 3, 4, 3, 0], [3, 1, 3, 1, 0], [4, 3, 2, 2, 4], [2, 2, 2, 0, 3]], dtype=torch.int32) 27 | assert ( 28 | super_unique(x, dim=0) == torch.tensor([[2, 1, 2, 1, 0], [4, 3, 4, 3, 3], [3, 2, 3, 2, 4], [-1, -1, -1, 0, -1]]) 29 | ).all() 30 | --------------------------------------------------------------------------------