├── .circleci
└── config.yml
├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── .vscode
├── extensions.json
└── settings.json
├── CITATION.bib
├── LICENSE
├── README.md
├── experiment
├── .gitignore
├── README.md
├── configs
│ ├── compressive-former-enwik8.json
│ ├── compressive-former-sort.json
│ ├── compressive-former.json
│ ├── gpt2-enwik8.json
│ ├── gpt2-synthetic.json
│ ├── gpt2-xl.json
│ ├── gpt2.json
│ ├── infinity-gpt2-enwik8.json
│ ├── infinity-gpt2-sort.json
│ ├── infinity-gpt2-synthetic.json
│ ├── infinity-gpt2.json
│ ├── memoria-enwik8.json
│ ├── memoria-gpt2-large.json
│ ├── memoria-gpt2-medium.json
│ ├── memoria-gpt2-sort.json
│ ├── memoria-gpt2-synthetic.json
│ ├── memoria-gpt2-xl.json
│ ├── memoria-gpt2.json
│ ├── transfo-xl-enwik8.json
│ ├── transfo-xl-sort.json
│ ├── transfo-xl-synthetic.json
│ └── transfo-xl.json
├── eval_classification.py
├── eval_language_modeling.py
├── eval_synthetic.py
├── longseq_formers
│ ├── __init__.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── enwik8.py
│ │ ├── hyperpartisan.py
│ │ ├── pg19.py
│ │ └── wikitext103.py
│ ├── dataset
│ │ ├── __init__.py
│ │ ├── classification.py
│ │ ├── language_modeling.py
│ │ └── synthetic.py
│ ├── model
│ │ ├── __init__.py
│ │ ├── compressive_former
│ │ │ ├── __init__.py
│ │ │ └── modeling_compressive_transformer.py
│ │ ├── gpt2_with_memoria
│ │ │ ├── __init__.py
│ │ │ └── modeling_gpt2_with_memoria.py
│ │ ├── infinity_gpt2
│ │ │ ├── __init__.py
│ │ │ ├── basis_functions.py
│ │ │ ├── configuration_infinity_gpt2.py
│ │ │ ├── continuous_softmax.py
│ │ │ ├── continuous_sparsemax.py
│ │ │ ├── long_term_attention.py
│ │ │ └── modeling_infinity_gpt2.py
│ │ ├── memoria_bert
│ │ │ ├── __init__.py
│ │ │ ├── configuration_memoria_bert.py
│ │ │ └── modeling_memoria_bert.py
│ │ └── memoria_roberta
│ │ │ ├── __init__.py
│ │ │ ├── configuration_memoria_roberta.py
│ │ │ └── modeling_memoria_roberta.py
│ ├── task
│ │ ├── __init__.py
│ │ ├── classification.py
│ │ ├── language_modeling.py
│ │ └── synthetic.py
│ └── utils.py
├── requirements.txt
├── train_classification.py
├── train_language_modeling.py
└── train_synthetic.py
├── images
└── Memoria-Engrams.gif
├── memoria
├── __init__.py
├── abstractor.py
├── engram.py
├── history_manager.py
├── memoria.py
├── sparse_tensor.py
├── types.py
└── utils.py
├── pyproject.toml
├── requirements-dev.txt
├── requirements.txt
├── setup.py
└── tests
├── __init__.py
├── test_abstractor.py
├── test_engram.py
├── test_history_manager.py
├── test_memoria.py
├── test_sparse_tensor.py
└── test_utils.py
/.circleci/config.yml:
--------------------------------------------------------------------------------
1 | version: 2.1
2 |
3 | orbs:
4 | codecov: codecov/codecov@4.0.1
5 |
6 | executors:
7 | python-executor:
8 | working_directory: ~/memoria
9 | docker:
10 | - image: circleci/python:3.10
11 |
12 | commands:
13 | install-packages:
14 | steps:
15 | - checkout
16 |
17 | - restore_cache:
18 | key: deps-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
19 |
20 | - run:
21 | name: Create Virtual Environment and Install Dependencies
22 | command: |
23 | virtualenv env
24 | source env/bin/activate
25 | pip install -r requirements.txt -r requirements-dev.txt
26 |
27 | - save_cache:
28 | key: deps-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
29 | paths:
30 | - "env"
31 |
32 | jobs:
33 | run-test:
34 | executor: python-executor
35 | steps:
36 | - install-packages
37 |
38 | - run:
39 | name: Run Tests and Coverage
40 | command: |
41 | source env/bin/activate
42 | pytest --cov --cov-branch --cov-report=xml
43 |
44 | - codecov/upload
45 |
46 | check-linting:
47 | executor: python-executor
48 | steps:
49 | - install-packages
50 |
51 | - run:
52 | name: Run black, isort
53 | command: |
54 | source env/bin/activate
55 | black --check memoria tests
56 | isort memoria tests
57 |
58 | workflows:
59 | main:
60 | jobs:
61 | - run-test
62 | - check-linting
63 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | permissions:
16 | contents: read
17 |
18 | jobs:
19 | deploy:
20 |
21 | runs-on: ubuntu-latest
22 |
23 | steps:
24 | - uses: actions/checkout@v4
25 | - name: Set up Python
26 | uses: actions/setup-python@v3
27 | with:
28 | python-version: '3.x'
29 | - name: Install dependencies
30 | run: |
31 | python -m pip install --upgrade pip
32 | pip install build
33 | - name: Build package
34 | run: python -m build
35 | - name: Publish package
36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
37 | with:
38 | user: __token__
39 | password: ${{ secrets.PYPI_API_TOKEN }}
40 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # Environments
7 | .env
8 | .venv
9 | env/
10 | venv/
11 | ENV/
12 | env.bak/
13 | venv.bak/
14 |
15 | # personal vscode settings
16 | .vscode/launch.json
17 | .vscode/tasks.json
18 |
19 | # log file
20 | *.log
21 |
22 | # build byproducts
23 | build/
24 | dist/
25 | memoria_pytorch.egg-info/
26 |
--------------------------------------------------------------------------------
/.vscode/extensions.json:
--------------------------------------------------------------------------------
1 | {
2 | "recommendations": [
3 | "ms-python.python",
4 | "ms-pyright.pyright",
5 | ]
6 | }
7 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.pythonPath": "env/bin/python",
3 | "python.formatting.provider": "black",
4 | "[python]": {
5 | "editor.codeActionsOnSave": {
6 | "source.organizeImports": "explicit"
7 | }
8 | },
9 | "editor.formatOnSave": true,
10 | "files.trimTrailingWhitespace": true,
11 | "files.insertFinalNewline": true,
12 | "python.testing.pytestArgs": [
13 | "./tests"
14 | ],
15 | "python.testing.unittestEnabled": false,
16 | "python.testing.pytestEnabled": true
17 | }
18 |
--------------------------------------------------------------------------------
/CITATION.bib:
--------------------------------------------------------------------------------
1 | @InProceedings{pmlr-v235-park24a,
2 | title = {Memoria: Resolving Fateful Forgetting Problem through Human-Inspired Memory Architecture},
3 | author = {Park, Sangjun and Bak, Jinyeong},
4 | booktitle = {Proceedings of the 41st International Conference on Machine Learning},
5 | pages = {39587--39615},
6 | year = {2024},
7 | editor = {Salakhutdinov, Ruslan and Kolter, Zico and Heller, Katherine and Weller, Adrian and Oliver, Nuria and Scarlett, Jonathan and Berkenkamp, Felix},
8 | volume = {235},
9 | series = {Proceedings of Machine Learning Research},
10 | month = {21--27 Jul},
11 | publisher = {PMLR},
12 | pdf = {https://raw.githubusercontent.com/mlresearch/v235/main/assets/park24a/park24a.pdf},
13 | url = {https://proceedings.mlr.press/v235/park24a.html},
14 | abstract = {Making neural networks remember over the long term has been a longstanding issue. Although several external memory techniques have been introduced, most focus on retaining recent information in the short term. Regardless of its importance, information tends to be fatefully forgotten over time. We present Memoria, a memory system for artificial neural networks, drawing inspiration from humans and applying various neuroscientific and psychological theories. The experimental results prove the effectiveness of Memoria in the diverse tasks of sorting, language modeling, and classification, surpassing conventional techniques. Engram analysis reveals that Memoria exhibits the primacy, recency, and temporal contiguity effects which are characteristics of human memory.}
15 | }
16 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 ParkSangJun
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Memoria
2 |
3 | [](https://opensource.org/licenses/MIT)
4 | [](https://github.com/psf/black)
5 | [](https://pycqa.github.io/isort/)
6 | [](https://dl.circleci.com/status-badge/redirect/gh/cosmoquester/memoria/tree/master)
7 | [](https://codecov.io/gh/cosmoquester/memoria)
8 |
9 |
10 |
11 | Making neural networks remember over the long term has been a longstanding issue. Although several external memory techniques have been introduced, most focus on retaining recent information in the short term. Regardless of its importance, information tends to be fatefully forgotten over time. We present Memoria, a memory system for artificial neural networks, drawing inspiration from humans and applying various neuroscientific and psychological theories. The experimental results prove the effectiveness of Memoria in the diverse tasks of sorting, language modeling, and classification, surpassing conventional techniques. Engram analysis reveals that Memoria exhibits the primacy, recency, and temporal contiguity effects which are characteristics of human memory.
12 |
13 | Memoria is an independant module which can be applied to neural network models in various ways and the experiment code of the paper is in the `experiment` directory.
14 |
15 | My paper [Memoria: Resolving Fateful Forgetting Problem through Human-Inspired Memory Architecture](https://icml.cc/virtual/2024/poster/32668) is accepted to **International Conference on Machine Learning (ICML) 2024 as a Spotlight paper**.
16 | The full text of the paper can be accessed from [OpenReview](https://openreview.net/forum?id=yTz0u4B8ug) or [ArXiv](https://arxiv.org/abs/2310.03052).
17 |
18 | ## Installation
19 |
20 | ```sh
21 | $ pip install memoria-pytorch
22 | ```
23 |
24 | You can install memoria by pip command above.
25 |
26 | ## Tutorial
27 |
28 | This is a tutorial to help to understand the concept and mechanism of Memoria.
29 |
30 | #### 1. Import Memoria and Set Parameters
31 |
32 | ```python
33 | import torch
34 | from memoria import Memoria, EngramType
35 |
36 | torch.manual_seed(42)
37 |
38 | # Memoria Parameters
39 | num_reminded_stm = 4
40 | stm_capacity = 16
41 | ltm_search_depth = 5
42 | initial_lifespan = 3
43 | num_final_ltms = 4
44 |
45 | # Data Parameters
46 | batch_size = 2
47 | sequence_length = 8
48 | hidden_dim = 64
49 | ```
50 |
51 | #### 2. Initialize Memoria and Dummy Data
52 |
53 | - Fake random data and lifespan delta are used for simplification.
54 |
55 | ```python
56 | memoria = Memoria(
57 | num_reminded_stm=num_reminded_stm,
58 | stm_capacity=stm_capacity,
59 | ltm_search_depth=ltm_search_depth,
60 | initial_lifespan=initial_lifespan,
61 | num_final_ltms=num_final_ltms,
62 | )
63 | data = torch.rand(batch_size, sequence_length, hidden_dim)
64 | ```
65 |
66 | #### 3. Add Data as Working Memory
67 |
68 | ```python
69 | # Add data as working memory
70 | memoria.add_working_memory(data)
71 | ```
72 |
73 | ```python
74 | # Expected values
75 | >>> len(memoria.engrams)
76 | 16
77 | >>> memoria.engrams.data.shape
78 | torch.Size([2, 8, 64])
79 | >>> memoria.engrams.lifespan
80 | tensor([[3., 3., 3., 3., 3., 3., 3., 3.],
81 | [3., 3., 3., 3., 3., 3., 3., 3.]])
82 | ```
83 |
84 | #### 4. Remind Memories
85 |
86 | - Empty memories are reminded because there is no engrams in STM/LTM yet
87 |
88 | ```python
89 | reminded_memories, reminded_indices = memoria.remind()
90 | ```
91 |
92 | ```python
93 | # No reminded memories because there is no STM/LTM engrams yet
94 | >>> reminded_memories
95 | tensor([], size=(2, 0, 64))
96 | >>> reminded_indices
97 | tensor([], size=(2, 0), dtype=torch.int64)
98 | ```
99 |
100 | #### 5. Adjust Lifespan and Memories
101 |
102 | - In this step, no engrams earn lifespan because there is no reminded memories
103 |
104 | ```python
105 | memoria.adjust_lifespan_and_memories(reminded_indices, torch.zeros_like(reminded_indices))
106 | ```
107 |
108 | ```python
109 | # Decreases lifespan for all engrams & working memories have changed into shortterm memory
110 | >>> memoria.engrams.lifespan
111 | tensor([[2., 2., 2., 2., 2., 2., 2., 2.],
112 | [2., 2., 2., 2., 2., 2., 2., 2.]])
113 | >>> memoria.engrams.engrams_types
114 | tensor([[2, 2, 2, 2, 2, 2, 2, 2],
115 | [2, 2, 2, 2, 2, 2, 2, 2]], dtype=torch.uint8)
116 | >>> EngramType.SHORTTERM
117 |
118 | ```
119 |
120 | #### 6. Repeat one more time
121 |
122 | - Now, there are some engrams in STM, remind and adjustment from STM will work
123 |
124 | ```python
125 | data2 = torch.rand(batch_size, sequence_length, hidden_dim)
126 | memoria.add_working_memory(data2)
127 | ```
128 |
129 | ```python
130 | >>> len(memoria.engrams)
131 | 32
132 | >>> memoria.engrams.lifespan
133 | tensor([[2., 2., 2., 2., 2., 2., 2., 2., 3., 3., 3., 3., 3., 3., 3., 3.],
134 | [2., 2., 2., 2., 2., 2., 2., 2., 3., 3., 3., 3., 3., 3., 3., 3.]])
135 | ```
136 |
137 | ```python
138 | reminded_memories, reminded_indices = memoria.remind()
139 | ```
140 |
141 | ```python
142 | # Remind memories from STM
143 | >>> reminded_memories.shape
144 | torch.Size([2, 6, 64])
145 | >>> reminded_indices.shape
146 | torch.Size([2, 6])
147 | >>> reminded_indices
148 | tensor([[ 0, 6, 4, 3, 2, -1],
149 | [ 0, 7, 6, 5, 4, -1]])
150 | ```
151 |
152 | ```python
153 | # Increase lifespan of all the reminded engrams by 5
154 | memoria.adjust_lifespan_and_memories(reminded_indices, torch.full_like(reminded_indices, 5))
155 | ```
156 |
157 | ```python
158 | # Reminded engrams got lifespan by 5, other engrams have got older
159 | >>> memoria.engrams.lifespan
160 | >>> memoria.engrams.lifespan
161 | tensor([[6., 1., 6., 6., 6., 1., 6., 1., 2., 2., 2., 2., 2., 2., 2., 2.],
162 | [6., 1., 1., 1., 6., 6., 6., 6., 2., 2., 2., 2., 2., 2., 2., 2.]])
163 | ```
164 |
165 | #### 7. Repeat
166 |
167 | - Repeat 10 times to see the dynamics of LTM
168 |
169 | ```python
170 | # This is default process to utilize Memoria
171 | for _ in range(10):
172 | data = torch.rand(batch_size, sequence_length, hidden_dim)
173 | memoria.add_working_memory(data)
174 |
175 | reminded_memories, reminded_indices = memoria.remind()
176 |
177 | lifespan_delta = torch.randint_like(reminded_indices, 0, 6).float()
178 |
179 | memoria.adjust_lifespan_and_memories(reminded_indices, lifespan_delta)
180 | ```
181 |
182 | ```python
183 | # After 10 iteration, some engrams have changed into longterm memory and got large lifespan
184 | # Engram type zero means those engrams are deleted
185 | >>> len(memoria.engrams)
186 | 72
187 | >>> memoria.engrams.engrams_types
188 | tensor([[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2,
189 | 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
190 | [0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2,
191 | 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]], dtype=torch.uint8)
192 | >>> EngramType.LONGTERM
193 |
194 | >>> EngramType.NULL
195 |
196 | >>> memoria.engrams.lifespan
197 | tensor([[ 9., 1., 8., 2., 16., 5., 13., 7., 7., 3., 3., 4., 3., 3.,
198 | 4., 2., 2., 1., 1., 1., 1., 1., 1., 1., 2., 6., 1., 1.,
199 | 2., 2., 2., 2., 2., 2., 2., 2.],
200 | [-1., -1., 3., 2., 19., 21., 11., 6., 14., 1., 5., 1., 5., 1.,
201 | 5., 1., 1., 8., 2., 1., 1., 1., 2., 1., 1., 1., 1., 1.,
202 | 2., 2., 2., 2., 2., 2., 2., 2.]])
203 | ```
204 |
205 | # Citation
206 |
207 | ```bibtex
208 | @InProceedings{pmlr-v235-park24a,
209 | title = {Memoria: Resolving Fateful Forgetting Problem through Human-Inspired Memory Architecture},
210 | author = {Park, Sangjun and Bak, Jinyeong},
211 | booktitle = {Proceedings of the 41st International Conference on Machine Learning},
212 | pages = {39587--39615},
213 | year = {2024},
214 | editor = {Salakhutdinov, Ruslan and Kolter, Zico and Heller, Katherine and Weller, Adrian and Oliver, Nuria and Scarlett, Jonathan and Berkenkamp, Felix},
215 | volume = {235},
216 | series = {Proceedings of Machine Learning Research},
217 | month = {21--27 Jul},
218 | publisher = {PMLR},
219 | pdf = {https://raw.githubusercontent.com/mlresearch/v235/main/assets/park24a/park24a.pdf},
220 | url = {https://proceedings.mlr.press/v235/park24a.html},
221 | abstract = {Making neural networks remember over the long term has been a longstanding issue. Although several external memory techniques have been introduced, most focus on retaining recent information in the short term. Regardless of its importance, information tends to be fatefully forgotten over time. We present Memoria, a memory system for artificial neural networks, drawing inspiration from humans and applying various neuroscientific and psychological theories. The experimental results prove the effectiveness of Memoria in the diverse tasks of sorting, language modeling, and classification, surpassing conventional techniques. Engram analysis reveals that Memoria exhibits the primacy, recency, and temporal contiguity effects which are characteristics of human memory.}
222 | }
223 | ```
224 |
--------------------------------------------------------------------------------
/experiment/.gitignore:
--------------------------------------------------------------------------------
1 | lightning_logs/
2 |
--------------------------------------------------------------------------------
/experiment/README.md:
--------------------------------------------------------------------------------
1 | # Memoria Experiment
2 |
3 | The directory contains the model architecture, data loader, config files, and training and evaluation script to conduct experiments in my paper. You can reproduce my research by refering to the Memoria paper or develop your own idea on this.
4 |
5 | ## Package Install
6 |
7 | You should install the required packages before running the code.
8 |
9 | ```sh
10 | $ pip install -r requirements.txt
11 | ```
12 |
13 | ## Structure
14 |
15 | ```
16 | longseq_formers
17 | ├── configs
18 | └── longseq_formers
19 | ├── data
20 | ├── dataset
21 | ├── model
22 | │ ├── compressive_former
23 | │ ├── gpt2_with_memoria
24 | │ ├── infinity_gpt2
25 | │ ├── memoria_bert
26 | │ └── memoria_roberta
27 | └── task
28 | ```
29 | - `longseq_formers` directory is main directory for experiment. There are data loaders, task training, and model architectures.
30 | - `configs` directory includes multiple config files for language modeling and synthetic task (sorting) for multiple models.
31 |
32 | ## Models
33 |
34 | You can load modeles from `longseq_formers.model` module regardless of the training or evaluation script.
35 |
36 | ```python
37 | import torch
38 | from longseq_formers.model import MemoriaBertModel
39 |
40 | memoria_bert = MemoriaBertModel.from_pretrained("bert-base-uncased")
41 | input_ids = torch.randint(0, 10, [1,10])
42 | outputs = memoria_bert(input_ids)
43 | ```
44 |
45 | ```python
46 | import torch
47 | from longseq_formers.model import GPT2WithMemoriaLMHeadModel
48 |
49 | memoria_gpt2 = GPT2WithMemoriaLMHeadModel.from_pretrained("gpt2")
50 | input_ids = torch.randint(0, 10, [1,10])
51 | outputs = memoria_gpt2(input_ids)
52 | ```
53 |
54 | ## Train
55 |
56 | You can train the model with training scripts depending on the task. With the `--help` option, you can see the options for training or evaluation. Because all the datasets except for the sorting task will be loaded from web, you don't have to download the dataset separately.
57 |
58 | ```sh
59 | $ python train_language_modeling.py --help
60 | usage: train [-h] [--model-config MODEL_CONFIG] [--model MODEL] [--model-type MODEL_TYPE] [--tokenizer TOKENIZER] [--dataset {wikitext103,pg19,enwik8}]
61 | [--batch-size BATCH_SIZE] [--valid-batch-size VALID_BATCH_SIZE] [--accumulate-grad-batches ACCUMULATE_GRAD_BATCHES] [--max-length MAX_LENGTH] [--epochs EPOCHS]
62 | [--learning-rate LEARNING_RATE] [--warmup-rate WARMUP_RATE] [--max-grad-norm MAX_GRAD_NORM] [--seed SEED] [--shuffle] [--test-ckpt {best,last}]
63 | [--output-dir OUTPUT_DIR] [--gpus GPUS] [--logging-interval LOGGING_INTERVAL] [--valid-interval VALID_INTERVAL] [--wandb-run-name WANDB_RUN_NAME]
64 | [--wandb-entity WANDB_ENTITY] [--wandb-project WANDB_PROJECT]
65 |
66 | Train & Test Language Modeling
67 |
68 | optional arguments:
69 | -h, --help show this help message and exit
70 |
71 | Train Parameter:
72 | --model-config MODEL_CONFIG
73 | huggingface model config
74 | --model MODEL huggingface model
75 | --model-type MODEL_TYPE
76 | specific model type
77 | --tokenizer TOKENIZER
78 | huggingface tokenizer
79 | --dataset {wikitext103,pg19,enwik8}
80 | dataset name
81 | --batch-size BATCH_SIZE
82 | global training batch size
83 | --valid-batch-size VALID_BATCH_SIZE
84 | validation batch size
85 | --accumulate-grad-batches ACCUMULATE_GRAD_BATCHES
86 | the number of gradident accumulation steps
87 | --max-length MAX_LENGTH
88 | max sequence length
89 | --epochs EPOCHS the number of training epochs
90 | --learning-rate LEARNING_RATE
91 | learning rate
92 | --warmup-rate WARMUP_RATE
93 | warmup step rate
94 | --max-grad-norm MAX_GRAD_NORM
95 | maximum gradient norm
96 | --seed SEED random seed
97 | --shuffle shuffle data order
98 | --test-ckpt {best,last}
99 | checkpoint type for testing
100 |
101 | Personal Options:
102 | --output-dir OUTPUT_DIR
103 | output directory path to save artifacts
104 | --gpus GPUS the number of gpus, use all devices by default
105 | --logging-interval LOGGING_INTERVAL
106 | logging interval
107 | --valid-interval VALID_INTERVAL
108 | validation interval rate
109 |
110 | Wandb Options:
111 | --wandb-run-name WANDB_RUN_NAME
112 | wanDB run name
113 | --wandb-entity WANDB_ENTITY
114 | wanDB entity name
115 | --wandb-project WANDB_PROJECT
116 | wanDB project name
117 | ```
118 |
119 | ```sh
120 | $ python train_language_modeling.py --model gpt2
121 | [2023-09-29 21:31:29,995] ====== Arguements ======
122 | [2023-09-29 21:31:29,995] model_config : None
123 | [2023-09-29 21:31:29,995] model : gpt2
124 | [2023-09-29 21:31:29,995] model_type : None
125 | [2023-09-29 21:31:29,995] tokenizer : None
126 | [2023-09-29 21:31:29,995] dataset : wikitext103
127 | [2023-09-29 21:31:29,995] batch_size : 8
128 | [2023-09-29 21:31:29,995] valid_batch_size : 1
129 | [2023-09-29 21:31:29,995] accumulate_grad_batches : 1
130 | [2023-09-29 21:31:29,995] max_length : 150
131 | [2023-09-29 21:31:29,995] epochs : 6
132 | [2023-09-29 21:31:29,995] learning_rate : 0.0002
133 | ...
134 | ```
135 | - You can start training simply this command without any download.
136 |
137 | ```sh
138 | $ python train_language_modeling.py --model-config configs/memoria-gpt2.json --tokenizer gpt2 --output-dir trained-model
139 | [2023-09-29 21:43:27,347] [+] Save output to "trained-model"
140 | [2023-09-29 21:43:27,347] ====== Arguements ======
141 | [2023-09-29 21:43:27,347] model_config : configs/memoria-gpt2.json
142 | [2023-09-29 21:43:27,347] model : None
143 | [2023-09-29 21:43:27,347] model_type : None
144 | [2023-09-29 21:43:27,347] tokenizer : gpt2
145 | [2023-09-29 21:43:27,347] dataset : wikitext103
146 | [2023-09-29 21:43:27,347] batch_size : 8
147 | [2023-09-29 21:43:27,347] valid_batch_size : 1
148 | [2023-09-29 21:43:27,347] accumulate_grad_batches : 1
149 | [2023-09-29 21:43:27,347] max_length : 150
150 | [2023-09-29 21:43:27,347] epochs : 6
151 | [2023-09-29 21:43:27,347] learning_rate : 0.0002
152 | ...
153 | ```
154 | - You can train MemoriaGPT2 model by adding `--model-type gpt2_with_memoria` option or `--model-config configs/memoria-gpt2.json` simply.
155 | - To save model checkpoint, you can add `--output-dir [OUTPUT-DIR]` option. The model checkpoint will be saved in `trained-model` directory.
156 | - Refer help description and the Memoria paper for detail hyperparameters.
157 |
158 | ## Evaluation
159 |
160 | ```sh
161 | $ python eval_language_modeling.py --model trained-model/checkpoint/last.ckpt
162 | [2023-09-29 21:45:03,214] ====== Arguements ======
163 | [2023-09-29 21:45:03,214] model : trained-model/checkpoint/last.ckpt
164 | [2023-09-29 21:45:03,214] tokenizer : gpt2
165 | [2023-09-29 21:45:03,214] dataset : wikitext103
166 | [2023-09-29 21:45:03,214] valid_batch_size : 1
167 | [2023-09-29 21:45:03,214] max_length : 512
168 | [2023-09-29 21:45:03,214] seed : 42
169 | [2023-09-29 21:45:03,214] [+] Set Random Seed to 42
170 | Global seed set to 42
171 | [2023-09-29 21:45:03,237] [+] GPU: 1
172 | [2023-09-29 21:45:03,237] [+] Load Tokenizer: "gpt2"
173 | ...
174 | ```
175 | - You should give save model checkpoint with `--model [MODEL-CHECKPOINT]` option.
176 |
--------------------------------------------------------------------------------
/experiment/configs/compressive-former-enwik8.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "compressive_transformer",
3 | "attn_dropout": 0.1,
4 | "attn_layer_dropout": 0.1,
5 | "cmem_len": 256,
6 | "cmem_ratio": 4,
7 | "depth": 12,
8 | "dim": 512,
9 | "emb_dim": null,
10 | "enhanced_recurrence": true,
11 | "ff_dropout": 0.1,
12 | "ff_glu": false,
13 | "gru_gated_residual": false,
14 | "heads": 8,
15 | "mem_len": 256,
16 | "memory_layers": null,
17 | "mogrify_gru": false,
18 | "vocab_size": 204,
19 | "reconstruction_attn_dropout": 0.0,
20 | "reconstruction_loss_weight": 1.0,
21 | "seq_len": 512,
22 | "transformers_version": "4.25.1"
23 | }
24 |
--------------------------------------------------------------------------------
/experiment/configs/compressive-former-sort.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "compressive_transformer",
3 | "attn_dropout": 0.1,
4 | "attn_layer_dropout": 0.1,
5 | "cmem_len": 512,
6 | "cmem_ratio": 4,
7 | "depth": 4,
8 | "dim": 512,
9 | "emb_dim": null,
10 | "enhanced_recurrence": true,
11 | "ff_dropout": 0.1,
12 | "ff_glu": false,
13 | "gru_gated_residual": false,
14 | "heads": 4,
15 | "mem_len": 512,
16 | "memory_layers": null,
17 | "mogrify_gru": false,
18 | "num_tokens": 21,
19 | "reconstruction_attn_dropout": 0.0,
20 | "reconstruction_loss_weight": 1.0,
21 | "seq_len": 1024,
22 | "transformers_version": "4.25.1"
23 | }
24 |
--------------------------------------------------------------------------------
/experiment/configs/compressive-former.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "compressive_transformer",
3 | "attn_dropout": 0.1,
4 | "attn_layer_dropout": 0.1,
5 | "cmem_len": 75,
6 | "cmem_ratio": 4,
7 | "depth": 12,
8 | "dim": 768,
9 | "emb_dim": null,
10 | "enhanced_recurrence": true,
11 | "ff_dropout": 0.1,
12 | "ff_glu": false,
13 | "gru_gated_residual": false,
14 | "heads": 12,
15 | "mem_len": 75,
16 | "memory_layers": null,
17 | "mogrify_gru": false,
18 | "vocab_size": 50257,
19 | "reconstruction_attn_dropout": 0.0,
20 | "reconstruction_loss_weight": 1.0,
21 | "seq_len": 512,
22 | "transformers_version": "4.25.1"
23 | }
24 |
--------------------------------------------------------------------------------
/experiment/configs/gpt2-enwik8.json:
--------------------------------------------------------------------------------
1 | {
2 | "activation_function": "gelu_new",
3 | "attn_pdrop": 0.1,
4 | "bos_token_id": 50256,
5 | "embd_pdrop": 0.1,
6 | "eos_token_id": 50256,
7 | "initializer_range": 0.02,
8 | "layer_norm_epsilon": 1e-05,
9 | "model_type": "gpt2",
10 | "n_embd": 512,
11 | "n_head": 8,
12 | "n_inner": null,
13 | "n_layer": 12,
14 | "n_positions": 1024,
15 | "reorder_and_upcast_attn": false,
16 | "resid_pdrop": 0.1,
17 | "scale_attn_by_inverse_layer_idx": false,
18 | "scale_attn_weights": true,
19 | "summary_activation": null,
20 | "summary_first_dropout": 0.1,
21 | "summary_proj_to_labels": true,
22 | "summary_type": "cls_index",
23 | "summary_use_proj": true,
24 | "transformers_version": "4.25.1",
25 | "use_cache": true,
26 | "vocab_size": 204
27 | }
28 |
--------------------------------------------------------------------------------
/experiment/configs/gpt2-synthetic.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "gpt2_with_memoria",
3 | "activation_function": "gelu_new",
4 | "attn_pdrop": 0.1,
5 | "bos_token_id": 50256,
6 | "embd_pdrop": 0.1,
7 | "eos_token_id": 50256,
8 | "initializer_range": 0.02,
9 | "layer_norm_epsilon": 1e-05,
10 | "n_ctx": 1024,
11 | "n_embd": 512,
12 | "n_head": 4,
13 | "n_inner": null,
14 | "n_layer": 4,
15 | "n_positions": 1024,
16 | "reorder_and_upcast_attn": false,
17 | "resid_pdrop": 0.1,
18 | "scale_attn_by_inverse_layer_idx": false,
19 | "scale_attn_weights": true,
20 | "summary_activation": null,
21 | "summary_first_dropout": 0.1,
22 | "summary_proj_to_labels": true,
23 | "summary_type": "cls_index",
24 | "summary_use_proj": true,
25 | "task_specific_params": {
26 | "text-generation": {
27 | "do_sample": true,
28 | "max_length": 50
29 | }
30 | },
31 | "transformers_version": "4.25.1",
32 | "use_cache": true,
33 | "vocab_size": 50257
34 | }
35 |
--------------------------------------------------------------------------------
/experiment/configs/gpt2-xl.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "gpt2-xl",
3 | "activation_function": "gelu_new",
4 | "architectures": [
5 | "GPT2LMHeadModel"
6 | ],
7 | "attn_pdrop": 0.1,
8 | "bos_token_id": 50256,
9 | "embd_pdrop": 0.1,
10 | "eos_token_id": 50256,
11 | "initializer_range": 0.02,
12 | "layer_norm_epsilon": 1e-05,
13 | "model_type": "gpt2",
14 | "n_ctx": 1024,
15 | "n_embd": 1600,
16 | "n_head": 25,
17 | "n_inner": null,
18 | "n_layer": 48,
19 | "n_positions": 1024,
20 | "output_past": true,
21 | "reorder_and_upcast_attn": false,
22 | "resid_pdrop": 0.1,
23 | "scale_attn_by_inverse_layer_idx": false,
24 | "scale_attn_weights": true,
25 | "summary_activation": null,
26 | "summary_first_dropout": 0.1,
27 | "summary_proj_to_labels": true,
28 | "summary_type": "cls_index",
29 | "summary_use_proj": true,
30 | "task_specific_params": {
31 | "text-generation": {
32 | "do_sample": true,
33 | "max_length": 50
34 | }
35 | },
36 | "transformers_version": "4.25.1",
37 | "use_cache": true,
38 | "vocab_size": 50257
39 | }
40 |
--------------------------------------------------------------------------------
/experiment/configs/gpt2.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "gpt2_with_memoria",
3 | "activation_function": "gelu_new",
4 | "attn_pdrop": 0.1,
5 | "bos_token_id": 50256,
6 | "embd_pdrop": 0.1,
7 | "eos_token_id": 50256,
8 | "initializer_range": 0.02,
9 | "layer_norm_epsilon": 1e-05,
10 | "n_ctx": 1024,
11 | "n_embd": 512,
12 | "n_head": 4,
13 | "n_inner": null,
14 | "n_layer": 4,
15 | "n_positions": 1024,
16 | "reorder_and_upcast_attn": false,
17 | "resid_pdrop": 0.1,
18 | "scale_attn_by_inverse_layer_idx": false,
19 | "scale_attn_weights": true,
20 | "summary_activation": null,
21 | "summary_first_dropout": 0.1,
22 | "summary_proj_to_labels": true,
23 | "summary_type": "cls_index",
24 | "summary_use_proj": true,
25 | "task_specific_params": {
26 | "text-generation": {
27 | "do_sample": true,
28 | "max_length": 50
29 | }
30 | },
31 | "transformers_version": "4.25.1",
32 | "use_cache": true,
33 | "vocab_size": 50257
34 | }
35 |
--------------------------------------------------------------------------------
/experiment/configs/infinity-gpt2-enwik8.json:
--------------------------------------------------------------------------------
1 | {
2 | "activation_function": "gelu_new",
3 | "attn_drop": 0.1,
4 | "attn_pdrop": 0.1,
5 | "bos_token_id": 50256,
6 | "detach_recursive_outputs": true,
7 | "embd_pdrop": 0.1,
8 | "eos_token_id": 50256,
9 | "gradient_checkpointing": false,
10 | "initializer_range": 0.02,
11 | "kl_lambda": 1e-06,
12 | "layer_norm_epsilon": 1e-05,
13 | "longterm_attention_dropout": 0.1,
14 | "mask_dropout": 0.1,
15 | "mask_type": "cnn",
16 | "memory_length": 512,
17 | "model_type": "infinity_gpt2",
18 | "mu_0": -1.0,
19 | "n_ctx": 1024,
20 | "n_embd": 512,
21 | "n_head": 8,
22 | "n_inner": null,
23 | "n_layer": 12,
24 | "n_positions": 1024,
25 | "normalize_function": "softmax",
26 | "num_basis": 512,
27 | "num_samples": 512,
28 | "resid_pdrop": 0.1,
29 | "sigma_0": 0.05,
30 | "summary_activation": null,
31 | "summary_first_dropout": 0.1,
32 | "summary_proj_to_labels": true,
33 | "summary_type": "cls_index",
34 | "summary_use_proj": true,
35 | "task_specific_params": {
36 | "text-generation": {
37 | "do_sample": true,
38 | "max_length": 50
39 | }
40 | },
41 | "tau": 0.5,
42 | "transformers_version": "4.25.1",
43 | "use_affines": true,
44 | "use_cache": true,
45 | "use_kl_regularizer": true,
46 | "use_sticky_memories": true,
47 | "vocab_size": 204
48 | }
49 |
--------------------------------------------------------------------------------
/experiment/configs/infinity-gpt2-sort.json:
--------------------------------------------------------------------------------
1 | {
2 | "activation_function": "gelu_new",
3 | "attn_drop": 0.1,
4 | "attn_pdrop": 0.1,
5 | "bos_token_id": 50256,
6 | "detach_recursive_outputs": true,
7 | "embd_pdrop": 0.1,
8 | "eos_token_id": 50256,
9 | "gradient_checkpointing": false,
10 | "initializer_range": 0.02,
11 | "kl_lambda": 1e-06,
12 | "layer_norm_epsilon": 1e-05,
13 | "longterm_attention_dropout": 0.1,
14 | "mask_dropout": 0.1,
15 | "mask_type": "cnn",
16 | "memory_length": 1024,
17 | "model_type": "infinity_gpt2",
18 | "mu_0": -1.0,
19 | "n_ctx": 1024,
20 | "n_embd": 512,
21 | "n_head": 4,
22 | "n_inner": 2048,
23 | "n_layer": 4,
24 | "n_positions": 1024,
25 | "normalize_function": "softmax",
26 | "num_basis": 1024,
27 | "num_samples": 1024,
28 | "resid_pdrop": 0.1,
29 | "sigma_0": 0.05,
30 | "summary_activation": null,
31 | "summary_first_dropout": 0.1,
32 | "summary_proj_to_labels": true,
33 | "summary_type": "cls_index",
34 | "summary_use_proj": true,
35 | "task_specific_params": {
36 | "text-generation": {
37 | "do_sample": true,
38 | "max_length": 50
39 | }
40 | },
41 | "tau": 0.5,
42 | "transformers_version": "4.25.1",
43 | "use_affines": true,
44 | "use_cache": true,
45 | "use_kl_regularizer": true,
46 | "use_sticky_memories": true,
47 | "vocab_size": 50257
48 | }
49 |
--------------------------------------------------------------------------------
/experiment/configs/infinity-gpt2-synthetic.json:
--------------------------------------------------------------------------------
1 | {
2 | "activation_function": "gelu_new",
3 | "attn_drop": 0.1,
4 | "attn_pdrop": 0.1,
5 | "bos_token_id": 50256,
6 | "detach_recursive_outputs": true,
7 | "embd_pdrop": 0.1,
8 | "eos_token_id": 50256,
9 | "gradient_checkpointing": false,
10 | "initializer_range": 0.02,
11 | "kl_lambda": 1e-06,
12 | "layer_norm_epsilon": 1e-05,
13 | "longterm_attention_dropout": 0.1,
14 | "mask_dropout": 0.1,
15 | "mask_type": "cnn",
16 | "memory_length": 18,
17 | "model_type": "infinity_gpt2",
18 | "mu_0": -1.0,
19 | "n_ctx": 1024,
20 | "n_embd": 512,
21 | "n_head": 4,
22 | "n_inner": null,
23 | "n_layer": 4,
24 | "n_positions": 1024,
25 | "normalize_function": "softmax",
26 | "num_basis": 18,
27 | "num_samples": 18,
28 | "resid_pdrop": 0.1,
29 | "sigma_0": 0.05,
30 | "summary_activation": null,
31 | "summary_first_dropout": 0.1,
32 | "summary_proj_to_labels": true,
33 | "summary_type": "cls_index",
34 | "summary_use_proj": true,
35 | "task_specific_params": {
36 | "text-generation": {
37 | "do_sample": true,
38 | "max_length": 50
39 | }
40 | },
41 | "tau": 0.5,
42 | "transformers_version": "4.25.1",
43 | "use_affines": true,
44 | "use_cache": true,
45 | "use_kl_regularizer": true,
46 | "use_sticky_memories": true,
47 | "vocab_size": 50257
48 | }
49 |
--------------------------------------------------------------------------------
/experiment/configs/infinity-gpt2.json:
--------------------------------------------------------------------------------
1 | {
2 | "activation_function": "gelu_new",
3 | "attn_drop": 0.1,
4 | "attn_pdrop": 0.1,
5 | "bos_token_id": 50256,
6 | "detach_recursive_outputs": true,
7 | "embd_pdrop": 0.1,
8 | "eos_token_id": 50256,
9 | "gradient_checkpointing": false,
10 | "initializer_range": 0.02,
11 | "kl_lambda": 1e-06,
12 | "layer_norm_epsilon": 1e-05,
13 | "longterm_attention_dropout": 0.1,
14 | "mask_dropout": 0.1,
15 | "mask_type": "cnn",
16 | "memory_length": 150,
17 | "model_type": "infinity_gpt2",
18 | "mu_0": -1.0,
19 | "n_ctx": 1024,
20 | "n_embd": 768,
21 | "n_head": 12,
22 | "n_inner": null,
23 | "n_layer": 12,
24 | "n_positions": 1024,
25 | "normalize_function": "softmax",
26 | "num_basis": 150,
27 | "num_samples": 150,
28 | "resid_pdrop": 0.1,
29 | "sigma_0": 0.05,
30 | "summary_activation": null,
31 | "summary_first_dropout": 0.1,
32 | "summary_proj_to_labels": true,
33 | "summary_type": "cls_index",
34 | "summary_use_proj": true,
35 | "task_specific_params": {
36 | "text-generation": {
37 | "do_sample": true,
38 | "max_length": 50
39 | }
40 | },
41 | "tau": 0.5,
42 | "transformers_version": "4.25.1",
43 | "use_affines": true,
44 | "use_cache": true,
45 | "use_kl_regularizer": true,
46 | "use_sticky_memories": true,
47 | "vocab_size": 50257
48 | }
49 |
--------------------------------------------------------------------------------
/experiment/configs/memoria-enwik8.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "gpt2_with_memoria",
3 | "activation_function": "gelu_new",
4 | "attn_pdrop": 0.1,
5 | "bos_token_id": 50256,
6 | "embd_pdrop": 0.1,
7 | "eos_token_id": 50256,
8 | "initializer_range": 0.02,
9 | "layer_norm_epsilon": 1e-05,
10 | "memoria_stm_capacity": 1536,
11 | "memoria_num_memories": 128,
12 | "memoria_initial_lifespan": 9,
13 | "memoria_lifespan_extend_scale": 8.0,
14 | "memoria_ltm_search_depth": 10,
15 | "memoria_reset_period": 500,
16 | "memoria_num_reminded_stm": 192,
17 | "memoria_num_reminded_ltm": 192,
18 | "memoria_device": null,
19 | "n_ctx": 1024,
20 | "n_embd": 512,
21 | "n_head": 8,
22 | "n_inner": null,
23 | "n_layer": 12,
24 | "n_positions": 1024,
25 | "reorder_and_upcast_attn": false,
26 | "resid_pdrop": 0.1,
27 | "scale_attn_by_inverse_layer_idx": false,
28 | "scale_attn_weights": true,
29 | "summary_activation": null,
30 | "summary_first_dropout": 0.1,
31 | "summary_proj_to_labels": true,
32 | "summary_type": "cls_index",
33 | "summary_use_proj": true,
34 | "task_specific_params": {
35 | "text-generation": {
36 | "do_sample": true,
37 | "max_length": 50
38 | }
39 | },
40 | "transformers_version": "4.25.1",
41 | "use_cache": true,
42 | "vocab_size": 204
43 | }
44 |
--------------------------------------------------------------------------------
/experiment/configs/memoria-gpt2-large.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "gpt2_with_memoria",
3 | "activation_function": "gelu_new",
4 | "attn_pdrop": 0.1,
5 | "bos_token_id": 50256,
6 | "embd_pdrop": 0.1,
7 | "eos_token_id": 50256,
8 | "initializer_range": 0.02,
9 | "layer_norm_epsilon": 1e-05,
10 | "memoria_stm_capacity": 400,
11 | "memoria_num_memories": 50,
12 | "memoria_initial_lifespan": 9,
13 | "memoria_lifespan_extend_scale": 8.0,
14 | "memoria_ltm_search_depth": 10,
15 | "memoria_reset_period": 500,
16 | "memoria_num_reminded_stm": 50,
17 | "memoria_num_reminded_ltm": 50,
18 | "memoria_device": null,
19 | "n_ctx": 1024,
20 | "n_embd": 1280,
21 | "n_head": 20,
22 | "n_inner": null,
23 | "n_layer": 36,
24 | "n_positions": 1024,
25 | "reorder_and_upcast_attn": false,
26 | "resid_pdrop": 0.1,
27 | "scale_attn_by_inverse_layer_idx": false,
28 | "scale_attn_weights": true,
29 | "summary_activation": null,
30 | "summary_first_dropout": 0.1,
31 | "summary_proj_to_labels": true,
32 | "summary_type": "cls_index",
33 | "summary_use_proj": true,
34 | "task_specific_params": {
35 | "text-generation": {
36 | "do_sample": true,
37 | "max_length": 50
38 | }
39 | },
40 | "transformers_version": "4.25.1",
41 | "use_cache": true,
42 | "vocab_size": 50257
43 | }
44 |
--------------------------------------------------------------------------------
/experiment/configs/memoria-gpt2-medium.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "gpt2_with_memoria",
3 | "activation_function": "gelu_new",
4 | "attn_pdrop": 0.1,
5 | "bos_token_id": 50256,
6 | "embd_pdrop": 0.1,
7 | "eos_token_id": 50256,
8 | "initializer_range": 0.02,
9 | "layer_norm_epsilon": 1e-05,
10 | "memoria_stm_capacity": 400,
11 | "memoria_num_memories": 50,
12 | "memoria_initial_lifespan": 9,
13 | "memoria_lifespan_extend_scale": 8.0,
14 | "memoria_ltm_search_depth": 10,
15 | "memoria_reset_period": 500,
16 | "memoria_num_reminded_stm": 50,
17 | "memoria_num_reminded_ltm": 50,
18 | "memoria_device": null,
19 | "n_ctx": 1024,
20 | "n_embd": 1024,
21 | "n_head": 16,
22 | "n_inner": null,
23 | "n_layer": 24,
24 | "n_positions": 1024,
25 | "reorder_and_upcast_attn": false,
26 | "resid_pdrop": 0.1,
27 | "scale_attn_by_inverse_layer_idx": false,
28 | "scale_attn_weights": true,
29 | "summary_activation": null,
30 | "summary_first_dropout": 0.1,
31 | "summary_proj_to_labels": true,
32 | "summary_type": "cls_index",
33 | "summary_use_proj": true,
34 | "task_specific_params": {
35 | "text-generation": {
36 | "do_sample": true,
37 | "max_length": 50
38 | }
39 | },
40 | "transformers_version": "4.25.1",
41 | "use_cache": true,
42 | "vocab_size": 50257
43 | }
44 |
--------------------------------------------------------------------------------
/experiment/configs/memoria-gpt2-sort.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "gpt2_with_memoria",
3 | "activation_function": "gelu_new",
4 | "attn_pdrop": 0.1,
5 | "bos_token_id": 50256,
6 | "embd_pdrop": 0.1,
7 | "eos_token_id": 50256,
8 | "initializer_range": 0.02,
9 | "layer_norm_epsilon": 1e-05,
10 | "memoria_stm_capacity": 512,
11 | "memoria_num_memories": 128,
12 | "memoria_initial_lifespan": 5,
13 | "memoria_lifespan_extend_scale": 8.0,
14 | "memoria_ltm_search_depth": 10,
15 | "memoria_reset_period": 500,
16 | "memoria_num_reminded_stm": 256,
17 | "memoria_num_reminded_ltm": 640,
18 | "memoria_device": null,
19 | "n_ctx": 10000,
20 | "n_embd": 512,
21 | "n_head": 4,
22 | "n_inner": 2048,
23 | "n_layer": 4,
24 | "n_positions": 10000,
25 | "reorder_and_upcast_attn": false,
26 | "resid_pdrop": 0.1,
27 | "scale_attn_by_inverse_layer_idx": false,
28 | "scale_attn_weights": true,
29 | "summary_activation": null,
30 | "summary_first_dropout": 0.1,
31 | "summary_proj_to_labels": true,
32 | "summary_type": "cls_index",
33 | "summary_use_proj": true,
34 | "task_specific_params": {
35 | "text-generation": {
36 | "do_sample": true,
37 | "max_length": 50
38 | }
39 | },
40 | "transformers_version": "4.25.1",
41 | "use_cache": true,
42 | "vocab_size": 50257
43 | }
44 |
--------------------------------------------------------------------------------
/experiment/configs/memoria-gpt2-synthetic.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "gpt2_with_memoria",
3 | "activation_function": "gelu_new",
4 | "attn_pdrop": 0.1,
5 | "bos_token_id": 50256,
6 | "embd_pdrop": 0.1,
7 | "eos_token_id": 50256,
8 | "initializer_range": 0.02,
9 | "layer_norm_epsilon": 1e-05,
10 | "memoria_stm_capacity": 32,
11 | "memoria_num_memories": null,
12 | "memoria_initial_lifespan": 5,
13 | "memoria_lifespan_extend_scale": 6.0,
14 | "memoria_ltm_search_depth": 30,
15 | "memoria_reset_period": 500,
16 | "memoria_num_reminded_stm": 8,
17 | "memoria_num_reminded_ltm": 8,
18 | "memoria_device": null,
19 | "n_ctx": 1024,
20 | "n_embd": 512,
21 | "n_head": 4,
22 | "n_inner": null,
23 | "n_layer": 4,
24 | "n_positions": 1024,
25 | "reorder_and_upcast_attn": false,
26 | "resid_pdrop": 0.1,
27 | "scale_attn_by_inverse_layer_idx": false,
28 | "scale_attn_weights": true,
29 | "summary_activation": null,
30 | "summary_first_dropout": 0.1,
31 | "summary_proj_to_labels": true,
32 | "summary_type": "cls_index",
33 | "summary_use_proj": true,
34 | "task_specific_params": {
35 | "text-generation": {
36 | "do_sample": true,
37 | "max_length": 50
38 | }
39 | },
40 | "transformers_version": "4.25.1",
41 | "use_cache": true,
42 | "vocab_size": 50257
43 | }
44 |
--------------------------------------------------------------------------------
/experiment/configs/memoria-gpt2-xl.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "gpt2_with_memoria",
3 | "activation_function": "gelu_new",
4 | "attn_pdrop": 0.1,
5 | "bos_token_id": 50256,
6 | "embd_pdrop": 0.1,
7 | "eos_token_id": 50256,
8 | "initializer_range": 0.02,
9 | "layer_norm_epsilon": 1e-05,
10 | "memoria_stm_capacity": 400,
11 | "memoria_num_memories": 50,
12 | "memoria_initial_lifespan": 9,
13 | "memoria_lifespan_extend_scale": 8.0,
14 | "memoria_ltm_search_depth": 10,
15 | "memoria_reset_period": 500,
16 | "memoria_num_reminded_stm": 50,
17 | "memoria_num_reminded_ltm": 50,
18 | "memoria_device": null,
19 | "n_ctx": 1024,
20 | "n_embd": 1600,
21 | "n_head": 25,
22 | "n_inner": null,
23 | "n_layer": 48,
24 | "n_positions": 1024,
25 | "reorder_and_upcast_attn": false,
26 | "resid_pdrop": 0.1,
27 | "scale_attn_by_inverse_layer_idx": false,
28 | "scale_attn_weights": true,
29 | "summary_activation": null,
30 | "summary_first_dropout": 0.1,
31 | "summary_proj_to_labels": true,
32 | "summary_type": "cls_index",
33 | "summary_use_proj": true,
34 | "task_specific_params": {
35 | "text-generation": {
36 | "do_sample": true,
37 | "max_length": 50
38 | }
39 | },
40 | "transformers_version": "4.25.1",
41 | "use_cache": true,
42 | "vocab_size": 50257
43 | }
44 |
--------------------------------------------------------------------------------
/experiment/configs/memoria-gpt2.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "gpt2_with_memoria",
3 | "activation_function": "gelu_new",
4 | "attn_pdrop": 0.1,
5 | "bos_token_id": 50256,
6 | "embd_pdrop": 0.1,
7 | "eos_token_id": 50256,
8 | "initializer_range": 0.02,
9 | "layer_norm_epsilon": 1e-05,
10 | "memoria_stm_capacity": 400,
11 | "memoria_num_memories": 50,
12 | "memoria_initial_lifespan": 9,
13 | "memoria_lifespan_extend_scale": 8.0,
14 | "memoria_ltm_search_depth": 10,
15 | "memoria_reset_period": 500,
16 | "memoria_num_reminded_stm": 50,
17 | "memoria_num_reminded_ltm": 50,
18 | "memoria_device": null,
19 | "n_ctx": 1024,
20 | "n_embd": 768,
21 | "n_head": 12,
22 | "n_inner": null,
23 | "n_layer": 12,
24 | "n_positions": 1024,
25 | "reorder_and_upcast_attn": false,
26 | "resid_pdrop": 0.1,
27 | "scale_attn_by_inverse_layer_idx": false,
28 | "scale_attn_weights": true,
29 | "summary_activation": null,
30 | "summary_first_dropout": 0.1,
31 | "summary_proj_to_labels": true,
32 | "summary_type": "cls_index",
33 | "summary_use_proj": true,
34 | "task_specific_params": {
35 | "text-generation": {
36 | "do_sample": true,
37 | "max_length": 50
38 | }
39 | },
40 | "transformers_version": "4.25.1",
41 | "use_cache": true,
42 | "vocab_size": 50257
43 | }
44 |
--------------------------------------------------------------------------------
/experiment/configs/transfo-xl-enwik8.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "transfo-xl",
3 | "adaptive": true,
4 | "attn_type": 0,
5 | "clamp_len": 1000,
6 | "cutoffs": [
7 | 204
8 | ],
9 | "d_embed": 512,
10 | "d_head": 64,
11 | "d_inner": 2048,
12 | "d_model": 512,
13 | "div_val": 4,
14 | "dropatt": 0.0,
15 | "dropout": 0.1,
16 | "eos_token_id": 0,
17 | "init": "normal",
18 | "init_range": 0.01,
19 | "init_std": 0.02,
20 | "layer_norm_epsilon": 1e-05,
21 | "mem_len": 512,
22 | "n_head": 8,
23 | "n_layer": 12,
24 | "pre_lnorm": false,
25 | "proj_init_std": 0.01,
26 | "same_length": true,
27 | "sample_softmax": -1,
28 | "untie_r": true,
29 | "vocab_size": 204
30 | }
31 |
--------------------------------------------------------------------------------
/experiment/configs/transfo-xl-sort.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "transfo-xl",
3 | "adaptive": true,
4 | "attn_type": 0,
5 | "clamp_len": 1000,
6 | "cutoffs": [
7 | 21
8 | ],
9 | "d_embed": 512,
10 | "d_head": 128,
11 | "d_inner": 2048,
12 | "d_model": 512,
13 | "div_val": 4,
14 | "dropatt": 0.0,
15 | "dropout": 0.1,
16 | "eos_token_id": 0,
17 | "init": "normal",
18 | "init_range": 0.01,
19 | "init_std": 0.02,
20 | "layer_norm_epsilon": 1e-05,
21 | "mem_len": 1024,
22 | "n_head": 4,
23 | "n_layer": 4,
24 | "pre_lnorm": false,
25 | "proj_init_std": 0.01,
26 | "same_length": true,
27 | "sample_softmax": -1,
28 | "untie_r": true,
29 | "vocab_size": 50257
30 | }
31 |
--------------------------------------------------------------------------------
/experiment/configs/transfo-xl-synthetic.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "transfo-xl",
3 | "adaptive": true,
4 | "attn_type": 0,
5 | "clamp_len": 1000,
6 | "cutoffs": [
7 | 10
8 | ],
9 | "d_embed": 512,
10 | "d_head": 64,
11 | "d_inner": 2048,
12 | "d_model": 512,
13 | "div_val": 4,
14 | "dropatt": 0.0,
15 | "dropout": 0.1,
16 | "eos_token_id": 0,
17 | "init": "normal",
18 | "init_range": 0.01,
19 | "init_std": 0.02,
20 | "layer_norm_epsilon": 1e-05,
21 | "mem_len": 100,
22 | "n_head": 4,
23 | "n_layer": 4,
24 | "pre_lnorm": false,
25 | "proj_init_std": 0.01,
26 | "same_length": true,
27 | "sample_softmax": -1,
28 | "untie_r": true,
29 | "vocab_size": 50257
30 | }
31 |
--------------------------------------------------------------------------------
/experiment/configs/transfo-xl.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "transfo-xl",
3 | "adaptive": true,
4 | "attn_type": 0,
5 | "clamp_len": 1000,
6 | "cutoffs": [
7 | 50257
8 | ],
9 | "d_embed": 768,
10 | "d_head": 64,
11 | "d_inner": 3072,
12 | "d_model": 768,
13 | "div_val": 4,
14 | "dropatt": 0.0,
15 | "dropout": 0.1,
16 | "eos_token_id": 0,
17 | "init": "normal",
18 | "init_range": 0.01,
19 | "init_std": 0.02,
20 | "layer_norm_epsilon": 1e-05,
21 | "mem_len": 150,
22 | "n_head": 12,
23 | "n_layer": 12,
24 | "pre_lnorm": false,
25 | "proj_init_std": 0.01,
26 | "same_length": true,
27 | "sample_softmax": -1,
28 | "untie_r": true,
29 | "vocab_size": 50257
30 | }
31 |
--------------------------------------------------------------------------------
/experiment/eval_classification.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from typing import Dict
3 |
4 | import pytorch_lightning as pl
5 | import torch
6 | from longseq_formers.data import CLASSIFICATION_DATASETS, load_hyperpartisan_data
7 | from longseq_formers.dataset import ClassificationDataset
8 | from longseq_formers.task import Classification
9 | from longseq_formers.utils import get_logger
10 | from torch.utils.data import DataLoader
11 | from transformers import AutoTokenizer
12 |
13 | # fmt: off
14 | parser = argparse.ArgumentParser(prog="train_classification", description="Train & Test Long Sequence Classification")
15 |
16 | g = parser.add_argument_group("Train Parameter")
17 | g.add_argument("--model", type=str, required=True, help="lightning checkpoint")
18 | g.add_argument("--tokenizer", type=str, required=True, help="huggingface tokenizer")
19 | g.add_argument("--dataset", type=str, default="hyperpartisan", choices=CLASSIFICATION_DATASETS, help="dataset name")
20 | g.add_argument("--valid-batch-size", type=int, default=1, help="validation batch size")
21 | g.add_argument("--max-length", type=int, default=512, help="max sequence length")
22 | g.add_argument("--memory-length", type=int, default=512, help="max sequence length for bert one inference on infinity former")
23 | g.add_argument("--seed", type=int, default=42, help="random seed")
24 | g.add_argument("--not-truncate", action="store_false", dest="truncation", help="not truncate sequence")
25 | g.add_argument("--segment-size", type=int, help="segment size for infinity former")
26 | # fmt: on
27 |
28 |
29 | def main(args: argparse.Namespace) -> dict[str, float]:
30 | logger = get_logger("evaluate_classification")
31 |
32 | logger.info(f"[+] Set Random Seed to {args.seed}")
33 | pl.seed_everything(args.seed, workers=True)
34 |
35 | logger.info(f'[+] Load Tokenizer: "{args.tokenizer}"')
36 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
37 |
38 | logger.info(f'[+] Use Dataset: "{args.dataset}"')
39 | if args.dataset == "hyperpartisan":
40 | datasets = load_hyperpartisan_data()
41 |
42 | valid_dataset = ClassificationDataset(datasets["dev"])
43 | test_dataset = ClassificationDataset(datasets["test"])
44 |
45 | logger.info(f"[+] # of valid examples: {len(valid_dataset)}")
46 | logger.info(f"[+] # of test examples: {len(test_dataset)}")
47 |
48 | logger.info(f'[+] Load Model: "{args.model}"')
49 | classification = Classification.load_from_checkpoint(
50 | args.model, tokenizer=tokenizer, max_length=args.max_length, truncation=args.truncation
51 | )
52 |
53 | collate_fn = ClassificationDataset.pad_collate_fn if not args.truncation else None
54 | valid_dataloader = DataLoader(valid_dataset, batch_size=args.valid_batch_size, collate_fn=collate_fn)
55 | test_dataloader = DataLoader(test_dataset, batch_size=args.valid_batch_size, collate_fn=collate_fn)
56 |
57 | tester = pl.Trainer(accelerator="gpu" if torch.cuda.device_count() else None, devices=1)
58 |
59 | pl.seed_everything(args.seed, workers=True)
60 | result1 = tester.test(classification, valid_dataloader)[0]
61 |
62 | pl.seed_everything(args.seed, workers=True)
63 | result2 = tester.test(classification, test_dataloader)[0]
64 |
65 | print(result1)
66 | print(result2)
67 |
68 |
69 | if __name__ == "__main__":
70 | main(parser.parse_args())
71 | exit(0)
72 |
--------------------------------------------------------------------------------
/experiment/eval_language_modeling.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from typing import Dict
3 |
4 | import pytorch_lightning as pl
5 | import torch
6 | from longseq_formers.data import (
7 | LANGUAGE_MODELING_DATASETS,
8 | enwik8_tokenize,
9 | load_enwik8_data,
10 | load_pg19_data,
11 | load_wikitext103_data,
12 | )
13 | from longseq_formers.dataset import LanguageModelingDataset, text_to_tokens
14 | from longseq_formers.task import LanguageModeling
15 | from longseq_formers.utils import get_logger
16 | from torch.utils.data import DataLoader
17 | from transformers import AutoTokenizer
18 |
19 | # fmt: off
20 | parser = argparse.ArgumentParser(prog="evaluate", description="Evaluate Language Modeling")
21 |
22 | g = parser.add_argument_group("Eval Parameter")
23 | g.add_argument("--model", type=str, required=True, help="huggingface model")
24 | g.add_argument("--tokenizer", type=str, default="gpt2", help="huggingface tokenizer")
25 | g.add_argument("--dataset", type=str, default="wikitext103", choices=LANGUAGE_MODELING_DATASETS, help="dataset name")
26 | g.add_argument("--valid-batch-size", type=int, default=1, help="validation batch size")
27 | g.add_argument("--max-length", type=int, default=512, help="max sequence length")
28 | g.add_argument("--seed", type=int, default=42, help="random seed")
29 | # fmt: on
30 |
31 |
32 | def main(args: argparse.Namespace) -> dict[str, float]:
33 | logger = get_logger("test_language_modeling")
34 |
35 | logger.info(" ====== Arguements ======")
36 | for k, v in vars(args).items():
37 | logger.info(f"{k:25}: {v}")
38 |
39 | logger.info(f"[+] Set Random Seed to {args.seed}")
40 | pl.seed_everything(args.seed, workers=True)
41 |
42 | gpus = torch.cuda.device_count()
43 | logger.info(f"[+] GPU: {gpus}")
44 |
45 | if args.tokenizer is None:
46 | logger.info(f"[+] Use tokenizer same as model: {args.model}")
47 | args.tokenizer = args.model
48 | if args.dataset == "enwik8":
49 | logger.info(f"[+] Use character tokenizer for enwik8 dataset")
50 | tokenizer = enwik8_tokenize
51 | else:
52 | logger.info(f'[+] Load Tokenizer: "{args.tokenizer}"')
53 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
54 |
55 | logger.info(f'[+] Use Dataset: "{args.dataset}"')
56 | if args.dataset == "wikitext103":
57 | data = load_wikitext103_data()
58 | elif args.dataset == "pg19":
59 | data = load_pg19_data()
60 | elif args.dataset == "enwik8":
61 | data = load_enwik8_data()
62 | else:
63 | raise ValueError(f"dataset `{args.dataset}` is not valid!")
64 |
65 | dev_tokens = text_to_tokens(data["dev"], tokenizer, args.valid_batch_size, args.max_length)
66 | test_tokens = text_to_tokens(data["test"], tokenizer, args.valid_batch_size, args.max_length)
67 |
68 | valid_dataset = LanguageModelingDataset(dev_tokens)
69 | test_dataset = LanguageModelingDataset(test_tokens)
70 |
71 | logger.info(f"[+] # of batched valid examples: {len(valid_dataset)}")
72 | logger.info(f"[+] # of batched test examples: {len(test_dataset)}")
73 |
74 | language_modeling = LanguageModeling.load_from_checkpoint(args.model)
75 |
76 | # Use batch size as 1 because already batched
77 | # train_dataloader = DataLoader(train_dataset, batch_size=1, collate_fn=LanguageModelingDataset.collate_fn)
78 | valid_dataloader = DataLoader(valid_dataset, batch_size=1, collate_fn=LanguageModelingDataset.collate_fn)
79 | test_dataloader = DataLoader(test_dataset, batch_size=1, collate_fn=LanguageModelingDataset.collate_fn)
80 | tester = pl.Trainer(accelerator="gpu" if gpus else None, devices=1)
81 |
82 | pl.seed_everything(args.seed, workers=True)
83 | result1 = tester.test(language_modeling, valid_dataloader)[0]
84 |
85 | pl.seed_everything(args.seed, workers=True)
86 | result2 = tester.test(language_modeling, test_dataloader)[0]
87 |
88 | print(result1)
89 | print(result2)
90 |
91 |
92 | if __name__ == "__main__":
93 | main(parser.parse_args())
94 | exit(0)
95 |
--------------------------------------------------------------------------------
/experiment/eval_synthetic.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from typing import Dict
3 |
4 | import pytorch_lightning as pl
5 | import torch
6 | from longseq_formers.dataset.synthetic import SyntheticDataset, parse_syntetic_data
7 | from longseq_formers.task import Synthetic
8 | from longseq_formers.utils import get_logger
9 | from torch.utils.data import DataLoader
10 |
11 | # fmt: off
12 | parser = argparse.ArgumentParser(prog="train_synthetic", description="Train & Test Synthetic Task")
13 |
14 | g = parser.add_argument_group("Train Parameter")
15 | g.add_argument("--model", type=str, required=True, help="model checkpoint")
16 | g.add_argument("--dataset", type=str, required=True, help="dataset name")
17 | g.add_argument("--valid-batch-size", type=int, default=1, help="validation batch size")
18 | g.add_argument("--max-length", type=int, default=150, help="max sequence length")
19 | g.add_argument("--seed", type=int, default=42, help="random seed")
20 | g.add_argument("--shuffle", action="store_true", help="shuffle data order")
21 | # fmt: on
22 |
23 |
24 | def main(args: argparse.Namespace) -> dict[str, float]:
25 | logger = get_logger("eval_synthetic_task")
26 |
27 | logger.info(" ====== Arguements ======")
28 | for k, v in vars(args).items():
29 | logger.info(f"{k:25}: {v}")
30 |
31 | logger.info(f"[+] Set Random Seed to {args.seed}")
32 | pl.seed_everything(args.seed, workers=True)
33 |
34 | logger.info(f'[+] Use Dataset: "{args.dataset}"')
35 | _, vocab_size, _, dev_examples, test_examples = parse_syntetic_data(args.dataset)
36 |
37 | valid_dataset = SyntheticDataset(dev_examples)
38 | test_dataset = SyntheticDataset(test_examples)
39 |
40 | logger.info(f"[+] # of batched valid examples: {len(valid_dataset)}")
41 | logger.info(f"[+] # of batched test examples: {len(test_dataset)}")
42 |
43 | valid_dataloader = DataLoader(valid_dataset, batch_size=args.valid_batch_size)
44 | test_dataloader = DataLoader(test_dataset, batch_size=args.valid_batch_size)
45 |
46 | synthetic_task = Synthetic.load_from_checkpoint(args.model, vocab_size=vocab_size)
47 |
48 | logger.info(f"[+] Start Evaluation")
49 |
50 | tester = pl.Trainer(accelerator="gpu" if torch.cuda.device_count() else None, devices=1)
51 |
52 | pl.seed_everything(args.seed, workers=True)
53 | result1 = tester.test(synthetic_task, valid_dataloader)[0]
54 |
55 | pl.seed_everything(args.seed, workers=True)
56 | result2 = tester.test(synthetic_task, test_dataloader)[0]
57 |
58 | print(result1)
59 | print(result2)
60 |
61 |
62 | if __name__ == "__main__":
63 | main(parser.parse_args())
64 | exit(0)
65 |
--------------------------------------------------------------------------------
/experiment/longseq_formers/__init__.py:
--------------------------------------------------------------------------------
1 | from . import data, dataset, model, task
2 |
3 | __all__ = ["data", "dataset", "model", "task"]
4 |
--------------------------------------------------------------------------------
/experiment/longseq_formers/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .enwik8 import enwik8_tokenize, load_enwik8_data
2 | from .hyperpartisan import load_hyperpartisan_data
3 | from .pg19 import load_pg19_data
4 | from .wikitext103 import load_wikitext103_data
5 |
6 | CLASSIFICATION_DATASETS = ["hyperpartisan"]
7 | LANGUAGE_MODELING_DATASETS = ["wikitext103", "pg19", "enwik8"]
8 | DATASETS = CLASSIFICATION_DATASETS + LANGUAGE_MODELING_DATASETS
9 |
10 | __all__ = [
11 | "enwik8_tokenize",
12 | "load_enwik8_data",
13 | "load_hyperpartisan_data",
14 | "load_pg19_data",
15 | "load_wikitext103_data",
16 | "DATASETS",
17 | "CLASSIFICATION_DATASETS",
18 | "LANGUAGE_MODELING_DATASETS",
19 | ]
20 |
--------------------------------------------------------------------------------
/experiment/longseq_formers/data/enwik8.py:
--------------------------------------------------------------------------------
1 | """
2 | Refered https://github.com/salesforce/awd-lstm-lm/blob/master/data/enwik8/prep_enwik8.py
3 | """
4 |
5 | from typing import Dict
6 |
7 | from datasets import Dataset, DatasetDict, load_dataset
8 |
9 | # fmt: off
10 | CHAR_INDICES = ['9', '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '80', '81', '82', '83', '84', '85', '86', '87', '88', '89', '90', '91', '92', '93', '94', '95', '96', '97', '98', '99', '100', '101', '102', '103', '104', '105', '106', '107', '108', '109', '110', '111', '112', '113', '114', '115', '116', '117', '118', '119', '120', '121', '122', '123', '124', '125', '126', '128', '129', '130', '131', '132', '133', '134', '135', '136', '137', '138', '139', '140', '141', '142', '143', '144', '145', '146', '147', '148', '149', '150', '151', '152', '153', '154', '155', '156', '157', '158', '159', '160', '161', '162', '163', '164', '165', '166', '167', '168', '169', '170', '171', '172', '173', '174', '175', '176', '177', '178', '179', '180', '181', '182', '183', '184', '185', '186', '187', '188', '189', '190', '191', '194', '195', '196', '197', '198', '199', '200', '201', '202', '203', '204', '205', '206', '207', '208', '209', '210', '211', '212', '213', '214', '215', '216', '217', '218', '219', '220', '222', '224', '225', '226', '227', '228', '229', '230', '231', '232', '233', '234', '235', '236', '237', '239', '240']
11 | CHAR_TO_INDEX = {c: i for i, c in enumerate(CHAR_INDICES)}
12 | VOCAB_SIZE = len(CHAR_INDICES)
13 | assert VOCAB_SIZE == 204
14 | # fmt: on
15 |
16 |
17 | def enwik8_tokenize(text: str) -> Dict:
18 | input_ids = [CHAR_TO_INDEX[c] for c in text.split()]
19 | return {"input_ids": input_ids, "attention_mask": [1.0] * len(input_ids)}
20 |
21 |
22 | def load_enwik8_data() -> Dataset:
23 | dataset = load_dataset("enwik8", "enwik8-raw", revision="a3d620ecedec0d39511d1dfdc3a27a69e648be84")["train"]
24 |
25 | num_test_chars = 5000000
26 |
27 | def _preprocess(data):
28 | whole_text = data["text"]
29 | whole_bytes = whole_text.encode()
30 |
31 | train_data = whole_bytes[: -2 * num_test_chars]
32 | valid_data = whole_bytes[-2 * num_test_chars : -num_test_chars]
33 | test_data = whole_bytes[-num_test_chars:]
34 |
35 | train, dev, test = (
36 | " ".join([str(c) if c != ord("\n") else "\n" for c in part]) for part in (train_data, valid_data, test_data)
37 | )
38 |
39 | return {"train": train, "dev": dev, "test": test}
40 |
41 | dataset = dataset.map(_preprocess, remove_columns=dataset.column_names, load_from_cache_file=True)
42 |
43 | train = dataset["train"][0]
44 | dev = dataset["dev"][0]
45 | test = dataset["test"][0]
46 |
47 | def _gen(source):
48 | yield {"text": source}
49 |
50 | train_dataset = Dataset.from_generator(_gen, gen_kwargs={"source": train})
51 | dev_dataset = Dataset.from_generator(_gen, gen_kwargs={"source": dev})
52 | test_dataset = Dataset.from_generator(_gen, gen_kwargs={"source": test})
53 | dataset = DatasetDict({"train": train_dataset, "dev": dev_dataset, "test": test_dataset})
54 | return dataset
55 |
--------------------------------------------------------------------------------
/experiment/longseq_formers/data/hyperpartisan.py:
--------------------------------------------------------------------------------
1 | import re
2 | from collections import defaultdict
3 | from typing import Dict, List
4 |
5 | import datasets
6 | from bs4 import BeautifulSoup
7 |
8 | from ..dataset.classification import ClassificationDatum
9 |
10 | # fmt: off
11 | # hp-splits from longformer (https://github.com/allenai/longformer/blob/master/scripts/hp-splits.json)
12 | HYPERPARTISAN_SPLITS = {
13 | "train": [239, 342, 401, 424, 518, 374, 457, 81, 208, 216, 112, 77, 448, 596, 388, 505, 362, 180, 587, 398, 636, 297, 363, 389, 148, 567, 163, 549, 472, 26, 427, 227, 213, 470, 346, 383, 585, 352, 22, 20, 390, 3, 97, 439, 637, 197, 392, 480, 225, 414, 333, 561, 615, 359, 598, 107, 12, 195, 54, 459, 23, 455, 624, 233, 17, 499, 307, 416, 578, 568, 220, 334, 65, 73, 170, 215, 447, 446, 606, 276, 502, 534, 582, 241, 425, 356, 192, 301, 514, 589, 466, 207, 82, 201, 391, 366, 476, 594, 477, 126, 393, 508, 158, 483, 604, 206, 15, 353, 372, 512, 543, 330, 290, 539, 444, 399, 410, 169, 125, 487, 74, 381, 479, 556, 292, 576, 224, 173, 441, 205, 29, 559, 509, 552, 317, 231, 296, 643, 524, 209, 433, 397, 488, 18, 553, 149, 380, 168, 484, 234, 586, 486, 555, 232, 246, 373, 139, 458, 157, 644, 257, 91, 53, 59, 341, 159, 36, 109, 2, 106, 485, 258, 422, 404, 313, 402, 183, 419, 283, 87, 351, 75, 187, 310, 320, 19, 304, 38, 471, 129, 66, 151, 266, 268, 548, 328, 405, 371, 580, 51, 492, 474, 510, 468, 396, 308, 408, 526, 622, 511, 63, 274, 531, 128, 368, 599, 426, 43, 360, 541, 454, 263, 407, 138, 76, 530, 517, 165, 641, 436, 493, 326, 194, 202, 546, 238, 382, 92, 52, 120, 437, 71, 504, 532, 237, 314, 625, 617, 605, 171, 331, 456, 607, 542, 55, 475, 584, 251, 611, 40, 122, 100, 570, 338, 137, 597, 101, 324, 95, 577, 31, 116, 176, 145, 211, 236, 627, 143, 638, 620, 219, 10, 60, 198, 7, 293, 452, 590, 579, 141, 558, 160, 214, 166, 593, 538, 33, 364, 635, 119, 250, 223, 319, 619, 339, 616, 618, 284, 533, 603, 302, 49, 588, 572, 575, 515, 21, 1, 103, 150, 529, 506, 69, 343, 323, 482, 222, 535, 188, 14, 299, 489, 108, 140, 39, 420, 285, 86, 554, 259, 564, 400, 269, 281, 248, 272, 24, 629, 130, 226, 525, 80, 117, 115, 305, 370, 465, 186, 93, 113, 46, 461, 378, 184, 336, 50, 309, 48, 72, 495, 131, 507, 325, 298, 412, 406, 240, 278, 212, 279, 5, 90, 181, 8, 288, 61, 300, 174, 608, 58, 520, 449, 218, 294, 354, 494, 417, 99, 154, 89, 527, 273, 11, 162, 610, 179, 56, 613, 329, 377, 335, 253, 501, 442, 252, 614, 327, 98, 88, 631, 609, 547, 376, 581, 621, 152, 228, 4, 565, 540, 132, 110, 191, 30, 6, 189, 303, 270, 255, 415, 172, 64, 267, 503, 78, 118, 235, 435, 167, 453, 282, 573, 291, 642, 123, 395, 551, 94, 450, 478, 311, 289, 153, 102, 421, 277, 583, 164, 244, 229, 178, 217, 523, 96, 280, 68, 497, 430, 190, 516, 445, 428, 633, 536, 434, 387, 355, 528, 287, 144, 210, 295, 385, 185, 467, 256, 44, 83, 67, 175, 204, 602, 42, 358, 384, 28, 45, 569, 127, 47, 491, 265, 463, 121, 135, 460],
14 | "dev": [182, 438, 545, 286, 142, 27, 394, 261, 411, 0, 79, 550, 640, 254, 560, 386, 62, 440, 104, 473, 155, 432, 124, 133, 136, 519, 322, 318, 245, 249, 612, 349, 623, 591, 429, 306, 592, 375, 203, 544, 312, 114, 41, 344, 571, 134, 462, 347, 464, 566, 350, 199, 562, 357, 361, 521, 574, 315, 243, 601, 260, 409, 337, 177],
15 | "test": [537, 517, 23, 459, 593, 258, 227, 16, 204, 367, 159, 142, 214, 82, 182, 564, 411, 600, 610, 306, 21, 434, 625, 197, 202, 489, 404, 400, 551, 320, 36, 435, 344, 183, 134, 19, 253, 231, 383, 572, 201, 528, 15, 116, 265, 221, 462, 342, 465, 436, 490, 442, 547, 282, 535, 256, 160, 140, 555, 51, 540, 165, 504, 181, 147],
16 | }
17 | # fmt: on
18 |
19 |
20 | def load_hyperpartisan_data() -> dict[str, list[ClassificationDatum]]:
21 | """Load Hyperpartisan dataset
22 |
23 | Returns:
24 | datasets like below
25 | {
26 | "train": [{
27 | "text": "...",
28 | "label": 1
29 | }, {...}, ...],
30 | "dev": ...,
31 | "test": ...
32 | }
33 | """
34 | data = datasets.load_dataset(
35 | "hyperpartisan_news_detection", "byarticle", revision="c315cc4a12a27cde08fd55c0beda41ced8b75923"
36 | )["train"]
37 |
38 | split_datasets = defaultdict(list)
39 | for split, indices in HYPERPARTISAN_SPLITS.items():
40 | for index in indices:
41 | datum = data[index]
42 | normalized_text = BeautifulSoup(datum["text"], "html.parser").get_text()
43 | text = re.sub(r"\s+", " ", normalized_text)
44 | label = int(datum["hyperpartisan"])
45 | split_datasets[split].append({"text": text, "label": label})
46 | return split_datasets
47 |
--------------------------------------------------------------------------------
/experiment/longseq_formers/data/pg19.py:
--------------------------------------------------------------------------------
1 | import datasets
2 |
3 |
4 | def load_pg19_data(train_dataset_percent: int = 7) -> datasets.Dataset:
5 | dataset = datasets.load_dataset(
6 | "pg19",
7 | revision="dd75f494ab94328d0ce92c05390ab91a96920a9d",
8 | split={
9 | "train": f"train[:{train_dataset_percent}%]",
10 | "dev": "validation",
11 | "test": "test",
12 | },
13 | )
14 | return dataset
15 |
--------------------------------------------------------------------------------
/experiment/longseq_formers/data/wikitext103.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import datasets
4 |
5 |
6 | def load_wikitext103_data() -> datasets.Dataset:
7 | dataset = datasets.load_dataset(
8 | "wikitext",
9 | "wikitext-103-raw-v1",
10 | revision="dfd72879b14bf51e8f831b4b092c4f58f356a70f",
11 | split={"train": f"train", "dev": "validation", "test": "test"},
12 | )
13 |
14 | def _join_segment_text(example):
15 | whole_text = "".join(example["text"])
16 | start_idxs = [m.start() - 1 for m in re.finditer(r"\n\s*= [^=]+ =\s*\n", whole_text)]
17 | all_idxs = [0] + start_idxs + [len(whole_text)]
18 | segments = [whole_text[all_idxs[i] : all_idxs[i + 1]].strip() for i in range(len(all_idxs) - 1)]
19 | return {"text": segments}
20 |
21 | dataset = dataset.map(
22 | _join_segment_text,
23 | load_from_cache_file=True,
24 | batched=True,
25 | batch_size=len(dataset["train"]),
26 | drop_last_batch=False,
27 | remove_columns=["text"],
28 | )
29 | return dataset
30 |
--------------------------------------------------------------------------------
/experiment/longseq_formers/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from .classification import ClassificationDataset, ClassificationDatum
2 | from .language_modeling import LanguageModelingDataset, text_to_tokens
3 | from .synthetic import SyntheticDataset
4 |
5 | __all__ = [
6 | "ClassificationDataset",
7 | "ClassificationDatum",
8 | "LanguageModelingDataset",
9 | "text_to_tokens",
10 | "SyntheticDataset",
11 | ]
12 |
--------------------------------------------------------------------------------
/experiment/longseq_formers/dataset/classification.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, TypedDict
2 |
3 | import torch
4 | from torch.nn.utils.rnn import pad_sequence
5 | from transformers import AutoTokenizer
6 |
7 |
8 | class ClassificationDatum(TypedDict):
9 | text: str
10 | label: int
11 |
12 |
13 | class ClassificationDataset(torch.utils.data.Dataset):
14 | """ClassificationDataset
15 |
16 | Attributes:
17 | data: data for text classification
18 | tokenizer: huggingface tokenizer
19 | max_length: token max length
20 | """
21 |
22 | def __init__(
23 | self, data: list[ClassificationDatum], tokenizer: AutoTokenizer, max_length: int, truncation: bool = True
24 | ) -> None:
25 | super().__init__()
26 |
27 | self.data = data
28 | self.tokenizer = tokenizer
29 | self.max_length = max_length
30 | self.truncation = truncation
31 |
32 | def __len__(self) -> int:
33 | return len(self.data)
34 |
35 | def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
36 | item = self.data[index]
37 | text = item["text"]
38 | label = item["label"]
39 |
40 | inputs = self.tokenizer(
41 | text,
42 | add_special_tokens=True,
43 | max_length=self.max_length,
44 | truncation=self.truncation,
45 | padding="max_length",
46 | return_token_type_ids=True,
47 | return_tensors="pt",
48 | )
49 | inputs = {k: v.squeeze(dim=0) for k, v in inputs.items()}
50 | inputs["labels"] = torch.tensor(label)
51 |
52 | return inputs
53 |
54 | @staticmethod
55 | def pad_collate_fn(batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]:
56 | # [NumTimeSteps, BatchSize, MaxSequenceLength]
57 | padded_batch = {k: [item[k] for item in batch] for k in batch[0].keys()}
58 | for k in padded_batch:
59 | if k == "labels":
60 | padded_batch[k] = torch.stack(padded_batch[k], dim=0)
61 | else:
62 | padded_batch[k] = pad_sequence(padded_batch[k], batch_first=True)
63 |
64 | return padded_batch
65 |
--------------------------------------------------------------------------------
/experiment/longseq_formers/dataset/language_modeling.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional
2 |
3 | import datasets
4 | import torch
5 | from transformers import AutoTokenizer
6 |
7 |
8 | def text_to_tokens(
9 | dataset: datasets.Dataset,
10 | tokenizer: AutoTokenizer,
11 | batch_size: int,
12 | max_length: int,
13 | batch_size_per_device: Optional[int] = None,
14 | ) -> datasets.Dataset:
15 | """Tokenize a series of text into tokens and chunk
16 | the processed datasets will be cached automatically.
17 |
18 | Args:
19 | dataset: huggingface dataset containing "text" field
20 | tokenizer: huggingface tokenizer
21 | batch_size: batch size, in same batch, there's sequential dataset.
22 | max_length: max length of each example. the remainder will be dropped.
23 | batch_size_per_device: batch size per device with using DDP.
24 | Return:
25 | huggingface input dictionary.
26 | the values shaped [NumExamples, BatchSize, MaxLength]
27 | """
28 |
29 | def _tokenize(example):
30 | return tokenizer(example["text"])
31 |
32 | token_dataset = dataset.map(_tokenize, remove_columns=dataset.column_names, load_from_cache_file=True)
33 |
34 | def _segment(example):
35 | num_segments = len(example["input_ids"]) // max_length
36 | return {
37 | "data": [
38 | {k: v[i * max_length : (i + 1) * max_length] for k, v in example.items()} for i in range(num_segments)
39 | ],
40 | "is_end": [False] * (num_segments - 1) + [True] if num_segments else [],
41 | }
42 |
43 | segment_dataset = token_dataset.map(_segment, remove_columns=token_dataset.column_names, load_from_cache_file=True)
44 |
45 | def _merge(examples):
46 | data = examples["data"]
47 | is_ends = examples["is_end"]
48 | merged = {k: [example[k] for datum in data for example in datum] for k in data[0][0].keys()}
49 | merged["is_end"] = [v for is_end in is_ends for v in is_end]
50 | return merged
51 |
52 | merge_dataset = segment_dataset.map(
53 | _merge,
54 | remove_columns=segment_dataset.column_names,
55 | load_from_cache_file=True,
56 | batched=True,
57 | batch_size=len(segment_dataset),
58 | )
59 |
60 | num_examples = len(merge_dataset) // batch_size
61 |
62 | def _batching(example):
63 | return {
64 | k: [v[i : num_examples * batch_size : num_examples] for i in range(num_examples)]
65 | for k, v in example.items()
66 | }
67 |
68 | batch_dataset = merge_dataset.map(_batching, load_from_cache_file=True, batched=True, batch_size=len(merge_dataset))
69 |
70 | def _rebatching_for_multi_device(example):
71 | return {
72 | k: [v[0][i : i + batch_size_per_device] for i in range(0, batch_size, batch_size_per_device)]
73 | for k, v in example.items()
74 | }
75 |
76 | if batch_size_per_device is not None and batch_size != batch_size_per_device:
77 | batch_dataset = batch_dataset.map(
78 | _rebatching_for_multi_device, load_from_cache_file=True, batched=True, batch_size=1
79 | )
80 | batch_dataset.set_format(type="torch", columns=batch_dataset.column_names)
81 | return batch_dataset
82 |
83 |
84 | class LanguageModelingDataset(torch.utils.data.Dataset):
85 | def __init__(self, data: datasets.Dataset) -> None:
86 | super().__init__()
87 |
88 | self.data = {k: data[k] for k in data.column_names}
89 |
90 | def __len__(self) -> int:
91 | return len(self.data["input_ids"])
92 |
93 | def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
94 | # [BatchSize, MaxLength]
95 | inputs = {k: v[index] for k, v in self.data.items()}
96 | inputs["labels"] = inputs["input_ids"]
97 | return inputs
98 |
99 | @staticmethod
100 | def collate_fn(batches: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]:
101 | """Select first item becuase batch size is 1"""
102 | assert len(batches) == 1
103 | return batches[0]
104 |
--------------------------------------------------------------------------------
/experiment/longseq_formers/dataset/synthetic.py:
--------------------------------------------------------------------------------
1 | import json
2 | from typing import Dict, List, Tuple
3 |
4 | import torch
5 |
6 |
7 | def parse_syntetic_data(
8 | path: str,
9 | ) -> Tuple[int, int, list[dict[str, list[int]]], list[dict[str, list[int]]], list[dict[str, list[int]]]]:
10 | with open(path, "r") as f:
11 | data = json.load(f)
12 |
13 | prompt_length = data["prompt_length"]
14 | vocab_size = data["vocab_size"]
15 | train_examples = data["train"]
16 | dev_examples = data["dev"]
17 | test_examples = data["test"]
18 |
19 | return prompt_length, vocab_size, train_examples, dev_examples, test_examples
20 |
21 |
22 | class SyntheticDataset(torch.utils.data.Dataset):
23 | def __init__(self, examples: list[dict[str, list[int]]]) -> None:
24 | super().__init__()
25 |
26 | self.data = examples
27 |
28 | def __len__(self) -> int:
29 | return len(self.data)
30 |
31 | def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
32 | example = self.data[index]
33 | input_ids = example["prompt_ids"] + example["target_ids"][:-1]
34 | labels = [-100] * (len(example["prompt_ids"]) - 1) + example["target_ids"]
35 | return {"input_ids": torch.tensor(input_ids), "labels": torch.tensor(labels)}
36 |
--------------------------------------------------------------------------------
/experiment/longseq_formers/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .compressive_former import CompressiveFormerConfig, CompressiveFormerLMHeadModel, CompressiveFormerModel
2 | from .gpt2_with_memoria import GPT2WithMemoriaConfig, GPT2WithMemoriaLMHeadModel, GPT2WithMemoriaModel
3 | from .infinity_gpt2 import InfinityGPT2Config, InfinityGPT2LMHeadModel, InfinityGPT2Model
4 | from .memoria_bert import MemoriaBertConfig, MemoriaBertForSequenceClassification, MemoriaBertModel
5 | from .memoria_roberta import MemoriaRobertaConfig, MemoriaRobertaForSequenceClassification, MemoriaRobertaModel
6 |
7 | __all__ = [
8 | "CompressiveFormerConfig",
9 | "CompressiveFormerLMHeadModel",
10 | "CompressiveFormerModel",
11 | "GPT2WithMemoriaConfig",
12 | "GPT2WithMemoriaLMHeadModel",
13 | "GPT2WithMemoriaModel",
14 | "InfinityGPT2Config",
15 | "InfinityGPT2LMHeadModel",
16 | "InfinityGPT2Model",
17 | "MemoriaBertConfig",
18 | "MemoriaBertForSequenceClassification",
19 | "MemoriaBertModel",
20 | "MemoriaRobertaConfig",
21 | "MemoriaRobertaForSequenceClassification",
22 | "MemoriaRobertaModel",
23 | ]
24 |
--------------------------------------------------------------------------------
/experiment/longseq_formers/model/compressive_former/__init__.py:
--------------------------------------------------------------------------------
1 | from .modeling_compressive_transformer import (
2 | CompressiveFormerConfig,
3 | CompressiveFormerLMHeadModel,
4 | CompressiveFormerModel,
5 | )
6 |
7 | __all__ = [
8 | "CompressiveFormerConfig",
9 | "CompressiveFormerLMHeadModel",
10 | "CompressiveFormerModel",
11 | ]
12 |
--------------------------------------------------------------------------------
/experiment/longseq_formers/model/gpt2_with_memoria/__init__.py:
--------------------------------------------------------------------------------
1 | from .modeling_gpt2_with_memoria import GPT2WithMemoriaConfig, GPT2WithMemoriaLMHeadModel, GPT2WithMemoriaModel
2 |
3 | __all__ = [
4 | "GPT2WithMemoriaConfig",
5 | "GPT2WithMemoriaLMHeadModel",
6 | "GPT2WithMemoriaModel",
7 | ]
8 |
--------------------------------------------------------------------------------
/experiment/longseq_formers/model/infinity_gpt2/__init__.py:
--------------------------------------------------------------------------------
1 | from .configuration_infinity_gpt2 import InfinityGPT2Config
2 | from .modeling_infinity_gpt2 import InfinityGPT2LMHeadModel, InfinityGPT2Model
3 |
4 | __all__ = [
5 | "InfinityGPT2Config",
6 | "InfinityGPT2LMHeadModel",
7 | "InfinityGPT2Model",
8 | ]
9 |
--------------------------------------------------------------------------------
/experiment/longseq_formers/model/infinity_gpt2/basis_functions.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 |
5 |
6 | class BasisFunctions(object):
7 | def __init__(self):
8 | pass
9 |
10 | def __len__(self):
11 | """Number of basis functions."""
12 | pass
13 |
14 | def evaluate(self, t):
15 | pass
16 |
17 | def integrate_t2_times_psi(self, a, b):
18 | """Compute integral int_a^b (t**2) * psi(t)."""
19 | pass
20 |
21 | def integrate_t_times_psi(self, a, b):
22 | """Compute integral int_a^b t * psi(t)."""
23 | pass
24 |
25 | def integrate_psi(self, a, b):
26 | """Compute integral int_a^b psi(t)."""
27 | pass
28 |
29 |
30 | class PowerBasisFunctions(BasisFunctions):
31 | """Function phi(t) = t**degree."""
32 |
33 | def __init__(self, degree):
34 | self.degree = degree.unsqueeze(0)
35 |
36 | def __len__(self):
37 | """Number of basis functions."""
38 | return self.degree.size(1)
39 |
40 | def evaluate(self, t):
41 | return t**self.degree
42 |
43 | def integrate_t2_times_psi(self, a, b):
44 | """Compute integral int_a^b (t**2) * psi(t)."""
45 | return (b ** (self.degree + 3) - a ** (self.degree + 3)) / (self.degree + 3)
46 |
47 | def integrate_t_times_psi(self, a, b):
48 | """Compute integral int_a^b t * psi(t)."""
49 | return (b ** (self.degree + 2) - a ** (self.degree + 2)) / (self.degree + 2)
50 |
51 | def integrate_psi(self, a, b):
52 | """Compute integral int_a^b psi(t)."""
53 | return (b ** (self.degree + 1) - a ** (self.degree + 1)) / (self.degree + 1)
54 |
55 | def __repr__(self):
56 | return f"PowerBasisFunction(degree={self.degree})"
57 |
58 |
59 | class SineBasisFunctions(BasisFunctions):
60 | """Function phi(t) = sin(omega*t)."""
61 |
62 | def __init__(self, omega):
63 | self.omega = omega.unsqueeze(0)
64 |
65 | def __repr__(self):
66 | return f"SineBasisFunction(omega={self.omega})"
67 |
68 | def __len__(self):
69 | """Number of basis functions."""
70 | return self.omega.size(1)
71 |
72 | def evaluate(self, t):
73 | return torch.sin(self.omega * t)
74 |
75 | def integrate_t2_times_psi(self, a, b):
76 | """Compute integral int_a^b (t**2) * psi(t)."""
77 | # The antiderivative of (t**2)*sin(omega*t) is
78 | # ((2-(t**2)*(omega**2))*cos(omega*t) + 2*omega*t*sin(omega*t)) / omega**3. # noqa
79 | return (
80 | (2 - (b**2) * (self.omega**2)) * torch.cos(self.omega * b)
81 | + 2 * self.omega * b * torch.sin(self.omega * b)
82 | - (2 - (a**2) * (self.omega**2)) * torch.cos(self.omega * a)
83 | - 2 * self.omega * a * torch.sin(self.omega * a)
84 | ) / (self.omega**3)
85 |
86 | def integrate_t_times_psi(self, a, b):
87 | """Compute integral int_a^b t * psi(t)."""
88 | # The antiderivative of t*sin(omega*t) is
89 | # (sin(omega*t) - omega*t*cos(omega*t)) / omega**2.
90 | return (
91 | torch.sin(self.omega * b)
92 | - self.omega * b * torch.cos(self.omega * b)
93 | - torch.sin(self.omega * a)
94 | + self.omega * a * torch.cos(self.omega * a)
95 | ) / (self.omega**2)
96 |
97 | def integrate_psi(self, a, b):
98 | """Compute integral int_a^b psi(t)."""
99 | # The antiderivative of sin(omega*t) is -cos(omega*t)/omega.
100 | return (-torch.cos(self.omega * b) + torch.cos(self.omega * a)) / self.omega
101 |
102 |
103 | class CosineBasisFunctions(BasisFunctions):
104 | """Function phi(t) = cos(omega*t)."""
105 |
106 | def __init__(self, omega):
107 | self.omega = omega.unsqueeze(0)
108 |
109 | def __repr__(self):
110 | return f"CosineBasisFunction(omega={self.omega})"
111 |
112 | def __len__(self):
113 | """Number of basis functions."""
114 | return self.omega.size(1)
115 |
116 | def evaluate(self, t):
117 | return torch.cos(self.omega * t)
118 |
119 | def integrate_t2_times_psi(self, a, b):
120 | """Compute integral int_a^b (t**2) * psi(t)."""
121 | # The antiderivative of (t**2)*cos(omega*t) is
122 | # (((t**2)*(omega**2)-2)*cos(omega*t) + 2*omega*t*sin(omega*t)) / omega**3. # noqa
123 | return (
124 | ((b**2) * (self.omega**2) - 2) * torch.sin(self.omega * b)
125 | + 2 * self.omega * b * torch.cos(self.omega * b)
126 | - ((a**2) * (self.omega**2) - 2) * torch.sin(self.omega * a)
127 | - 2 * self.omega * a * torch.cos(self.omega * a)
128 | ) / (self.omega**3)
129 |
130 | def integrate_t_times_psi(self, a, b):
131 | """Compute integral int_a^b t * psi(t)."""
132 | # The antiderivative of t*cos(omega*t) is
133 | # (cos(omega*t) + omega*t*sin(omega*t)) / omega**2.
134 | return (
135 | torch.cos(self.omega * b)
136 | + self.omega * b * torch.sin(self.omega * b)
137 | - torch.cos(self.omega * a)
138 | - self.omega * a * torch.sin(self.omega * a)
139 | ) / (self.omega**2)
140 |
141 | def integrate_psi(self, a, b):
142 | """Compute integral int_a^b psi(t)."""
143 | # The antiderivative of cos(omega*t) is sin(omega*t)/omega.
144 | return (torch.sin(self.omega * b) - torch.sin(self.omega * a)) / self.omega
145 |
146 |
147 | class GaussianBasisFunctions(BasisFunctions):
148 | """Function phi(t) = Gaussian(t; mu, sigma_sq)
149 |
150 | Attributes:
151 | mu: mu shaped [1, NumBasis]
152 | sigma: sigma shaped [1, NumBasis]
153 | """
154 |
155 | def __init__(self, mu, sigma):
156 | self.mu = mu.unsqueeze(0)
157 | self.sigma = sigma.unsqueeze(0)
158 |
159 | def __repr__(self):
160 | return f"GaussianBasisFunction(mu={self.mu}, sigma={self.sigma})"
161 |
162 | def __len__(self):
163 | """Number of basis functions."""
164 | return self.mu.size(1)
165 |
166 | def _phi(self, t):
167 | return 1.0 / math.sqrt(2 * math.pi) * torch.exp(-0.5 * t**2)
168 |
169 | def _Phi(self, t):
170 | return 0.5 * (1 + torch.erf(t / math.sqrt(2)))
171 |
172 | def _integrate_product_of_gaussians(self, mu, sigma_sq):
173 | sigma = torch.sqrt(self.sigma**2 + sigma_sq)
174 | return self._phi((mu - self.mu) / sigma) / sigma
175 |
176 | def evaluate(self, t):
177 | """Return Gaussian Function value
178 |
179 | Args:
180 | t: [BatchSize, NumBasis] or [BatchSize, 1] or scalar
181 | considered same value for all basis if NumBasis shape is none.
182 | Return:
183 | Gaussian function value shaped [BatchSize, NumBasis]
184 | """
185 | return self._phi((t - self.mu) / self.sigma) / self.sigma
186 |
187 | def integrate_t2_times_psi(self, a, b):
188 | """Compute integral int_a^b (t**2) * psi(t)."""
189 | return (
190 | (self.mu**2 + self.sigma**2)
191 | * (self._Phi((b - self.mu) / self.sigma) - self._Phi((a - self.mu) / self.sigma))
192 | - (self.sigma * (b + self.mu) * self._phi((b - self.mu) / self.sigma))
193 | + (self.sigma * (a + self.mu) * self._phi((a - self.mu) / self.sigma))
194 | )
195 |
196 | def integrate_t_times_psi(self, a, b):
197 | """Compute integral int_a^b t * psi(t)."""
198 | return self.mu * (
199 | self._Phi((b - self.mu) / self.sigma) - self._Phi((a - self.mu) / self.sigma)
200 | ) - self.sigma * (self._phi((b - self.mu) / self.sigma) - self._phi((a - self.mu) / self.sigma))
201 |
202 | def integrate_psi(self, a, b):
203 | """Compute integral int_a^b psi(t)."""
204 | return self._Phi((b - self.mu) / self.sigma) - self._Phi((a - self.mu) / self.sigma)
205 |
206 | def integrate_t2_times_psi_gaussian(self, mu, sigma_sq):
207 | """Compute integral int N(t; mu, sigma_sq) * t**2 * psi(t)."""
208 | S_tilde = self._integrate_product_of_gaussians(mu, sigma_sq)
209 | mu_tilde = (self.mu * sigma_sq + mu * self.sigma**2) / (self.sigma**2 + sigma_sq)
210 | sigma_sq_tilde = ((self.sigma**2) * sigma_sq) / (self.sigma**2 + sigma_sq)
211 | return S_tilde * (mu_tilde**2 + sigma_sq_tilde)
212 |
213 | def integrate_t_times_psi_gaussian(self, mu, sigma_sq):
214 | """Compute integral int N(t; mu, sigma_sq) * t * psi(t)."""
215 | S_tilde = self._integrate_product_of_gaussians(mu, sigma_sq)
216 | mu_tilde = (self.mu * sigma_sq + mu * self.sigma**2) / (self.sigma**2 + sigma_sq)
217 | return S_tilde * mu_tilde
218 |
219 | def integrate_psi_gaussian(self, mu, sigma_sq):
220 | """Compute integral int N(t; mu, sigma_sq) * psi(t)."""
221 | return self._integrate_product_of_gaussians(mu, sigma_sq)
222 |
--------------------------------------------------------------------------------
/experiment/longseq_formers/model/infinity_gpt2/configuration_infinity_gpt2.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ OpenAI GPT-2 configuration """
17 |
18 | from transformers import AutoConfig
19 | from transformers.configuration_utils import PretrainedConfig
20 | from transformers.utils import logging
21 |
22 | logger = logging.get_logger(__name__)
23 |
24 |
25 | class InfinityGPT2Config(PretrainedConfig):
26 | """
27 | This is the configuration class to store the configuration of a :class:`~transformers.GPT2Model` or a
28 | :class:`~transformers.TFGPT2Model`. It is used to instantiate a GPT-2 model according to the specified arguments,
29 | defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration
30 | to that of the GPT-2 `small `__ architecture.
31 |
32 | Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
33 | outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
34 |
35 |
36 | Args:
37 | vocab_size (:obj:`int`, `optional`, defaults to 50257):
38 | Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the
39 | :obj:`inputs_ids` passed when calling :class:`~transformers.GPT2Model` or
40 | :class:`~transformers.TFGPT2Model`.
41 | n_positions (:obj:`int`, `optional`, defaults to 1024):
42 | The maximum sequence length that this model might ever be used with. Typically set this to something large
43 | just in case (e.g., 512 or 1024 or 2048).
44 | n_ctx (:obj:`int`, `optional`, defaults to 1024):
45 | Dimensionality of the causal mask (usually same as n_positions).
46 | n_embd (:obj:`int`, `optional`, defaults to 768):
47 | Dimensionality of the embeddings and hidden states.
48 | n_layer (:obj:`int`, `optional`, defaults to 12):
49 | Number of hidden layers in the Transformer encoder.
50 | n_head (:obj:`int`, `optional`, defaults to 12):
51 | Number of attention heads for each attention layer in the Transformer encoder.
52 | n_inner (:obj:`int`, `optional`, defaults to None):
53 | Dimensionality of the inner feed-forward layers. :obj:`None` will set it to 4 times n_embd
54 | activation_function (:obj:`str`, `optional`, defaults to :obj:`"gelu"`):
55 | Activation function, to be selected in the list :obj:`["relu", "silu", "gelu", "tanh", "gelu_new"]`.
56 | resid_pdrop (:obj:`float`, `optional`, defaults to 0.1):
57 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
58 | embd_pdrop (:obj:`int`, `optional`, defaults to 0.1):
59 | The dropout ratio for the embeddings.
60 | attn_pdrop (:obj:`float`, `optional`, defaults to 0.1):
61 | The dropout ratio for the attention.
62 | layer_norm_epsilon (:obj:`float`, `optional`, defaults to 1e-5):
63 | The epsilon to use in the layer normalization layers
64 | initializer_range (:obj:`float`, `optional`, defaults to 0.02):
65 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
66 | summary_type (:obj:`string`, `optional`, defaults to :obj:`"cls_index"`):
67 | Argument used when doing sequence summary, used in the models :class:`~transformers.GPT2DoubleHeadsModel`
68 | and :class:`~transformers.TFGPT2DoubleHeadsModel`.
69 |
70 | Has to be one of the following options:
71 |
72 | - :obj:`"last"`: Take the last token hidden state (like XLNet).
73 | - :obj:`"first"`: Take the first token hidden state (like BERT).
74 | - :obj:`"mean"`: Take the mean of all tokens hidden states.
75 | - :obj:`"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2).
76 | - :obj:`"attn"`: Not implemented now, use multi-head attention.
77 | summary_use_proj (:obj:`bool`, `optional`, defaults to :obj:`True`):
78 | Argument used when doing sequence summary, used in the models :class:`~transformers.GPT2DoubleHeadsModel`
79 | and :class:`~transformers.TFGPT2DoubleHeadsModel`.
80 |
81 | Whether or not to add a projection after the vector extraction.
82 | summary_activation (:obj:`str`, `optional`):
83 | Argument used when doing sequence summary. Used in for the multiple choice head in
84 | :class:`~transformers.GPT2DoubleHeadsModel`.
85 |
86 | Pass :obj:`"tanh"` for a tanh activation to the output, any other value will result in no activation.
87 | summary_proj_to_labels (:obj:`bool`, `optional`, defaults to :obj:`True`):
88 | Argument used when doing sequence summary, used in the models :class:`~transformers.GPT2DoubleHeadsModel`
89 | and :class:`~transformers.TFGPT2DoubleHeadsModel`.
90 |
91 | Whether the projection outputs should have :obj:`config.num_labels` or :obj:`config.hidden_size` classes.
92 | summary_first_dropout (:obj:`float`, `optional`, defaults to 0.1):
93 | Argument used when doing sequence summary, used in the models :class:`~transformers.GPT2DoubleHeadsModel`
94 | and :class:`~transformers.TFGPT2DoubleHeadsModel`.
95 |
96 | The dropout ratio to be used after the projection and activation.
97 | gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
98 | Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.
99 | use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
100 | Whether or not the model should return the last key/values attentions (not used by all models).
101 |
102 | Example::
103 |
104 | >>> from transformers import GPT2Model, GPT2Config
105 |
106 | >>> # Initializing a GPT2 configuration
107 | >>> configuration = GPT2Config()
108 |
109 | >>> # Initializing a model from the configuration
110 | >>> model = GPT2Model(configuration)
111 |
112 | >>> # Accessing the model configuration
113 | >>> configuration = model.config
114 | """
115 |
116 | model_type = "infinity_gpt2"
117 | keys_to_ignore_at_inference = ["past_key_values"]
118 |
119 | def __init__(
120 | self,
121 | vocab_size=50257,
122 | n_positions=512,
123 | n_ctx=512,
124 | n_embd=1024,
125 | n_layer=24,
126 | n_head=16,
127 | n_inner=None,
128 | activation_function="gelu_new",
129 | resid_pdrop=0.1,
130 | embd_pdrop=0.1,
131 | attn_drop=0.1,
132 | attn_pdrop=0.1,
133 | layer_norm_epsilon=1e-5,
134 | initializer_range=0.02,
135 | summary_type="cls_index",
136 | summary_use_proj=True,
137 | summary_activation=None,
138 | summary_proj_to_labels=True,
139 | summary_first_dropout=0.1,
140 | gradient_checkpointing=False,
141 | use_cache=True,
142 | bos_token_id=50256,
143 | eos_token_id=50256,
144 | memory_length=150,
145 | num_basis=150,
146 | num_samples=150,
147 | tau=0.5,
148 | normalize_function="softmax",
149 | mask_type="cnn",
150 | mask_dropout=0.1,
151 | longterm_attention_dropout=0.1,
152 | use_affines=True,
153 | use_kl_regularizer=True,
154 | use_sticky_memories=True,
155 | mu_0=-1.0,
156 | sigma_0=0.05,
157 | kl_lambda=1e-6,
158 | detach_recursive_outputs=True,
159 | **kwargs
160 | ):
161 | super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
162 |
163 | self.vocab_size = vocab_size
164 | self.n_ctx = n_ctx
165 | self.n_positions = n_positions
166 | self.n_embd = n_embd
167 | self.n_layer = n_layer
168 | self.n_head = n_head
169 | self.n_inner = n_inner
170 | self.activation_function = activation_function
171 | self.resid_pdrop = resid_pdrop
172 | self.embd_pdrop = embd_pdrop
173 | self.attn_drop = attn_drop
174 | self.attn_pdrop = attn_pdrop
175 | self.layer_norm_epsilon = layer_norm_epsilon
176 | self.initializer_range = initializer_range
177 | self.summary_type = summary_type
178 | self.summary_use_proj = summary_use_proj
179 | self.summary_activation = summary_activation
180 | self.summary_first_dropout = summary_first_dropout
181 | self.summary_proj_to_labels = summary_proj_to_labels
182 | self.gradient_checkpointing = gradient_checkpointing
183 | self.use_cache = use_cache
184 |
185 | self.bos_token_id = bos_token_id
186 | self.eos_token_id = eos_token_id
187 |
188 | self.memory_length = memory_length
189 | self.num_basis = num_basis
190 | self.num_samples = num_samples
191 | self.tau = tau
192 | self.normalize_function = normalize_function
193 | self.mask_type = mask_type
194 | self.mask_dropout = mask_dropout
195 | self.longterm_attention_dropout = longterm_attention_dropout
196 | self.use_affines = use_affines
197 | self.use_kl_regularizer = use_kl_regularizer
198 | self.use_sticky_memories = use_sticky_memories
199 | self.mu_0 = mu_0
200 | self.sigma_0 = sigma_0
201 | self.kl_lambda = kl_lambda
202 | self.detach_recursive_outputs = detach_recursive_outputs
203 |
204 | @property
205 | def max_position_embeddings(self):
206 | return self.n_positions
207 |
208 | @property
209 | def hidden_size(self):
210 | return self.n_embd
211 |
212 | @property
213 | def num_attention_heads(self):
214 | return self.n_head
215 |
216 | @property
217 | def num_hidden_layers(self):
218 | return self.n_layer
219 |
220 |
221 | InfinityGPT2Config.register_for_auto_class()
222 | AutoConfig.register("infinity_gpt2", InfinityGPT2Config)
223 |
--------------------------------------------------------------------------------
/experiment/longseq_formers/model/infinity_gpt2/continuous_softmax.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.autograd.function import FunctionCtx
6 |
7 | from .basis_functions import GaussianBasisFunctions
8 |
9 |
10 | class ContinuousSoftmaxFunction(torch.autograd.Function):
11 | @classmethod
12 | def _expectation_phi_psi(cls, ctx: FunctionCtx, mu: torch.FloatTensor, sigma_sq: torch.FloatTensor):
13 | """Compute expectation of phi(t) * psi(t).T under N(mu, sigma_sq)."""
14 | num_basis = [len(basis_functions) for basis_functions in ctx.psi]
15 | total_basis = sum(num_basis)
16 | V = torch.zeros((mu.shape[0], 2, total_basis), dtype=ctx.dtype, device=ctx.device)
17 | offsets = torch.cumsum(torch.tensor(num_basis, dtype=torch.int, device=ctx.device), dim=0)
18 | start = 0
19 | for j, basis_functions in enumerate(ctx.psi):
20 | V[:, 0, start : offsets[j]] = basis_functions.integrate_t_times_psi_gaussian(mu, sigma_sq)
21 | V[:, 1, start : offsets[j]] = basis_functions.integrate_t2_times_psi_gaussian(mu, sigma_sq)
22 | start = offsets[j]
23 | return V
24 |
25 | @classmethod
26 | def _expectation_psi(
27 | cls, ctx: FunctionCtx, mu: torch.FloatTensor, sigma_sq: torch.FloatTensor
28 | ) -> torch.FloatTensor:
29 | """Compute expectation of psi under N(mu, sigma_sq).
30 |
31 | Args:
32 | mu: mu of distribution shaped [BatchSize, 1]
33 | sigma_sq: sigma_sq of distribution shaped [BatchSize, 1]
34 | Return:
35 | integraded result shaped [BatchSize, TotalBasis]
36 | """
37 | psi: list[GaussianBasisFunctions] = ctx.psi
38 | num_basis = [len(basis_functions) for basis_functions in psi]
39 | total_basis = sum(num_basis)
40 | r = torch.zeros(mu.shape[0], total_basis, dtype=ctx.dtype, device=ctx.device)
41 | offsets = torch.cumsum(torch.tensor(num_basis, dtype=torch.int, device=ctx.device), dim=0)
42 | start = 0
43 | for j, basis_functions in enumerate(psi):
44 | r[:, start : offsets[j]] = basis_functions.integrate_psi_gaussian(mu, sigma_sq)
45 | start = offsets[j]
46 | return r
47 |
48 | @classmethod
49 | def _expectation_phi(cls, ctx: FunctionCtx, mu: torch.FloatTensor, sigma_sq: torch.FloatTensor):
50 | """Compute expectation of phi under N(mu, sigma_sq)."""
51 | v = torch.zeros(mu.shape[0], 2, dtype=ctx.dtype, device=ctx.device)
52 | v[:, 0] = mu.squeeze(1)
53 | v[:, 1] = (mu**2 + sigma_sq).squeeze(1)
54 | return v
55 |
56 | @classmethod
57 | def forward(
58 | cls, ctx: FunctionCtx, theta: torch.FloatTensor, psi: list[GaussianBasisFunctions]
59 | ) -> torch.FloatTensor:
60 | """
61 | We assume a Gaussian.
62 | We have:
63 | theta = [mu/sigma**2, -1/(2*sigma**2)],
64 | phi(t) = [t, t**2],
65 | p(t) = Gaussian(t; mu, sigma**2).
66 |
67 | Args:
68 | theta: shaped [BatchSize, 2]
69 | psi: list of basis functions
70 | """
71 | ctx.dtype = theta.dtype
72 | ctx.device = theta.device
73 | ctx.psi = psi
74 | # sigma_sq, mu: [BatchSize, 1]
75 | sigma_sq = (-0.5 / theta[:, 1]).unsqueeze(1)
76 | mu = theta[:, 0].unsqueeze(1) * sigma_sq
77 |
78 | r = cls._expectation_psi(ctx, mu, sigma_sq)
79 | ctx.save_for_backward(mu, sigma_sq, r)
80 | return r
81 |
82 | @classmethod
83 | def backward(cls, ctx: FunctionCtx, grad_output):
84 | mu, sigma_sq, r = ctx.saved_tensors
85 | J = cls._expectation_phi_psi(ctx, mu, sigma_sq)
86 | e_phi = cls._expectation_phi(ctx, mu, sigma_sq)
87 | e_psi = cls._expectation_psi(ctx, mu, sigma_sq)
88 | J -= torch.bmm(e_phi.unsqueeze(2), e_psi.unsqueeze(1))
89 | grad_input = torch.matmul(J, grad_output.unsqueeze(2)).squeeze(2)
90 | return grad_input, None
91 |
92 |
93 | class ContinuousSoftmax(nn.Module):
94 | def __init__(self, psi: Optional[list[GaussianBasisFunctions]] = None):
95 | super(ContinuousSoftmax, self).__init__()
96 | self.psi = psi
97 |
98 | def forward(self, theta):
99 | return ContinuousSoftmaxFunction.apply(theta, self.psi)
100 |
--------------------------------------------------------------------------------
/experiment/longseq_formers/model/infinity_gpt2/continuous_sparsemax.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class ContinuousSparsemaxFunction(torch.autograd.Function):
6 | @classmethod
7 | def _integrate_phi_times_psi(cls, ctx, a, b):
8 | """Compute integral int_a^b phi(t) * psi(t).T."""
9 | num_basis = [len(basis_functions) for basis_functions in ctx.psi]
10 | total_basis = sum(num_basis)
11 | V = torch.zeros((a.shape[0], 2, total_basis), dtype=ctx.dtype, device=ctx.device)
12 | offsets = torch.cumsum(torch.tensor(num_basis, dtype=torch.int, device=ctx.device), dim=0)
13 | start = 0
14 | for j, basis_functions in enumerate(ctx.psi):
15 | V[:, 0, start : offsets[j]] = basis_functions.integrate_t_times_psi(a, b)
16 | V[:, 1, start : offsets[j]] = basis_functions.integrate_t2_times_psi(a, b)
17 | start = offsets[j]
18 | return V
19 |
20 | @classmethod
21 | def _integrate_psi(cls, ctx, a, b):
22 | """Compute integral int_a^b psi(t)."""
23 | num_basis = [len(basis_functions) for basis_functions in ctx.psi]
24 | total_basis = sum(num_basis)
25 | v = torch.zeros(a.shape[0], total_basis, dtype=ctx.dtype, device=ctx.device)
26 | offsets = torch.cumsum(torch.tensor(num_basis, dtype=torch.int, device=ctx.device), dim=0)
27 | start = 0
28 | for j, basis_functions in enumerate(ctx.psi):
29 | v[:, start : offsets[j]] = basis_functions.integrate_psi(a, b)
30 | start = offsets[j]
31 | return v
32 |
33 | @classmethod
34 | def _integrate_phi(cls, ctx, a, b):
35 | """Compute integral int_a^b phi(t)."""
36 | v = torch.zeros(a.shape[0], 2, dtype=ctx.dtype, device=ctx.device)
37 | v[:, 0] = ((b**2 - a**2) / 2).squeeze(1)
38 | v[:, 1] = ((b**3 - a**3) / 3).squeeze(1)
39 | return v
40 |
41 | @classmethod
42 | def forward(cls, ctx, theta, psi):
43 | # We assume a truncated parabola.
44 | # We have:
45 | # theta = [mu/sigma**2, -1/(2*sigma**2)],
46 | # phi(t) = [t, t**2],
47 | # p(t) = [theta.dot(phi(t)) - A]_+,
48 | # supported on [mu - a, mu + a].
49 | ctx.dtype = theta.dtype
50 | ctx.device = theta.device
51 | ctx.psi = psi
52 | sigma = torch.sqrt(-0.5 / theta[:, 1])
53 | mu = theta[:, 0] * sigma**2
54 | A = -0.5 * (3.0 / (2 * sigma)) ** (2.0 / 3)
55 | a = torch.sqrt(-2 * A) * sigma
56 | A += mu**2 / (2 * sigma**2)
57 | left = (mu - a).unsqueeze(1)
58 | right = (mu + a).unsqueeze(1)
59 | V = cls._integrate_phi_times_psi(ctx, left, right)
60 | u = cls._integrate_psi(ctx, left, right)
61 | r = torch.matmul(theta.unsqueeze(1), V).squeeze(1) - A.unsqueeze(1) * u
62 | ctx.save_for_backward(mu, a, V, u)
63 | return r
64 |
65 | @classmethod
66 | def backward(cls, ctx, grad_output):
67 | mu, a, V, u = ctx.saved_tensors
68 | # J.T = int_{-a}^{+a} phi(t+mu)*psi(t+mu).T
69 | # - (int_{-a}^{+a} phi(t+mu)) * (int_{-a}^{+a} psi(t+mu).T) / (2*a)
70 | left = (mu - a).unsqueeze(1)
71 | right = (mu + a).unsqueeze(1)
72 | i_phi = cls._integrate_phi(ctx, left, right)
73 | ger = torch.bmm(i_phi.unsqueeze(2), u.unsqueeze(1))
74 | # ger = torch.einsum('bi,bj->bij', (i_phi, u))
75 | J = V - ger / (2 * a.unsqueeze(1).unsqueeze(2))
76 | grad_input = torch.matmul(J, grad_output.unsqueeze(2)).squeeze(2)
77 | return grad_input, None
78 |
79 |
80 | class ContinuousSparsemax(nn.Module):
81 | def __init__(self, psi=None):
82 | super(ContinuousSparsemax, self).__init__()
83 | self.psi = psi
84 |
85 | def forward(self, theta):
86 | return ContinuousSparsemaxFunction.apply(theta, self.psi)
87 |
--------------------------------------------------------------------------------
/experiment/longseq_formers/model/memoria_bert/__init__.py:
--------------------------------------------------------------------------------
1 | from .configuration_memoria_bert import MemoriaBertConfig
2 | from .modeling_memoria_bert import MemoriaBertForSequenceClassification, MemoriaBertModel
3 |
4 | __all__ = [
5 | "MemoriaBertConfig",
6 | "MemoriaBertForSequenceClassification",
7 | "MemoriaBertModel",
8 | ]
9 |
--------------------------------------------------------------------------------
/experiment/longseq_formers/model/memoria_bert/configuration_memoria_bert.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ BERT model configuration"""
17 |
18 | from typing import Optional
19 |
20 | from transformers import AutoConfig
21 | from transformers.configuration_utils import PretrainedConfig
22 | from transformers.utils import logging
23 |
24 | logger = logging.get_logger(__name__)
25 |
26 |
27 | class MemoriaBertConfig(PretrainedConfig):
28 | r"""
29 | Args:
30 | vocab_size (`int`, *optional*, defaults to 30522):
31 | Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
32 | `inputs_ids` passed when calling [`BertModel`] or [`TFBertModel`].
33 | hidden_size (`int`, *optional*, defaults to 768):
34 | Dimensionality of the encoder layers and the pooler layer.
35 | num_hidden_layers (`int`, *optional*, defaults to 12):
36 | Number of hidden layers in the Transformer encoder.
37 | num_attention_heads (`int`, *optional*, defaults to 12):
38 | Number of attention heads for each attention layer in the Transformer encoder.
39 | intermediate_size (`int`, *optional*, defaults to 3072):
40 | Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
41 | hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
42 | The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
43 | `"relu"`, `"silu"` and `"gelu_new"` are supported.
44 | hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
45 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
46 | attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
47 | The dropout ratio for the attention probabilities.
48 | max_position_embeddings (`int`, *optional*, defaults to 512):
49 | The maximum sequence length that this model might ever be used with. Typically set this to something large
50 | just in case (e.g., 512 or 1024 or 2048).
51 | type_vocab_size (`int`, *optional*, defaults to 2):
52 | The vocabulary size of the `token_type_ids` passed when calling [`BertModel`] or [`TFBertModel`].
53 | initializer_range (`float`, *optional*, defaults to 0.02):
54 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
55 | layer_norm_eps (`float`, *optional*, defaults to 1e-12):
56 | The epsilon used by the layer normalization layers.
57 | position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
58 | Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
59 | positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
60 | [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
61 | For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
62 | with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
63 | use_cache (`bool`, *optional*, defaults to `True`):
64 | Whether or not the model should return the last key/values attentions (not used by all models). Only
65 | relevant if `config.is_decoder=True`.
66 | classifier_dropout (`float`, *optional*):
67 | The dropout ratio for the classification head.
68 | ```"""
69 | model_type = "memoria_bert"
70 |
71 | def __init__(
72 | self,
73 | vocab_size=30522,
74 | hidden_size=768,
75 | num_hidden_layers=12,
76 | num_attention_heads=12,
77 | intermediate_size=3072,
78 | hidden_act="gelu",
79 | hidden_dropout_prob=0.1,
80 | attention_probs_dropout_prob=0.1,
81 | max_position_embeddings=512,
82 | type_vocab_size=2,
83 | initializer_range=0.02,
84 | layer_norm_eps=1e-12,
85 | pad_token_id=0,
86 | position_embedding_type="absolute",
87 | use_cache=True,
88 | classifier_dropout=None,
89 | memory_layer_index: int = 9,
90 | memoria_num_memories: float = 64,
91 | memoria_lifespan_extend_scale: float = 8.0,
92 | memoria_num_reminded_stm: float = 64,
93 | memoria_num_reminded_ltm: float = 64,
94 | memoria_stm_capacity: int = 128,
95 | memoria_ltm_search_depth: int = 10,
96 | memoria_initial_lifespan: int = 12,
97 | memoria_reset_period: int = 500,
98 | memoria_device: Optional[str] = None,
99 | **kwargs
100 | ):
101 | super().__init__(pad_token_id=pad_token_id, **kwargs)
102 |
103 | self.vocab_size = vocab_size
104 | self.hidden_size = hidden_size
105 | self.num_hidden_layers = num_hidden_layers
106 | self.num_attention_heads = num_attention_heads
107 | self.hidden_act = hidden_act
108 | self.intermediate_size = intermediate_size
109 | self.hidden_dropout_prob = hidden_dropout_prob
110 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
111 | self.max_position_embeddings = max_position_embeddings
112 | self.type_vocab_size = type_vocab_size
113 | self.initializer_range = initializer_range
114 | self.layer_norm_eps = layer_norm_eps
115 | self.position_embedding_type = position_embedding_type
116 | self.use_cache = use_cache
117 | self.classifier_dropout = classifier_dropout
118 |
119 | self.memory_layer_index: int = memory_layer_index
120 | self.memoria_num_memories: int = memoria_num_memories
121 | self.memoria_lifespan_extend_scale: float = memoria_lifespan_extend_scale
122 | self.memoria_num_reminded_stm: int = memoria_num_reminded_stm
123 | self.memoria_num_reminded_ltm: int = memoria_num_reminded_ltm
124 | self.memoria_stm_capacity: int = memoria_stm_capacity
125 | self.memoria_ltm_search_depth: int = memoria_ltm_search_depth
126 | self.memoria_initial_lifespan: int = memoria_initial_lifespan
127 | self.memoria_reset_period: int = memoria_reset_period
128 | self.memoria_device: Optional[str] = memoria_device
129 |
130 |
131 | MemoriaBertConfig.register_for_auto_class()
132 | AutoConfig.register("memoria_bert", MemoriaBertConfig)
133 |
--------------------------------------------------------------------------------
/experiment/longseq_formers/model/memoria_roberta/__init__.py:
--------------------------------------------------------------------------------
1 | from .configuration_memoria_roberta import MemoriaRobertaConfig
2 | from .modeling_memoria_roberta import MemoriaRobertaForSequenceClassification, MemoriaRobertaModel
3 |
4 | __all__ = [
5 | "MemoriaRobertaConfig",
6 | "MemoriaRobertaForSequenceClassification",
7 | "MemoriaRobertaModel",
8 | ]
9 |
--------------------------------------------------------------------------------
/experiment/longseq_formers/model/memoria_roberta/configuration_memoria_roberta.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ RoBERTa configuration"""
17 |
18 | from typing import Optional
19 |
20 | from transformers import AutoConfig
21 | from transformers.configuration_utils import PretrainedConfig
22 | from transformers.utils import logging
23 |
24 | logger = logging.get_logger(__name__)
25 |
26 |
27 | class MemoriaRobertaConfig(PretrainedConfig):
28 | r"""
29 | This is the configuration class to store the configuration of a [`RobertaModel`] or a [`TFRobertaModel`]. It is
30 | used to instantiate a RoBERTa model according to the specified arguments, defining the model architecture.
31 | Instantiating a configuration with the defaults will yield a similar configuration to that of the RoBERTa
32 | [roberta-base](https://huggingface.co/roberta-base) architecture.
33 |
34 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35 | documentation from [`PretrainedConfig`] for more information.
36 |
37 |
38 | Args:
39 | vocab_size (`int`, *optional*, defaults to 30522):
40 | Vocabulary size of the RoBERTa model. Defines the number of different tokens that can be represented by the
41 | `inputs_ids` passed when calling [`RobertaModel`] or [`TFRobertaModel`].
42 | hidden_size (`int`, *optional*, defaults to 768):
43 | Dimensionality of the encoder layers and the pooler layer.
44 | num_hidden_layers (`int`, *optional*, defaults to 12):
45 | Number of hidden layers in the Transformer encoder.
46 | num_attention_heads (`int`, *optional*, defaults to 12):
47 | Number of attention heads for each attention layer in the Transformer encoder.
48 | intermediate_size (`int`, *optional*, defaults to 3072):
49 | Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
50 | hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
51 | The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
52 | `"relu"`, `"silu"` and `"gelu_new"` are supported.
53 | hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
54 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
55 | attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
56 | The dropout ratio for the attention probabilities.
57 | max_position_embeddings (`int`, *optional*, defaults to 512):
58 | The maximum sequence length that this model might ever be used with. Typically set this to something large
59 | just in case (e.g., 512 or 1024 or 2048).
60 | type_vocab_size (`int`, *optional*, defaults to 2):
61 | The vocabulary size of the `token_type_ids` passed when calling [`RobertaModel`] or [`TFRobertaModel`].
62 | initializer_range (`float`, *optional*, defaults to 0.02):
63 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
64 | layer_norm_eps (`float`, *optional*, defaults to 1e-12):
65 | The epsilon used by the layer normalization layers.
66 | position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
67 | Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
68 | positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
69 | [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
70 | For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
71 | with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
72 | use_cache (`bool`, *optional*, defaults to `True`):
73 | Whether or not the model should return the last key/values attentions (not used by all models). Only
74 | relevant if `config.is_decoder=True`.
75 | classifier_dropout (`float`, *optional*):
76 | The dropout ratio for the classification head.
77 |
78 | Examples:
79 |
80 | ```python
81 | >>> from transformers import RobertaConfig, RobertaModel
82 |
83 | >>> # Initializing a RoBERTa configuration
84 | >>> configuration = RobertaConfig()
85 |
86 | >>> # Initializing a model (with random weights) from the configuration
87 | >>> model = RobertaModel(configuration)
88 |
89 | >>> # Accessing the model configuration
90 | >>> configuration = model.config
91 | ```"""
92 | model_type = "memoria_roberta"
93 |
94 | def __init__(
95 | self,
96 | vocab_size=30522,
97 | hidden_size=768,
98 | num_hidden_layers=12,
99 | num_attention_heads=12,
100 | intermediate_size=3072,
101 | hidden_act="gelu",
102 | hidden_dropout_prob=0.1,
103 | attention_probs_dropout_prob=0.1,
104 | max_position_embeddings=512,
105 | type_vocab_size=2,
106 | initializer_range=0.02,
107 | layer_norm_eps=1e-12,
108 | pad_token_id=1,
109 | bos_token_id=0,
110 | eos_token_id=2,
111 | position_embedding_type="absolute",
112 | use_cache=True,
113 | classifier_dropout=None,
114 | memory_layer_index: int = 9,
115 | memoria_num_memories: float = 64,
116 | memoria_lifespan_extend_scale: float = 8.0,
117 | memoria_num_reminded_stm: float = 64,
118 | memoria_num_reminded_ltm: float = 64,
119 | memoria_stm_capacity: int = 128,
120 | memoria_ltm_search_depth: int = 10,
121 | memoria_initial_lifespan: int = 12,
122 | memoria_reset_period: int = 500,
123 | memoria_device: Optional[str] = None,
124 | **kwargs
125 | ):
126 | super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
127 |
128 | self.vocab_size = vocab_size
129 | self.hidden_size = hidden_size
130 | self.num_hidden_layers = num_hidden_layers
131 | self.num_attention_heads = num_attention_heads
132 | self.hidden_act = hidden_act
133 | self.intermediate_size = intermediate_size
134 | self.hidden_dropout_prob = hidden_dropout_prob
135 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
136 | self.max_position_embeddings = max_position_embeddings
137 | self.type_vocab_size = type_vocab_size
138 | self.initializer_range = initializer_range
139 | self.layer_norm_eps = layer_norm_eps
140 | self.position_embedding_type = position_embedding_type
141 | self.use_cache = use_cache
142 | self.classifier_dropout = classifier_dropout
143 |
144 | self.memory_layer_index: int = memory_layer_index
145 | self.memoria_num_memories: int = memoria_num_memories
146 | self.memoria_lifespan_extend_scale: float = memoria_lifespan_extend_scale
147 | self.memoria_num_reminded_stm: int = memoria_num_reminded_stm
148 | self.memoria_num_reminded_ltm: int = memoria_num_reminded_ltm
149 | self.memoria_stm_capacity: int = memoria_stm_capacity
150 | self.memoria_ltm_search_depth: int = memoria_ltm_search_depth
151 | self.memoria_initial_lifespan: int = memoria_initial_lifespan
152 | self.memoria_reset_period: int = memoria_reset_period
153 | self.memoria_device: Optional[str] = memoria_device
154 |
155 |
156 | MemoriaRobertaConfig.register_for_auto_class()
157 | AutoConfig.register("memoria_roberta", MemoriaRobertaConfig)
158 |
--------------------------------------------------------------------------------
/experiment/longseq_formers/task/__init__.py:
--------------------------------------------------------------------------------
1 | from .classification import Classification
2 | from .language_modeling import LanguageModeling
3 | from .synthetic import Synthetic
4 |
5 | __all__ = ["Classification", "LanguageModeling", "Synthetic"]
6 |
--------------------------------------------------------------------------------
/experiment/longseq_formers/task/classification.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Literal, Optional
2 |
3 | import pytorch_lightning as pl
4 | import torch
5 | import torch.nn.functional as F
6 | from torchmetrics.classification import Accuracy, MulticlassF1Score
7 | from torchmetrics.collections import MetricCollection
8 | from transformers import AutoConfig, AutoModelForSequenceClassification, get_linear_schedule_with_warmup
9 |
10 |
11 | class Classification(pl.LightningModule):
12 | """Classification
13 |
14 | Attributes:
15 | model: model for classification
16 | num_classes: the number of classes
17 | total_steps: total training steps for lr scheduling
18 | learning_rate: Max LR
19 | warmup_rate: warmup step rate
20 | """
21 |
22 | def __init__(
23 | self,
24 | model: AutoModelForSequenceClassification,
25 | num_classes: int,
26 | total_steps: int,
27 | learning_rate: float,
28 | warmup_rate: float,
29 | segment_size: Optional[int] = None,
30 | aggregate: Literal["mean", "last"] = "mean",
31 | eval_aggregate: Literal["mean", "last"] = "last",
32 | ):
33 | super().__init__()
34 |
35 | self.model = model
36 | self.num_classes = num_classes
37 | self.total_steps = total_steps
38 | self.learning_rate = learning_rate
39 | self.warmup_rate = warmup_rate
40 | self.segment_size = segment_size
41 | self.aggregate = aggregate
42 | self.eval_aggregate = eval_aggregate
43 | self.automatic_optimization = False
44 |
45 | metric_collection = MetricCollection(
46 | {
47 | "acc": Accuracy(task="multiclass", top_k=1, num_classes=self.num_classes),
48 | "f1": MulticlassF1Score(task="multiclass", num_classes=self.num_classes, average="macro"),
49 | }
50 | )
51 | self.train_metrics = metric_collection.clone(prefix="train/")
52 | self.val_metrics = metric_collection.clone(prefix="val/")
53 | self.test_metrics = metric_collection.clone(prefix="test/")
54 | self.metrics = {"train/": self.train_metrics, "val/": self.val_metrics, "test/": self.test_metrics}
55 |
56 | self.save_hyperparameters(
57 | {
58 | "model": None,
59 | "model_config": model.config.to_dict() if model is not None else None,
60 | "num_classes": num_classes,
61 | "total_steps": total_steps,
62 | "learning_rate": learning_rate,
63 | "warmup_rate": warmup_rate,
64 | "segment_size": segment_size,
65 | "aggregate": aggregate,
66 | "eval_aggregate": eval_aggregate,
67 | }
68 | )
69 |
70 | def _single_step(self, batch: dict[str, torch.Tensor], batch_idx: int, prefix="") -> dict[str, float]:
71 | """Common step function
72 |
73 | Args:
74 | batch: training batch input/label
75 | Returns:
76 | metrics dictionary of this train step
77 | """
78 | labels = batch.pop("labels")
79 |
80 | outputs = self.model(**batch)
81 | logits = outputs.logits
82 |
83 | ce_loss = F.cross_entropy(logits, labels, reduction="none")
84 | loss = ce_loss
85 | other_metrics = {"ce_loss": ce_loss.mean()}
86 | if self.model.config.model_type == "memoria_bert":
87 | ltm_mask = self.model.bert.encoder.memoria.engrams.longterm_memory_mask
88 | other_metrics["num_ltms_per_batch"] = (
89 | ltm_mask.sum(dim=1).float().mean(dim=0)
90 | if ltm_mask.numel() > 0
91 | else torch.tensor(0.0, device=loss.device)
92 | )
93 | if self.model.config.model_type == "memoria_roberta":
94 | ltm_mask = self.model.roberta.encoder.memoria.engrams.longterm_memory_mask
95 | other_metrics["num_ltms_per_batch"] = (
96 | ltm_mask.sum(dim=1).float().mean(dim=0)
97 | if ltm_mask.numel() > 0
98 | else torch.tensor(0.0, device=loss.device)
99 | )
100 | other_metrics["loss"] = loss
101 |
102 | other_metrics = {prefix + k: v for k, v in other_metrics.items()}
103 | return other_metrics, logits.detach(), labels.detach()
104 |
105 | def _segment_step(
106 | self,
107 | batch: dict[str, torch.Tensor],
108 | batch_idx: int,
109 | aggregate: Literal["mean", "last"],
110 | prefix="",
111 | ) -> dict[str, float]:
112 | batch_size, length = batch["input_ids"].shape
113 | num_valid_segments = batch["attention_mask"][:, :: self.segment_size].sum(dim=1)
114 | all_metrics = []
115 | all_probs = []
116 | indices = list(range(0, length, self.segment_size))
117 | prev_indices = [None] + indices[:-1]
118 | post_indices = indices[1:] + [None]
119 | final_loss = 0.0
120 | for pre_i, i, post_i in zip(prev_indices, indices, post_indices):
121 | segment_batch = {k: v[:, i : i + self.segment_size] if k != "labels" else v for k, v in batch.items()}
122 | pre_batch = (
123 | {k: v[:, pre_i : pre_i + self.segment_size] if k != "labels" else v for k, v in batch.items()}
124 | if pre_i is not None
125 | else None
126 | )
127 | post_batch = (
128 | {k: v[:, post_i : post_i + self.segment_size] if k != "labels" else v for k, v in batch.items()}
129 | if post_i is not None
130 | else None
131 | )
132 |
133 | current_valid = segment_batch["attention_mask"].bool().any(dim=1)
134 | is_last = current_valid
135 | if pre_batch is not None:
136 | pre_valid = pre_batch["attention_mask"].bool().any(dim=1)
137 | is_last &= pre_valid
138 | if post_batch is not None:
139 | post_valid = post_batch["attention_mask"].bool().any(dim=1)
140 | is_last &= ~post_valid
141 |
142 | segment_metrics, logits, labels = self._single_step(segment_batch, batch_idx, prefix)
143 | if aggregate == "last":
144 | loss = segment_metrics[f"{prefix}loss"] / batch_size
145 | loss = loss[is_last].sum()
146 | final_loss += loss.item()
147 |
148 | if logits[is_last].numel():
149 | self.metrics[prefix].update(logits[is_last], labels[is_last])
150 | segment_metrics[f"{prefix}loss"] = loss
151 | elif aggregate == "mean":
152 | loss = segment_metrics[f"{prefix}loss"].mean() / len(indices)
153 | final_loss += loss.item()
154 |
155 | probs = logits.softmax(dim=-1)
156 | probs[~current_valid] = 0.0
157 | all_probs.append(probs)
158 |
159 | segment_metrics[f"{prefix}loss"] = loss
160 | else:
161 | raise ValueError(f"Unknown aggregate method: {aggregate}")
162 |
163 | if prefix == "train/":
164 | self.manual_backward(loss)
165 |
166 | all_metrics.append(segment_metrics)
167 | if aggregate == "mean":
168 | all_metrics = {
169 | k: torch.stack([m[k] for m in all_metrics], dim=0).mean(dim=0) for k in all_metrics[0].keys()
170 | }
171 | mean_logits = torch.stack(all_probs, dim=-1).mean(dim=-1)
172 | self.metrics[prefix].update(mean_logits, labels)
173 | segment_metrics = all_metrics
174 |
175 | segment_metrics.update(self.metrics[prefix].compute())
176 | return segment_metrics
177 |
178 | def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> dict[str, float]:
179 | """Train step function"""
180 | opt = self.optimizers()
181 | sch = self.lr_schedulers()
182 | opt.zero_grad()
183 |
184 | if self.segment_size:
185 | metrics = self._segment_step(batch=batch, batch_idx=batch_idx, aggregate=self.aggregate, prefix="train/")
186 | else:
187 | metrics, logits, labels = self._single_step(batch=batch, batch_idx=batch_idx, prefix="train/")
188 | metrics = {k: v.mean() for k, v in metrics.items()}
189 | self.manual_backward(metrics["train/loss"])
190 | metrics.update(self.metrics["train/"](logits, labels))
191 |
192 | opt.step()
193 | if sch is not None:
194 | sch.step()
195 |
196 | self.log_dict(metrics, prog_bar=True, logger=True, on_step=True, sync_dist=True)
197 | return metrics
198 |
199 | def validation_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> dict[str, float]:
200 | """Validation step function"""
201 | if self.segment_size:
202 | metrics = self._segment_step(batch=batch, batch_idx=batch_idx, aggregate=self.eval_aggregate, prefix="val/")
203 | else:
204 | metrics, logits, labels = self._single_step(batch=batch, batch_idx=batch_idx, prefix="val/")
205 | metrics = {k: v.mean() for k, v in metrics.items()}
206 | metrics.update(self.metrics["val/"](logits, labels))
207 | self.log_dict(metrics, prog_bar=True, logger=True, on_step=True, sync_dist=True)
208 | return metrics
209 |
210 | def test_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> dict[str, float]:
211 | """Test step function"""
212 | if self.segment_size:
213 | metrics = self._segment_step(
214 | batch=batch, batch_idx=batch_idx, aggregate=self.eval_aggregate, prefix="test/"
215 | )
216 | else:
217 | metrics, logits, labels = self._single_step(batch=batch, batch_idx=batch_idx, prefix="test/")
218 | metrics = {k: v.mean() for k, v in metrics.items()}
219 | metrics.update(self.metrics["test/"](logits, labels))
220 | self.log_dict(metrics, prog_bar=True, logger=True, on_step=True, sync_dist=True)
221 | return metrics
222 |
223 | def configure_optimizers(self) -> Dict:
224 | optimizer = torch.optim.Adam(params=self.model.parameters(), lr=self.learning_rate)
225 | optimizers = {"optimizer": optimizer}
226 |
227 | if self.warmup_rate is not None:
228 | scheduler = get_linear_schedule_with_warmup(
229 | optimizer,
230 | num_warmup_steps=int(self.total_steps * self.warmup_rate),
231 | num_training_steps=self.total_steps,
232 | )
233 | optimizers["lr_scheduler"] = {"scheduler": scheduler, "interval": "step", "name": "Learning Rate"}
234 |
235 | return optimizers
236 |
237 | def on_save_checkpoint(self, checkpoint: dict[str, Any]):
238 | checkpoint["model_config"] = self.model.config.to_dict()
239 | checkpoint["model_type"] = self.model.config.model_type
240 |
241 | def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
242 | config_dict = checkpoint["model_config"]
243 | config_cls = AutoConfig.for_model(checkpoint["model_type"])
244 | config = config_cls.from_dict(config_dict)
245 | self.model = AutoModelForSequenceClassification.from_config(config)
246 | return super().on_load_checkpoint(checkpoint)
247 |
248 | def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
249 | self.metrics["train/"].reset()
250 | if self.model.config.model_type == "memoria_bert":
251 | self.model.bert.encoder.memoria.reset_memory()
252 | if self.model.config.model_type == "memoria_roberta":
253 | self.model.roberta.encoder.memoria.reset_memory()
254 |
255 | def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
256 | if self.model.config.model_type == "memoria_bert":
257 | self.model.bert.encoder.memoria.reset_memory()
258 | if self.model.config.model_type == "memoria_roberta":
259 | self.model.roberta.encoder.memoria.reset_memory()
260 |
261 | def on_test_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
262 | if self.model.config.model_type == "memoria_bert":
263 | self.model.bert.encoder.memoria.reset_memory()
264 | if self.model.config.model_type == "memoria_roberta":
265 | self.model.roberta.encoder.memoria.reset_memory()
266 |
267 | def _epoch_end(self, outputs, prefix: str = "") -> None:
268 | results = self.metrics[prefix].compute()
269 | results = {k + "_final": v for k, v in results.items()}
270 | self.metrics[prefix].reset()
271 | self.log_dict(results, logger=True, sync_dist=True)
272 |
273 | def validation_epoch_end(self, outputs) -> None:
274 | return self._epoch_end(outputs, prefix="val/")
275 |
276 | def test_epoch_end(self, outputs) -> None:
277 | return self._epoch_end(outputs, prefix="test/")
278 |
--------------------------------------------------------------------------------
/experiment/longseq_formers/task/language_modeling.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Optional
2 |
3 | import pytorch_lightning as pl
4 | import torch
5 | from transformers import AutoConfig, AutoModelForCausalLM, get_linear_schedule_with_warmup
6 |
7 |
8 | class LanguageModeling(pl.LightningModule):
9 | """LanguageModeling
10 |
11 | Attributes:
12 | model: model for language modeling
13 | total_steps: total training steps for lr scheduling
14 | learning_rate: Max LR
15 | warmup_rate: warmup step rate
16 | """
17 |
18 | def __init__(
19 | self, model: Optional[AutoModelForCausalLM], total_steps: int, learning_rate: float, warmup_rate: float
20 | ):
21 | super().__init__()
22 |
23 | self.model = model
24 | self.total_steps = total_steps
25 | self.learning_rate = learning_rate
26 | self.warmup_rate = warmup_rate
27 |
28 | self.save_hyperparameters(
29 | {
30 | "model": None,
31 | "total_steps": total_steps,
32 | "learning_rate": learning_rate,
33 | "warmup_rate": warmup_rate,
34 | "model_config": model.config.to_dict() if model is not None else None,
35 | }
36 | )
37 |
38 | def _step(self, batch: dict[str, torch.Tensor], batch_idx: int, prefix="") -> dict[str, float]:
39 | """Common step function
40 |
41 | Args:
42 | batch: training batch input/label
43 | Returns:
44 | metrics dictionary of this train step
45 | """
46 | is_end = batch.pop("is_end", None)
47 |
48 | if self.model.config.model_type in ["transfo-xl", "memoria-xl"]:
49 | del batch["attention_mask"]
50 | if hasattr(self, "_mems"):
51 | batch["mems"] = self._mems
52 | if hasattr(self, "_cmems"):
53 | batch["cmems"] = self._cmems
54 | outputs = self.model(**batch)
55 | lm_loss = outputs.loss
56 |
57 | if self.model.config.model_type in ["compressive_transformer"]:
58 | lm_loss = outputs.lm_loss
59 | if hasattr(outputs, "mems"):
60 | self._mems = outputs.mems
61 | if hasattr(outputs, "cmems"):
62 | self._cmems = outputs.cmems
63 |
64 | loss = outputs.loss
65 | ppl = lm_loss.detach().exp()
66 | metrics = {"loss": loss, "lm_loss": lm_loss, "ppl": ppl}
67 | if self.model.config.model_type in ["gpt2_with_memoria"]:
68 | ltm_mask = self.model.transformer.memoria.engrams.longterm_memory_mask
69 | metrics["num_ltms_per_batch"] = ltm_mask.sum(dim=1).float().mean(dim=0) if ltm_mask.numel() > 0 else 0.0
70 | metrics = {prefix + k: v for k, v in metrics.items()}
71 | return metrics
72 |
73 | def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> dict[str, float]:
74 | """Train step function"""
75 | metrics = self._step(batch=batch, batch_idx=batch_idx, prefix="")
76 | self.log_dict(metrics, prog_bar=True, logger=True, on_step=True, sync_dist=True)
77 | return metrics
78 |
79 | def validation_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> dict[str, float]:
80 | """Validation step function"""
81 | metrics = self._step(batch=batch, batch_idx=batch_idx, prefix="val/")
82 | self.log_dict(metrics, prog_bar=True, logger=True, on_step=True, sync_dist=True)
83 | return metrics
84 |
85 | def test_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> dict[str, float]:
86 | """Test step function"""
87 | metrics = self._step(batch=batch, batch_idx=batch_idx, prefix="test/")
88 | self.log_dict(metrics, prog_bar=True, logger=True, on_step=True, sync_dist=True)
89 | return metrics
90 |
91 | def configure_optimizers(self) -> Dict:
92 | optimizer = torch.optim.Adam(params=self.model.parameters(), lr=self.learning_rate)
93 | optimizers = {"optimizer": optimizer}
94 |
95 | scheduler = get_linear_schedule_with_warmup(
96 | optimizer,
97 | num_warmup_steps=int(self.total_steps * self.warmup_rate) if self.warmup_rate else 0,
98 | num_training_steps=self.total_steps,
99 | )
100 | optimizers["lr_scheduler"] = {"scheduler": scheduler, "interval": "step", "name": "Learning Rate"}
101 |
102 | return optimizers
103 |
104 | def reset_memories(self) -> None:
105 | if self.model.config.model_type in ["gpt2_with_memoria"]:
106 | self.model.transformer.memoria.reset_memory()
107 | self.model.transformer.prev_hidden = None
108 | if self.model.config.model_type in ["transfo-xl"] and hasattr(self, "_mems"):
109 | del self._mems
110 | if self.model.config.model_type == "compressive_transformer":
111 | if hasattr(self, "_mems"):
112 | del self._mems
113 | if hasattr(self, "_cmems"):
114 | del self._cmems
115 |
116 | def on_train_start(self) -> None:
117 | self.reset_memories()
118 |
119 | def on_train_end(self) -> None:
120 | self.reset_memories()
121 |
122 | def on_validation_start(self) -> None:
123 | self.reset_memories()
124 |
125 | def on_validation_end(self) -> None:
126 | self.reset_memories()
127 |
128 | def on_test_start(self) -> None:
129 | self.reset_memories()
130 |
131 | def on_test_end(self) -> None:
132 | self.reset_memories()
133 |
134 | def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
135 | if self.model.config.model_type in ["gpt2_with_memoria"]:
136 | if batch_idx % self.model.config.memoria_reset_period == 0:
137 | self.model.transformer.memoria.reset_memory()
138 | self.model.transformer.prev_hidden = None
139 |
140 | def on_save_checkpoint(self, checkpoint: dict[str, Any]):
141 | checkpoint["model_config"] = self.model.config.to_dict()
142 | checkpoint["model_type"] = self.model.config.model_type
143 |
144 | def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
145 | config_dict = checkpoint["model_config"]
146 | config_cls = AutoConfig.for_model(checkpoint["model_type"])
147 | config = config_cls.from_dict(config_dict)
148 | self.model = AutoModelForCausalLM.from_config(config)
149 | return super().on_load_checkpoint(checkpoint)
150 |
--------------------------------------------------------------------------------
/experiment/longseq_formers/task/synthetic.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Optional
2 |
3 | import pytorch_lightning as pl
4 | import torch
5 | import torch.nn.functional as F
6 | from torchmetrics.classification import Accuracy
7 | from transformers import AutoConfig, AutoModelForCausalLM, get_linear_schedule_with_warmup
8 |
9 |
10 | class Synthetic(pl.LightningModule):
11 | """Synthetic Task
12 |
13 | Attributes:
14 | model: model for classification
15 | num_classes: the number of classes
16 | total_steps: total training steps for lr scheduling
17 | learning_rate: Max LR
18 | warmup_rate: warmup step rate
19 | segment_size: segment size
20 | vocab_size: vocab size
21 | """
22 |
23 | def __init__(
24 | self,
25 | model: AutoModelForCausalLM,
26 | total_steps: int,
27 | learning_rate: float,
28 | warmup_rate: float,
29 | segment_size: int,
30 | vocab_size: int,
31 | max_grad_norm: Optional[float] = None,
32 | ):
33 | super().__init__()
34 |
35 | self.model = model
36 | self.total_steps = total_steps
37 | self.learning_rate = learning_rate
38 | self.warmup_rate = warmup_rate
39 | self.segment_size = segment_size
40 | self.vocab_size = vocab_size
41 | self.max_grad_norm = max_grad_norm
42 | self.automatic_optimization = False
43 |
44 | self.train_acc = Accuracy(task="multiclass", top_k=1, num_classes=vocab_size, ignore_index=-100)
45 | self.valid_acc = Accuracy(task="multiclass", top_k=1, num_classes=vocab_size, ignore_index=-100)
46 | self.test_acc = Accuracy(task="multiclass", top_k=1, num_classes=vocab_size, ignore_index=-100)
47 | self.accs = {"train": self.train_acc, "val": self.valid_acc, "test": self.test_acc}
48 |
49 | self.save_hyperparameters(
50 | {
51 | "model": None,
52 | "model_config": model.config.to_dict() if model is not None else None,
53 | "total_steps": total_steps,
54 | "learning_rate": learning_rate,
55 | "warmup_rate": warmup_rate,
56 | "segment_size": segment_size,
57 | "vocab_size": vocab_size,
58 | "max_grad_norm": max_grad_norm,
59 | }
60 | )
61 |
62 | def _step(self, batch: dict[str, torch.Tensor], batch_idx: int, prefix: str) -> dict[str, float]:
63 | """Train step function"""
64 | batch_size, length = batch["input_ids"].size()
65 | num_valid_labels = (batch["labels"] != -100).sum(dim=1)
66 | indices = range(0, length, self.segment_size)
67 | loss_mean = 0.0
68 | acc = self.accs[prefix]
69 | for i in indices:
70 | segment_batch = {k: v[:, i : i + self.segment_size] for k, v in batch.items()}
71 | labels = segment_batch.pop("labels")
72 | if hasattr(self, "_mems"):
73 | segment_batch["mems"] = self._mems
74 | if hasattr(self, "_cmems"):
75 | segment_batch["cmems"] = self._cmems
76 |
77 | use_grad = prefix == "train" and (labels != -100).any().item()
78 | with torch.set_grad_enabled(use_grad):
79 | outputs = self.model(**segment_batch)
80 | if self.model.config.model_type in ["transfo-xl", "memoria-xl"]:
81 | self._mems = outputs.mems
82 |
83 | loss = (
84 | F.cross_entropy(
85 | outputs.logits.view(-1, outputs.logits.size(-1)),
86 | labels.reshape(-1),
87 | ignore_index=-100,
88 | reduction="none",
89 | )
90 | .view(batch_size, -1)
91 | .sum(dim=1)
92 | / num_valid_labels
93 | ).mean()
94 |
95 | if use_grad:
96 | self.manual_backward(loss)
97 |
98 | loss_mean += loss.item()
99 | preds = outputs.logits.argmax(dim=-1)
100 | acc.update(preds=preds, target=labels)
101 |
102 | metrics = {"loss": loss_mean, "acc": acc.compute()}
103 | metrics = {f"{prefix}/{k}": v for k, v in metrics.items()}
104 | return metrics
105 |
106 | def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> dict[str, float]:
107 | """Train step function"""
108 | opt = self.optimizers()
109 | sch = self.lr_schedulers()
110 |
111 | metrics = self._step(batch=batch, batch_idx=batch_idx, prefix="train")
112 |
113 | if self.max_grad_norm is not None:
114 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
115 | opt.step()
116 | if sch is not None:
117 | sch.step()
118 | opt.zero_grad()
119 |
120 | self.train_acc.reset()
121 | self.log_dict(metrics, prog_bar=True, logger=True, on_step=True, sync_dist=True)
122 | return metrics
123 |
124 | def validation_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> dict[str, float]:
125 | """Validation step function"""
126 | metrics = self._step(batch=batch, batch_idx=batch_idx, prefix="val")
127 | self.log_dict(metrics, prog_bar=True, logger=True, on_step=True, on_epoch=False, sync_dist=True)
128 | return metrics
129 |
130 | def test_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> dict[str, float]:
131 | """Validation step function"""
132 | metrics = self._step(batch=batch, batch_idx=batch_idx, prefix="test")
133 | self.log_dict(metrics, prog_bar=True, logger=True, on_step=True, on_epoch=False, sync_dist=True)
134 | return metrics
135 |
136 | def validation_epoch_end(self, outputs):
137 | val_acc = self.valid_acc.compute()
138 | self.valid_acc.reset()
139 | self.log("val/acc-final", val_acc, logger=True, on_step=False, sync_dist=True)
140 |
141 | def test_epoch_end(self, outputs):
142 | test_acc = self.test_acc.compute()
143 | self.test_acc.reset()
144 | self.log("test/acc-final", test_acc, logger=True, on_step=False, sync_dist=True)
145 |
146 | def configure_optimizers(self) -> Dict:
147 | optimizer = torch.optim.Adam(params=self.model.parameters(), lr=self.learning_rate)
148 | optimizers = {"optimizer": optimizer}
149 |
150 | if self.warmup_rate is not None:
151 | scheduler = get_linear_schedule_with_warmup(
152 | optimizer,
153 | num_warmup_steps=int(self.total_steps * self.warmup_rate),
154 | num_training_steps=self.total_steps,
155 | )
156 | optimizers["lr_scheduler"] = {"scheduler": scheduler, "interval": "step", "name": "Learning Rate"}
157 |
158 | return optimizers
159 |
160 | def on_save_checkpoint(self, checkpoint: dict[str, Any]):
161 | checkpoint["model_config"] = self.model.config.to_dict()
162 | checkpoint["model_type"] = self.model.config.model_type
163 |
164 | def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
165 | config_dict = checkpoint["model_config"]
166 | config_cls = AutoConfig.for_model(checkpoint["model_type"])
167 | config = config_cls.from_dict(config_dict)
168 | self.model = AutoModelForCausalLM.from_config(config)
169 | return super().on_load_checkpoint(checkpoint)
170 |
171 | def reset_memories(self) -> None:
172 | if self.model.config.model_type in ["gpt2_with_memoria", "memoria-xl"]:
173 | self.model.transformer.memoria.reset_memory()
174 | self.model.transformer.prev_hidden = None
175 | if self.model.config.model_type in ["transfo-xl", "memoria-xl"] and hasattr(self, "_mems"):
176 | del self._mems
177 | if self.model.config.model_type == "compressive_transformer":
178 | if hasattr(self, "_mems"):
179 | del self._mems
180 | if hasattr(self, "_cmems"):
181 | del self._cmems
182 | if self.model.config.model_type == "infinity_gpt2":
183 | self.model.reset_memories()
184 |
185 | def on_train_batch_start(self, batch, batch_idx) -> None:
186 | self.reset_memories()
187 |
188 | def on_validation_batch_start(self, batch, batch_idx, dataloader_idx) -> None:
189 | self.reset_memories()
190 |
191 | def on_test_batch_start(self, batch, batch_idx, dataloader_idx) -> None:
192 | self.reset_memories()
193 |
--------------------------------------------------------------------------------
/experiment/longseq_formers/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import sys
3 |
4 | import pytorch_lightning as pl
5 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
6 | from torch.utils.data.distributed import DistributedSampler
7 |
8 | from .dataset.language_modeling import LanguageModelingDataset
9 |
10 |
11 | def get_logger(name: str) -> logging.Logger:
12 | """Return logger for logging
13 |
14 | Args:
15 | name: logger name
16 | """
17 | logger = logging.getLogger(name)
18 | logger.propagate = False
19 | logger.setLevel(logging.DEBUG)
20 | if not logger.handlers:
21 | handler = logging.StreamHandler(sys.stdout)
22 | handler.setFormatter(logging.Formatter("[%(asctime)s] %(message)s"))
23 | logger.addHandler(handler)
24 | return logger
25 |
26 |
27 | class BatchedDataModule(pl.LightningDataModule):
28 | def __init__(
29 | self,
30 | train_dataset: LanguageModelingDataset,
31 | valid_dataset: LanguageModelingDataset,
32 | shuffle: bool,
33 | distributed: bool = True,
34 | ) -> None:
35 | super().__init__()
36 |
37 | self.train_dataset = train_dataset
38 | self.valid_dataset = valid_dataset
39 | self.shuffle = shuffle
40 | self.distributed = distributed
41 |
42 | def train_dataloader(self):
43 | # Use batch size as 1 because already batched
44 | if self.distributed:
45 | sampler = DistributedSampler(self.train_dataset, shuffle=self.shuffle)
46 | elif self.shuffle:
47 | sampler = RandomSampler(self.train_dataset)
48 | else:
49 | sampler = SequentialSampler(self.train_dataset)
50 | return DataLoader(self.train_dataset, batch_size=1, sampler=sampler, collate_fn=self.train_dataset.collate_fn)
51 |
52 | def val_dataloader(self):
53 | if self.distributed:
54 | sampler = DistributedSampler(self.valid_dataset, shuffle=False)
55 | else:
56 | sampler = SequentialSampler(self.valid_dataset)
57 | return DataLoader(self.valid_dataset, batch_size=1, sampler=sampler, collate_fn=self.valid_dataset.collate_fn)
58 |
--------------------------------------------------------------------------------
/experiment/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.13.1
2 | pytorch-lightning==1.8.6
3 | transformers==4.25.1
4 | mogrifier # for compressive transformer
5 |
6 | datasets
7 | scikit-learn
8 | bs4
9 | nltk
10 |
11 | wandb
12 |
13 | memoria-pytorch==1.0.0
14 |
--------------------------------------------------------------------------------
/experiment/train_classification.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import tempfile
4 | from typing import Dict
5 |
6 | import pytorch_lightning as pl
7 | import torch
8 | import wandb
9 | from longseq_formers.data import CLASSIFICATION_DATASETS, load_hyperpartisan_data
10 | from longseq_formers.dataset import ClassificationDataset
11 | from longseq_formers.task import Classification
12 | from longseq_formers.utils import get_logger
13 | from pytorch_lightning.callbacks import LearningRateMonitor
14 | from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
15 | from torch.utils.data import DataLoader
16 | from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer
17 |
18 | # fmt: off
19 | parser = argparse.ArgumentParser(prog="train_classification", description="Train & Test Long Sequence Classification")
20 |
21 | g = parser.add_argument_group("Train Parameter")
22 | g.add_argument("--model", type=str, required=True, help="huggingface model")
23 | g.add_argument("--model-type", type=str, help="specific model type")
24 | g.add_argument("--tokenizer", type=str, help="huggingface tokenizer")
25 | g.add_argument("--dataset", type=str, default="hyperpartisan", choices=CLASSIFICATION_DATASETS, help="dataset name")
26 | g.add_argument("--batch-size", type=int, default=8, help="global training batch size")
27 | g.add_argument("--valid-batch-size", type=int, default=32, help="validation batch size")
28 | g.add_argument("--accumulate-grad-batches", type=int, default=1, help="the number of gradident accumulation steps")
29 | g.add_argument("--max-length", type=int, default=512, help="max sequence length")
30 | g.add_argument("--memory-length", type=int, default=512, help="max sequence length for bert one inference on infinity former")
31 | g.add_argument("--epochs", type=int, default=20, help="the number of training epochs")
32 | g.add_argument("--learning-rate", type=float, default=3e-5, help="learning rate")
33 | g.add_argument("--warmup-rate", type=float, help="warmup step rate")
34 | g.add_argument("--seed", type=int, default=42, help="random seed")
35 | g.add_argument("--test-ckpt", type=str, default="last", choices=["best", "last"], help="checkpoint type for testing")
36 | g.add_argument("--not-truncate", action="store_false", dest="truncation", help="not truncate sequence")
37 | g.add_argument("--segment-size", type=int, help="segment size for infinity former")
38 |
39 | g = parser.add_argument_group("Personal Options")
40 | g.add_argument("--output-dir", type=str, help="output directory path to save artifacts")
41 | g.add_argument("--gpus", type=int, help="the number of gpus, use all devices by default")
42 | g.add_argument("--logging-interval", type=int, default=10, help="logging interval")
43 |
44 | g = parser.add_argument_group("Wandb Options")
45 | g.add_argument("--wandb-run-name", type=str, help="wanDB run name")
46 | g.add_argument("--wandb-entity", type=str, help="wanDB entity name")
47 | g.add_argument("--wandb-project", type=str, help="wanDB project name")
48 | # fmt: on
49 |
50 |
51 | def main(args: argparse.Namespace) -> dict[str, float]:
52 | logger = get_logger("train_classification")
53 |
54 | if args.output_dir:
55 | os.makedirs(args.output_dir)
56 | logger.info(f'[+] Save output to "{args.output_dir}"')
57 |
58 | logger.info(" ====== Arguements ======")
59 | for k, v in vars(args).items():
60 | logger.info(f"{k:25}: {v}")
61 |
62 | logger.info(f"[+] Set Random Seed to {args.seed}")
63 | pl.seed_everything(args.seed, workers=True)
64 |
65 | logger.info(f"[+] GPU: {args.gpus}")
66 |
67 | if args.tokenizer is None:
68 | if args.model:
69 | logger.info(f"[+] Use tokenizer same as model: {args.model}")
70 | args.tokenizer = args.model
71 | else:
72 | raise ValueError("you should set `--tokenizer` when use `--model-config`!")
73 | logger.info(f'[+] Load Tokenizer: "{args.tokenizer}"')
74 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
75 |
76 | logger.info(f'[+] Use Dataset: "{args.dataset}"')
77 | if args.dataset == "hyperpartisan":
78 | datasets = load_hyperpartisan_data()
79 | num_classes = 2
80 |
81 | train_dataset = ClassificationDataset(
82 | datasets["train"], tokenizer=tokenizer, max_length=args.max_length, truncation=args.truncation
83 | )
84 | valid_dataset = ClassificationDataset(
85 | datasets["dev"], tokenizer=tokenizer, max_length=args.max_length, truncation=args.truncation
86 | )
87 | test_dataset = ClassificationDataset(
88 | datasets["test"], tokenizer=tokenizer, max_length=args.max_length, truncation=args.truncation
89 | )
90 |
91 | logger.info(f"[+] # of train examples: {len(train_dataset)}")
92 | logger.info(f"[+] # of valid examples: {len(valid_dataset)}")
93 | logger.info(f"[+] # of test examples: {len(test_dataset)}")
94 |
95 | logger.info(f'[+] Load Model: "{args.model}"')
96 | if args.model_type:
97 | model_cls = type(AutoModelForSequenceClassification.from_config(AutoConfig.for_model(args.model_type)))
98 | else:
99 | model_cls = AutoModelForSequenceClassification
100 | model = model_cls.from_pretrained(args.model, num_labels=num_classes)
101 |
102 | if args.gpus is None:
103 | args.gpus = torch.cuda.device_count()
104 | num_parallels = max(args.gpus, 1)
105 | distributed = num_parallels > 1
106 | batch_size_per_device = max(args.batch_size // num_parallels, 1)
107 | global_batch_size = batch_size_per_device * args.gpus
108 | valid_batch_size_per_device = max(args.valid_batch_size // num_parallels, 1)
109 | global_valid_batch_size = valid_batch_size_per_device * num_parallels
110 | if args.batch_size != global_batch_size:
111 | logger.warning(f"[-] Batch size {args.batch_size} isn't dividable by {args.gpus}!")
112 | logger.warning(f"[-] Use batch size as {batch_size_per_device} per device, {global_batch_size} global")
113 | if args.valid_batch_size != global_valid_batch_size:
114 | logger.warning(f"[-] Valid Batch size {args.valid_batch_size} isn't dividable by {args.gpus}!")
115 | logger.warning(
116 | f"[-] Use batch size as {valid_batch_size_per_device} per device, {global_valid_batch_size} global"
117 | )
118 |
119 | collate_fn = ClassificationDataset.pad_collate_fn if not args.truncation else None
120 | train_dataloader = DataLoader(
121 | train_dataset,
122 | shuffle=True,
123 | batch_size=batch_size_per_device,
124 | num_workers=os.cpu_count() // 2,
125 | pin_memory=True,
126 | collate_fn=collate_fn,
127 | )
128 | valid_dataloader = DataLoader(
129 | valid_dataset,
130 | batch_size=args.valid_batch_size // num_parallels,
131 | collate_fn=collate_fn,
132 | )
133 | test_dataloader = DataLoader(
134 | test_dataset,
135 | batch_size=args.valid_batch_size // num_parallels,
136 | collate_fn=collate_fn,
137 | )
138 |
139 | total_steps = len(train_dataloader) * args.epochs
140 |
141 | classification = Classification(
142 | model=model,
143 | total_steps=total_steps,
144 | learning_rate=args.learning_rate,
145 | warmup_rate=args.warmup_rate,
146 | segment_size=args.segment_size,
147 | num_classes=num_classes,
148 | )
149 |
150 | if args.output_dir:
151 | train_loggers = [TensorBoardLogger(args.output_dir, "", "logs")]
152 | model_dir = os.path.join(args.output_dir, "checkpoint")
153 | else:
154 | train_loggers = []
155 | tmp_dir = tempfile.TemporaryDirectory()
156 | model_dir = tmp_dir.name
157 |
158 | logger.info(f"[+] Start Training")
159 | if args.wandb_project and (args.wandb_run_name or args.output_dir):
160 | wandb_logger = WandbLogger(
161 | name=args.wandb_run_name or os.path.basename(args.output_dir),
162 | project=args.wandb_project,
163 | entity=args.wandb_entity,
164 | save_dir=args.output_dir if args.output_dir else None,
165 | )
166 | wandb_logger.log_hyperparams({"train_arguments": vars(args)})
167 | train_loggers.append(wandb_logger)
168 |
169 | model_checkpoint_callback = pl.callbacks.ModelCheckpoint(
170 | model_dir, mode="max", monitor="val/f1_final", save_last=True, auto_insert_metric_name=True
171 | )
172 | callbacks = [model_checkpoint_callback]
173 |
174 | if train_loggers:
175 | callbacks.append(LearningRateMonitor(logging_interval="step"))
176 | trainer = pl.Trainer(
177 | logger=train_loggers,
178 | max_epochs=args.epochs,
179 | log_every_n_steps=args.logging_interval,
180 | accumulate_grad_batches=args.accumulate_grad_batches,
181 | callbacks=callbacks,
182 | strategy="ddp_fork" if distributed else None,
183 | accelerator="gpu" if args.gpus else None,
184 | devices=num_parallels,
185 | )
186 | trainer.fit(classification, train_dataloader, valid_dataloader)
187 |
188 | # Use seperated initialized trainer (https://github.com/Lightning-AI/lightning/issues/8375)
189 | tester = pl.Trainer(
190 | logger=train_loggers,
191 | callbacks=callbacks,
192 | accelerator="gpu" if args.gpus else None,
193 | devices=1,
194 | )
195 | result = tester.test(classification, test_dataloader, ckpt_path=args.test_ckpt)[0]
196 |
197 | wandb.finish()
198 |
199 | if not args.output_dir:
200 | tmp_dir.cleanup()
201 |
202 | return result
203 |
204 |
205 | if __name__ == "__main__":
206 | main(parser.parse_args())
207 | exit(0)
208 |
--------------------------------------------------------------------------------
/experiment/train_language_modeling.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import tempfile
4 | from typing import Dict
5 |
6 | import pytorch_lightning as pl
7 | import torch
8 | import wandb
9 | from longseq_formers.data import (
10 | LANGUAGE_MODELING_DATASETS,
11 | enwik8_tokenize,
12 | load_enwik8_data,
13 | load_pg19_data,
14 | load_wikitext103_data,
15 | )
16 | from longseq_formers.dataset import LanguageModelingDataset, text_to_tokens
17 | from longseq_formers.task import LanguageModeling
18 | from longseq_formers.utils import BatchedDataModule, get_logger
19 | from pytorch_lightning.callbacks import LearningRateMonitor
20 | from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
21 | from torch.utils.data import DataLoader
22 | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
23 |
24 | # fmt: off
25 | parser = argparse.ArgumentParser(prog="train", description="Train & Test Language Modeling")
26 |
27 | g = parser.add_argument_group("Train Parameter")
28 | g.add_argument("--model-config", type=str, help="huggingface model config")
29 | g.add_argument("--model", type=str, help="huggingface model")
30 | g.add_argument("--model-type", type=str, help="specific model type")
31 | g.add_argument("--tokenizer", type=str, help="huggingface tokenizer")
32 | g.add_argument("--dataset", type=str, default="wikitext103", choices=LANGUAGE_MODELING_DATASETS, help="dataset name")
33 | g.add_argument("--batch-size", type=int, default=8, help="global training batch size")
34 | g.add_argument("--valid-batch-size", type=int, default=1, help="validation batch size")
35 | g.add_argument("--accumulate-grad-batches", type=int, default=1, help="the number of gradident accumulation steps")
36 | g.add_argument("--max-length", type=int, default=150, help="max sequence length")
37 | g.add_argument("--epochs", type=int, default=6, help="the number of training epochs")
38 | g.add_argument("--learning-rate", type=float, default=2e-4, help="learning rate")
39 | g.add_argument("--warmup-rate", type=float, default=0.06, help="warmup step rate")
40 | g.add_argument("--max-grad-norm", type=float, default=1.0, help="maximum gradient norm")
41 | g.add_argument("--seed", type=int, default=42, help="random seed")
42 | g.add_argument("--shuffle", action="store_true", help="shuffle data order")
43 | g.add_argument("--test-ckpt", type=str, default="last", choices=["best", "last"], help="checkpoint type for testing")
44 |
45 | g = parser.add_argument_group("Personal Options")
46 | g.add_argument("--output-dir", type=str, help="output directory path to save artifacts")
47 | g.add_argument("--gpus", type=int, help="the number of gpus, use all devices by default")
48 | g.add_argument("--logging-interval", type=int, default=100, help="logging interval")
49 | g.add_argument("--valid-interval", type=float, default=1.0, help="validation interval rate")
50 |
51 | g = parser.add_argument_group("Wandb Options")
52 | g.add_argument("--wandb-run-name", type=str, help="wanDB run name")
53 | g.add_argument("--wandb-entity", type=str, help="wanDB entity name")
54 | g.add_argument("--wandb-project", type=str, help="wanDB project name")
55 | # fmt: on
56 |
57 |
58 | def main(args: argparse.Namespace) -> dict[str, float]:
59 | logger = get_logger("train_language_modeling")
60 |
61 | if args.output_dir:
62 | os.makedirs(args.output_dir)
63 | logger.info(f'[+] Save output to "{args.output_dir}"')
64 |
65 | logger.info(" ====== Arguements ======")
66 | for k, v in vars(args).items():
67 | logger.info(f"{k:25}: {v}")
68 |
69 | logger.info(f"[+] Set Random Seed to {args.seed}")
70 | pl.seed_everything(args.seed, workers=True)
71 |
72 | logger.info(f"[+] GPU: {args.gpus}")
73 |
74 | if args.tokenizer is None and args.dataset != "enwik8":
75 | if args.model:
76 | logger.info(f"[+] Use tokenizer same as model: {args.model}")
77 | args.tokenizer = args.model
78 | else:
79 | raise ValueError("you should set `--tokenizer` when use `--model-config`!")
80 | logger.info(f'[+] Load Tokenizer: "{args.tokenizer}"')
81 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) if args.tokenizer else enwik8_tokenize
82 |
83 | logger.info(f'[+] Use Dataset: "{args.dataset}"')
84 | if args.dataset == "wikitext103":
85 | data = load_wikitext103_data()
86 | elif args.dataset == "pg19":
87 | data = load_pg19_data()
88 | elif args.dataset == "enwik8":
89 | data = load_enwik8_data()
90 | else:
91 | raise ValueError(f"dataset `{args.dataset}` is not valid!")
92 |
93 | if args.gpus is None:
94 | args.gpus = torch.cuda.device_count()
95 | num_parallels = max(args.gpus, 1)
96 | distributed = num_parallels > 1
97 | batch_size_per_device = max(args.batch_size // num_parallels, 1)
98 | global_batch_size = batch_size_per_device * num_parallels
99 | valid_batch_size_per_device = max(args.valid_batch_size // num_parallels, 1)
100 | global_valid_batch_size = valid_batch_size_per_device * num_parallels
101 | if args.batch_size != global_batch_size:
102 | logger.warning(f"[-] Batch size {args.batch_size} isn't dividable by {args.gpus}!")
103 | logger.warning(f"[-] Use batch size as {batch_size_per_device} per device, {global_batch_size} global")
104 | if args.valid_batch_size != global_valid_batch_size:
105 | logger.warning(f"[-] Valid Batch size {args.valid_batch_size} isn't dividable by {args.gpus}!")
106 | logger.warning(
107 | f"[-] Use batch size as {valid_batch_size_per_device} per device, {global_valid_batch_size} global"
108 | )
109 |
110 | train_tokens = text_to_tokens(data["train"], tokenizer, global_batch_size, args.max_length, batch_size_per_device)
111 | dev_tokens = text_to_tokens(
112 | data["dev"], tokenizer, global_valid_batch_size, args.max_length, valid_batch_size_per_device
113 | )
114 | test_tokens = text_to_tokens(data["test"], tokenizer, args.valid_batch_size, args.max_length)
115 |
116 | train_dataset = LanguageModelingDataset(train_tokens)
117 | valid_dataset = LanguageModelingDataset(dev_tokens)
118 | test_dataset = LanguageModelingDataset(test_tokens)
119 |
120 | logger.info(f"[+] # of batched train examples: {len(train_dataset)}")
121 | logger.info(f"[+] # of batched valid examples: {len(valid_dataset)}")
122 | logger.info(f"[+] # of batched test examples: {len(test_dataset)}")
123 |
124 | if args.model:
125 | logger.info(f'[+] Load Model: "{args.model}"')
126 | if args.model_type:
127 | model_cls = type(AutoModelForCausalLM.from_config(AutoConfig.for_model(args.model_type)))
128 | logger.info(f"[+] Use model type: {args.model_type}")
129 | else:
130 | model_cls = AutoModelForCausalLM
131 | model = model_cls.from_pretrained(args.model)
132 | elif args.model_config:
133 | logger.info(f'[+] Initialize Model with Config: "{args.model_config}"')
134 | config = AutoConfig.from_pretrained(args.model_config, trust_remote_code=True)
135 | model = AutoModelForCausalLM.from_config(config)
136 | else:
137 | raise ValueError("you should set `--model` or `--model-config` argument!")
138 |
139 | total_steps = len(train_tokens["input_ids"]) // num_parallels * args.epochs
140 |
141 | language_modeling = LanguageModeling(
142 | model=model,
143 | total_steps=total_steps,
144 | learning_rate=args.learning_rate,
145 | warmup_rate=args.warmup_rate,
146 | )
147 |
148 | if args.output_dir:
149 | train_loggers = [TensorBoardLogger(args.output_dir, "", "logs")]
150 | model_dir = os.path.join(args.output_dir, "checkpoint")
151 | else:
152 | train_loggers = []
153 | tmp_dir = tempfile.TemporaryDirectory()
154 | model_dir = tmp_dir.name
155 |
156 | logger.info(f"[+] Start Training")
157 | if args.wandb_project and (args.wandb_run_name or args.output_dir):
158 | wandb_logger = WandbLogger(
159 | name=args.wandb_run_name or os.path.basename(args.output_dir),
160 | project=args.wandb_project,
161 | entity=args.wandb_entity,
162 | save_dir=args.output_dir if args.output_dir else None,
163 | )
164 | wandb_logger.log_hyperparams({"train_arguments": vars(args)})
165 | train_loggers.append(wandb_logger)
166 |
167 | model_checkpoint_callback = pl.callbacks.ModelCheckpoint(
168 | model_dir, mode="min", monitor="val/ppl", save_last=True, auto_insert_metric_name=True
169 | )
170 | callbacks = [model_checkpoint_callback]
171 |
172 | if train_loggers:
173 | callbacks.append(LearningRateMonitor(logging_interval="step"))
174 | trainer = pl.Trainer(
175 | logger=train_loggers,
176 | max_epochs=args.epochs,
177 | log_every_n_steps=args.logging_interval,
178 | val_check_interval=args.valid_interval,
179 | accumulate_grad_batches=args.accumulate_grad_batches,
180 | gradient_clip_val=args.max_grad_norm,
181 | callbacks=callbacks,
182 | strategy="ddp_fork" if distributed else None,
183 | accelerator="gpu" if args.gpus else None,
184 | devices=num_parallels,
185 | replace_sampler_ddp=False,
186 | )
187 | trainer.fit(
188 | language_modeling,
189 | datamodule=BatchedDataModule(train_dataset, valid_dataset, args.shuffle, distributed),
190 | )
191 |
192 | # Use seperated initialized trainer (https://github.com/Lightning-AI/lightning/issues/8375)
193 | # Use batch size as 1 because already batched
194 | test_dataloader = DataLoader(test_dataset, batch_size=1, collate_fn=test_dataset.collate_fn)
195 | tester = pl.Trainer(
196 | logger=train_loggers,
197 | callbacks=callbacks,
198 | accelerator="gpu" if args.gpus else None,
199 | devices=1,
200 | )
201 | result = tester.test(language_modeling, test_dataloader, ckpt_path=args.test_ckpt)[0]
202 |
203 | wandb.finish()
204 |
205 | if not args.output_dir:
206 | tmp_dir.cleanup()
207 |
208 | return result
209 |
210 |
211 | if __name__ == "__main__":
212 | main(parser.parse_args())
213 | exit(0)
214 |
--------------------------------------------------------------------------------
/experiment/train_synthetic.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import tempfile
4 | from typing import Dict
5 |
6 | import pytorch_lightning as pl
7 | import torch
8 | import wandb
9 | from longseq_formers.dataset.synthetic import SyntheticDataset, parse_syntetic_data
10 | from longseq_formers.task import Synthetic
11 | from longseq_formers.utils import get_logger
12 | from pytorch_lightning.callbacks import LearningRateMonitor
13 | from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
14 | from torch.utils.data import DataLoader
15 | from transformers import AutoConfig, AutoModelForCausalLM
16 |
17 | # fmt: off
18 | parser = argparse.ArgumentParser(prog="train_synthetic", description="Train & Test Synthetic Task")
19 |
20 | g = parser.add_argument_group("Train Parameter")
21 | g.add_argument("--model-config", type=str, required=True, help="huggingface model config")
22 | g.add_argument("--dataset", type=str, required=True, help="dataset name")
23 | g.add_argument("--batch-size", type=int, default=32, help="global training batch size")
24 | g.add_argument("--valid-batch-size", type=int, default=1, help="validation batch size")
25 | g.add_argument("--accumulate-grad-batches", type=int, default=1, help="the number of gradident accumulation steps")
26 | g.add_argument("--epochs", type=int, default=1, help="the number of training epochs")
27 | g.add_argument("--learning-rate", type=float, default=2e-4, help="learning rate")
28 | g.add_argument("--warmup-rate", type=float, default=0.06, help="warmup step rate")
29 | g.add_argument("--max-grad-norm", type=float, default=1.0, help="maximum gradient norm")
30 | g.add_argument("--seed", type=int, default=42, help="random seed")
31 | g.add_argument("--test-ckpt", type=str, default="last", choices=["best", "last"], help="checkpoint type for testing")
32 | g.add_argument("--segment-size", type=int, required=True, help="segment size for infinity former")
33 |
34 | g = parser.add_argument_group("Personal Options")
35 | g.add_argument("--output-dir", type=str, help="output directory path to save artifacts")
36 | g.add_argument("--gpus", type=int, help="the number of gpus, use all devices by default")
37 | g.add_argument("--logging-interval", type=int, default=100, help="logging interval")
38 | g.add_argument("--valid-interval", type=float, default=1.0, help="validation interval rate")
39 |
40 | g = parser.add_argument_group("Wandb Options")
41 | g.add_argument("--wandb-run-name", type=str, help="wanDB run name")
42 | g.add_argument("--wandb-entity", type=str, help="wanDB entity name")
43 | g.add_argument("--wandb-project", type=str, help="wanDB project name")
44 | # fmt: on
45 |
46 |
47 | def main(args: argparse.Namespace) -> dict[str, float]:
48 | logger = get_logger("train_synthetic_task")
49 |
50 | if args.output_dir:
51 | os.makedirs(args.output_dir)
52 | logger.info(f'[+] Save output to "{args.output_dir}"')
53 |
54 | logger.info(" ====== Arguements ======")
55 | for k, v in vars(args).items():
56 | logger.info(f"{k:25}: {v}")
57 |
58 | logger.info(f"[+] Set Random Seed to {args.seed}")
59 | pl.seed_everything(args.seed, workers=True)
60 |
61 | logger.info(f"[+] GPU: {args.gpus}")
62 |
63 | logger.info(f'[+] Use Dataset: "{args.dataset}"')
64 | _, vocab_size, train_examples, dev_examples, test_examples = parse_syntetic_data(args.dataset)
65 |
66 | if args.gpus is None:
67 | args.gpus = torch.cuda.device_count()
68 | num_parallels = max(args.gpus, 1)
69 | distributed = num_parallels > 1
70 | batch_size_per_device = max(args.batch_size // num_parallels, 1)
71 | global_batch_size = batch_size_per_device * num_parallels
72 | valid_batch_size_per_device = max(args.valid_batch_size // num_parallels, 1)
73 | global_valid_batch_size = valid_batch_size_per_device * num_parallels
74 | if args.batch_size != global_batch_size:
75 | logger.warning(f"[-] Batch size {args.batch_size} isn't dividable by {args.gpus}!")
76 | logger.warning(f"[-] Use batch size as {batch_size_per_device} per device, {global_batch_size} global")
77 | if args.valid_batch_size != global_valid_batch_size:
78 | logger.warning(f"[-] Valid Batch size {args.valid_batch_size} isn't dividable by {args.gpus}!")
79 | logger.warning(
80 | f"[-] Use batch size as {valid_batch_size_per_device} per device, {global_valid_batch_size} global"
81 | )
82 |
83 | train_dataset = SyntheticDataset(train_examples)
84 | valid_dataset = SyntheticDataset(dev_examples)
85 | test_dataset = SyntheticDataset(test_examples)
86 |
87 | logger.info(f"[+] # of batched train examples: {len(train_dataset)}")
88 | logger.info(f"[+] # of batched valid examples: {len(valid_dataset)}")
89 | logger.info(f"[+] # of batched test examples: {len(test_dataset)}")
90 |
91 | logger.info(f'[+] Initialize Model with Config: "{args.model_config}"')
92 | config = AutoConfig.from_pretrained(args.model_config, trust_remote_code=True, vocab_size=vocab_size)
93 | model = AutoModelForCausalLM.from_config(config)
94 |
95 | train_dataloader = DataLoader(
96 | train_dataset, shuffle=True, batch_size=batch_size_per_device, num_workers=os.cpu_count() // 2, pin_memory=True
97 | )
98 | valid_dataloader = DataLoader(valid_dataset, batch_size=args.valid_batch_size // num_parallels)
99 | test_dataloader = DataLoader(test_dataset, batch_size=args.valid_batch_size // num_parallels)
100 | total_steps = len(train_dataloader) * args.epochs
101 |
102 | synthetic_task = Synthetic(
103 | model=model,
104 | total_steps=total_steps,
105 | learning_rate=args.learning_rate,
106 | warmup_rate=args.warmup_rate,
107 | segment_size=args.segment_size,
108 | vocab_size=vocab_size,
109 | max_grad_norm=args.max_grad_norm,
110 | )
111 |
112 | if args.output_dir:
113 | train_loggers = [TensorBoardLogger(args.output_dir, "", "logs")]
114 | model_dir = os.path.join(args.output_dir, "checkpoint")
115 | else:
116 | train_loggers = []
117 | tmp_dir = tempfile.TemporaryDirectory()
118 | model_dir = tmp_dir.name
119 |
120 | logger.info(f"[+] Start Training")
121 | if args.wandb_project and (args.wandb_run_name or args.output_dir):
122 | wandb_logger = WandbLogger(
123 | name=args.wandb_run_name or os.path.basename(args.output_dir),
124 | project=args.wandb_project,
125 | entity=args.wandb_entity,
126 | save_dir=args.output_dir if args.output_dir else None,
127 | )
128 | wandb_logger.log_hyperparams({"train_arguments": vars(args)})
129 | train_loggers.append(wandb_logger)
130 |
131 | model_checkpoint_callback = pl.callbacks.ModelCheckpoint(
132 | model_dir, mode="max", monitor="val/acc-final", save_last=True, auto_insert_metric_name=True
133 | )
134 | callbacks = [model_checkpoint_callback]
135 |
136 | if train_loggers:
137 | callbacks.append(LearningRateMonitor(logging_interval="step"))
138 | trainer = pl.Trainer(
139 | logger=train_loggers,
140 | max_epochs=args.epochs,
141 | log_every_n_steps=args.logging_interval,
142 | val_check_interval=args.valid_interval,
143 | accumulate_grad_batches=args.accumulate_grad_batches,
144 | callbacks=callbacks,
145 | strategy="ddp_fork" if distributed else None,
146 | accelerator="gpu" if args.gpus else None,
147 | devices=num_parallels,
148 | replace_sampler_ddp=False,
149 | )
150 | trainer.fit(synthetic_task, train_dataloader, valid_dataloader)
151 |
152 | # Use seperated initialized trainer (https://github.com/Lightning-AI/lightning/issues/8375)
153 | # Use batch size as 1 because already batched
154 | tester = pl.Trainer(
155 | logger=train_loggers,
156 | callbacks=callbacks,
157 | accelerator="gpu" if args.gpus else None,
158 | devices=1,
159 | )
160 | result = tester.test(synthetic_task, test_dataloader, ckpt_path=args.test_ckpt)[0]
161 |
162 | wandb.finish()
163 |
164 | if not args.output_dir:
165 | tmp_dir.cleanup()
166 |
167 | return result
168 |
169 |
170 | if __name__ == "__main__":
171 | main(parser.parse_args())
172 | exit(0)
173 |
--------------------------------------------------------------------------------
/images/Memoria-Engrams.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cosmoquester/memoria/e4ba6e2e13410e01fba896dd7800e66520e9d716/images/Memoria-Engrams.gif
--------------------------------------------------------------------------------
/memoria/__init__.py:
--------------------------------------------------------------------------------
1 | from . import utils
2 | from .abstractor import Abstractor
3 | from .engram import Engrams, EngramType
4 | from .history_manager import HistoryManager
5 | from .memoria import Memoria
6 | from .sparse_tensor import SparseTensor
7 |
8 | __all__ = ["utils", "Abstractor", "Engrams", "EngramType", "HistoryManager", "Memoria", "SparseTensor"]
9 | __version__ = "1.0.0"
10 |
--------------------------------------------------------------------------------
/memoria/abstractor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class Abstractor(nn.Module):
6 | """Abstract Module to summarize data"""
7 |
8 | def __init__(self, num_memories: int, hidden_dim: int, feedforward_dim: int) -> None:
9 | """
10 | Args:
11 | num_memories (int): Number of memories to be created
12 | hidden_dim (int): Hidden dimension of the model
13 | feedforward_dim (int): Feedforward dimension of the model
14 | """
15 | super().__init__()
16 |
17 | w = torch.empty(1, num_memories, hidden_dim)
18 | nn.init.normal_(w, std=0.02)
19 | self.query_embeddings = nn.Parameter(w)
20 | self.key_transform = nn.Linear(hidden_dim, hidden_dim)
21 | self.value_transform = nn.Linear(hidden_dim, hidden_dim)
22 | self.feedforward = nn.Linear(hidden_dim, feedforward_dim)
23 | self.output = nn.Linear(feedforward_dim, hidden_dim)
24 |
25 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
26 | """
27 |
28 | Args:
29 | hidden_states (torch.Tensor): [Batch, N, HiddenDim]
30 | Returns:
31 | torch.Tensor: [Batch, NumMemories, HiddenDim]
32 | """
33 | query = self.query_embeddings
34 | key = self.key_transform(hidden_states)
35 | # [Batch, N, HiddemDim]
36 | value = self.value_transform(hidden_states)
37 | # [Batch, NumMemories, HiddenDim] x [Batch, N, HiddenDim] -> [Batch, NumMemories, N]
38 | attn = query @ key.transpose(-2, -1)
39 | attn = attn.softmax(dim=-1)
40 | attn = attn @ value
41 | attn = self.feedforward(attn)
42 | attn = self.output(attn)
43 | return attn
44 |
--------------------------------------------------------------------------------
/memoria/history_manager.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import pickle
3 | from collections import defaultdict
4 |
5 | from .types import EngramHistory, EngramInfo, EngramsInfo, Firing
6 |
7 |
8 | class HistoryManager:
9 | """Managing History of engram summaries.
10 |
11 | Attributes:
12 | timestep: Current timestep.
13 | summaries: List of engram summaries.
14 | engram_creation_times: Dictionary of engram creation times.
15 | engram_deletion_times: Dictionary of engram deletion times.
16 | engram_durations: Dictionary of engram durations.
17 | engram_firing_times: Dictionary of engram firing times.
18 | engram_firings: Dictionary of engram firings.
19 | firings_per_time: List of firings per timestep.
20 | engram_fire_counts: Dictionary of engram fire counts.
21 | engram_ids: List of engram ids.
22 | alive_engram_ids: List of alive engram ids.
23 | deleted_engram_ids: List of deleted engram ids.
24 | """
25 |
26 | def __init__(self):
27 | self.summaries: list[EngramsInfo] = []
28 | self.engram_creation_times: dict[int, int] = {}
29 | self.engram_deletion_times: dict[int, int] = {}
30 | self.engram_durations: dict[int, int] = {}
31 | self.engram_firing_times: dict[int, list[int]] = defaultdict(list)
32 | self.engram_firings: dict[int, list[Firing]] = defaultdict(list)
33 | self.firings_per_time: list[list[Firing]] = []
34 |
35 | self.alive_engram_ids: list[int] = []
36 | self.deleted_engram_ids: list[int] = []
37 |
38 | def __len__(self) -> int:
39 | return len(self.summaries)
40 |
41 | def __getitem__(self, index: int) -> EngramsInfo:
42 | return self.summaries[index]
43 |
44 | def save(self, path: str) -> None:
45 | """Save the history manager to a compressed data file.
46 |
47 | Args:
48 | path: Path to save the history manager.
49 | """
50 | with gzip.open(path, "wb") as f:
51 | pickle.dump(self, f)
52 |
53 | @classmethod
54 | def load(cls, path: str) -> "HistoryManager":
55 | """Load the history manager from a compressed data file.
56 |
57 | Args:
58 | path: Path to load the history manager.
59 | Returns:
60 | HistoryManager: Loaded history manager.
61 | """
62 | with gzip.open(path, "rb") as f:
63 | return pickle.load(f)
64 |
65 | @property
66 | def timestep(self) -> int:
67 | """Get the current timestep."""
68 | return len(self)
69 |
70 | @property
71 | def engram_ids(self) -> list[int]:
72 | """Get the list of engram IDs."""
73 | return list(self.engram_creation_times.keys())
74 |
75 | @property
76 | def engram_fire_counts(self) -> dict[int, int]:
77 | """Get the fire counts of the engrams."""
78 | return {engram_id: len(firings) for engram_id, firings in self.engram_firings.items()}
79 |
80 | @property
81 | def engram_lastest_alive_timestep(self) -> dict[int, int]:
82 | """Get the latest alive timestep of the engrams."""
83 | return {
84 | engram_id: creation_time + self.engram_durations[engram_id] - 1
85 | for engram_id, creation_time in self.engram_creation_times.items()
86 | }
87 |
88 | @property
89 | def latest_engram_infos(self) -> dict[int, EngramInfo]:
90 | """Get the latest engram information before dying."""
91 | last_timestep = self.engram_lastest_alive_timestep
92 | return {engram_id: self.summaries[last_timestep[engram_id]].engrams[engram_id] for engram_id in self.engram_ids}
93 |
94 | def add_summary(self, summary: EngramsInfo) -> None:
95 | firings = []
96 | for engram_id, engram in summary.engrams.items():
97 | if engram_id not in self.alive_engram_ids:
98 | self.engram_creation_times[engram_id] = self.timestep
99 | self.alive_engram_ids.append(engram_id)
100 | elif (
101 | engram_id in self.summaries[-1].engrams
102 | and engram.fire_count > self.summaries[-1].engrams[engram_id].fire_count
103 | ):
104 | self.engram_firing_times[engram_id].append(self.timestep)
105 | firing = Firing(
106 | timestep=self.timestep,
107 | engram_id=engram_id,
108 | lifespan_gain=engram.lifespan - self.summaries[-1].engrams[engram_id].lifespan + 1.0,
109 | )
110 | self.engram_firings[engram_id].append(firing)
111 | firings.append(firing)
112 | self.firings_per_time.append(firings)
113 |
114 | for engram_id in self.alive_engram_ids:
115 | if engram_id not in summary.engrams:
116 | self.engram_deletion_times[engram_id] = self.timestep
117 | self.deleted_engram_ids.append(engram_id)
118 | self.alive_engram_ids = list(summary.engrams.keys())
119 | for engram_id in self.alive_engram_ids:
120 | self.engram_durations[engram_id] = self.timestep - self.engram_creation_times[engram_id] + 1
121 |
122 | self.summaries.append(summary)
123 |
124 | def inspect(self, engram_id: int) -> EngramHistory:
125 | """Inspect the history of an engram.
126 |
127 | Args:
128 | engram_id: Engram ID to inspect.
129 | Returns:
130 | EngramHistory: Historical information of the engram.
131 | """
132 | creation_time = self.engram_creation_times[engram_id]
133 | deletion_time = self.engram_deletion_times.get(engram_id)
134 | duration = self.engram_durations.get(engram_id)
135 | firing_times = self.engram_firing_times.get(engram_id, [])
136 | firings = self.engram_firings.get(engram_id, [])
137 | related_summaries = self.summaries[creation_time:deletion_time]
138 |
139 | return EngramHistory(
140 | id=engram_id,
141 | creation_timestep=creation_time,
142 | deletion_timestep=deletion_time,
143 | duration=duration,
144 | firing_times=firing_times,
145 | firings=firings,
146 | summaries=related_summaries,
147 | )
148 |
--------------------------------------------------------------------------------
/memoria/types.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Literal, Optional, Tuple
3 |
4 |
5 | @dataclass(slots=True, frozen=True)
6 | class EngramConnection:
7 | """Data structure for engram connections."""
8 |
9 | #: Source engram ID.
10 | source_id: int
11 | #: Target engram ID.
12 | target_id: int
13 | #: Connection weight (Probability).
14 | weight: float
15 | #: Cofire count.
16 | cofire_count: int
17 |
18 |
19 | @dataclass(slots=True, frozen=True)
20 | class EngramInfo:
21 | """Data structure for engram information."""
22 |
23 | #: Engram ID.
24 | id: int
25 | #: Engram type.
26 | type: Literal["WORKING", "SHORTTERM", "LONGTERM"]
27 | #: Lifetime of the engram.
28 | lifespan: int
29 | #: The age of the engram.
30 | age: Optional[int]
31 | #: Fire count of the engram.
32 | fire_count: int
33 | #: The outgoing edges of the engram.
34 | outgoings: Tuple[EngramConnection]
35 | #: The incoming edges of the engram.
36 | incoming: Tuple[EngramConnection]
37 |
38 | @property
39 | def cofire_counts(self) -> dict[int, int]:
40 | """Get the cofire counts of the engram."""
41 | return {edge.target_id: edge.cofire_count for edge in self.outgoings}
42 |
43 |
44 | @dataclass(slots=True, frozen=True)
45 | class EngramsInfo:
46 | """Data structure for engrams information."""
47 |
48 | #: Engram ID to EngramInfo mapping.
49 | engrams: dict[int, EngramInfo]
50 | #: All engram connections mapping from source and target engram IDs.
51 | edges: dict[Tuple[int, int], EngramConnection]
52 | #: Working memory engram IDs.
53 | working: Tuple[int]
54 | #: Short-term memory engram IDs.
55 | shortterm: Tuple[int]
56 | #: Long-term memory engram IDs.
57 | longterm: Tuple[int]
58 |
59 |
60 | @dataclass(slots=True, frozen=True)
61 | class Firing:
62 | """Data structure for firing information."""
63 |
64 | #: Firing timestep.
65 | timestep: int
66 | #: Engram ID.
67 | engram_id: int
68 | #: Lifespan Gain.
69 | lifespan_gain: float
70 |
71 |
72 | @dataclass(slots=True, frozen=True)
73 | class EngramHistory:
74 | """Historical information of an engram."""
75 |
76 | #: Engram ID.
77 | id: int
78 | #: Creation time of the engram.
79 | creation_timestep: int
80 | #: Deletion time of the engram.
81 | deletion_timestep: Optional[int]
82 | #: Duration of the engram.
83 | duration: Optional[int]
84 | #: Firing times of the engram.
85 | firing_times: list[int]
86 | #: Firing information of the engram.
87 | firings: list[Firing]
88 | #: Summaries of the engram.
89 | summaries: list[EngramsInfo]
90 |
--------------------------------------------------------------------------------
/memoria/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def super_unique(t: torch.Tensor, dim: int) -> torch.Tensor:
5 | if t.numel() == 0:
6 | return t
7 |
8 | min_value = t.min()
9 | t = t - min_value
10 |
11 | max_value = t.max()
12 | new_shape = list(t.shape)
13 | new_shape[dim] = max_value + 1
14 | unique_t_mask = torch.zeros(new_shape, dtype=torch.bool, device=t.device)
15 | unique_t_mask.scatter_(dim, t.long(), True)
16 |
17 | k = min(t.size(dim), unique_t_mask.sum(dim).max().item())
18 | validity, unique_t = unique_t_mask.int().topk(k, dim=dim)
19 | unique_t += min_value
20 | unique_t.masked_fill_(~validity.bool(), -1)
21 | return unique_t
22 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 120
3 | include = '\.pyi?$'
4 |
5 | [tool.isort]
6 | multi_line_output = 3
7 | line_length = 120
8 |
9 | [tool.pyright]
10 | reportUnknownVariableType = false
11 | reportUnknownMemberType = false
12 | reportUnusedImport = true
13 | reportUnusedVariable = true
14 | reportUnusedClass = true
15 | reportUnusedFunction = true
16 | reportImportCycles = true
17 | reportTypeshedErrors = true
18 | reportOptionalMemberAccess = true
19 | reportUntypedBaseClass = true
20 | reportPrivateUsage = true
21 | reportConstantRedefinition = true
22 | reportInvalidStringEscapeSequence = true
23 | reportUnnecessaryIsInstance = true
24 | reportUnnecessaryCast = true
25 | reportAssertAlwaysTrue = true
26 | reportSelfClsParameterName = true
27 |
--------------------------------------------------------------------------------
/requirements-dev.txt:
--------------------------------------------------------------------------------
1 | black
2 | isort
3 | pytest
4 | pytest-cov
5 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | numpy<2
3 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | with open("README.md", "r") as f:
4 | long_description = f.read()
5 |
6 | setup(
7 | name="memoria-pytorch",
8 | version="1.1.0",
9 | description="Memoria is a human-inspired memory architecture for neural networks.",
10 | long_description=long_description,
11 | long_description_content_type="text/markdown",
12 | python_requires=">=3.10",
13 | install_requires=["torch"],
14 | url="https://github.com/cosmoquester/memoria.git",
15 | author="Park Sangjun",
16 | keywords=["memoria", "hebbian", "memory", "transformer"],
17 | classifiers=[
18 | "Programming Language :: Python :: 3",
19 | "License :: OSI Approved :: MIT License",
20 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
21 | ],
22 | packages=find_packages(exclude=["tests", "experiment"]),
23 | )
24 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cosmoquester/memoria/e4ba6e2e13410e01fba896dd7800e66520e9d716/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_abstractor.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from memoria.abstractor import Abstractor
4 |
5 |
6 | def test_abstractor():
7 | abstractor = Abstractor(num_memories=3, hidden_dim=4, feedforward_dim=5)
8 | hidden_states = torch.randn(2, 3, 4)
9 | output = abstractor(hidden_states)
10 | assert output.shape == (2, 3, 4)
11 |
--------------------------------------------------------------------------------
/tests/test_history_manager.py:
--------------------------------------------------------------------------------
1 | from memoria.history_manager import HistoryManager
2 | from memoria.types import EngramInfo, EngramsInfo, Firing
3 |
4 |
5 | def test_history_manager():
6 | history_manager = HistoryManager()
7 | assert len(history_manager) == 0
8 | assert history_manager.timestep == 0
9 | assert history_manager.engram_ids == []
10 | assert history_manager.alive_engram_ids == []
11 | assert history_manager.deleted_engram_ids == []
12 |
13 | engrams = {
14 | 1: EngramInfo(id=1, type="WORKING", lifespan=4, age=0, fire_count=0, outgoings=[], incoming=[]),
15 | 2: EngramInfo(id=2, type="SHORTTERM", lifespan=5, age=0, fire_count=0, outgoings=[], incoming=[]),
16 | 3: EngramInfo(id=3, type="LONGTERM", lifespan=6, age=0, fire_count=0, outgoings=[], incoming=[]),
17 | }
18 | history_manager.add_summary(EngramsInfo(engrams=engrams, edges={}, working=[1], shortterm=[2], longterm=[3]))
19 | assert len(history_manager) == 1
20 | assert history_manager.timestep == 1
21 | assert history_manager.engram_ids == [1, 2, 3]
22 | assert history_manager.alive_engram_ids == [1, 2, 3]
23 | assert history_manager.deleted_engram_ids == []
24 | assert history_manager.engram_fire_counts == {}
25 | assert history_manager.engram_lastest_alive_timestep == {1: 0, 2: 0, 3: 0}
26 | assert history_manager.latest_engram_infos == {1: engrams[1], 2: engrams[2], 3: engrams[3]}
27 |
28 | engrams2 = {
29 | 2: EngramInfo(id=2, type="SHORTTERM", lifespan=4, age=0, fire_count=0, outgoings=[], incoming=[]),
30 | 3: EngramInfo(id=3, type="LONGTERM", lifespan=8, age=0, fire_count=1, outgoings=[], incoming=[]),
31 | 4: EngramInfo(id=4, type="LONGTERM", lifespan=6, age=0, fire_count=0, outgoings=[], incoming=[]),
32 | }
33 | history_manager.add_summary(EngramsInfo(engrams=engrams2, edges={}, working=[], shortterm=[2], longterm=[3, 4]))
34 | assert len(history_manager) == 2
35 | assert history_manager.timestep == 2
36 | assert history_manager.engram_ids == [1, 2, 3, 4]
37 | assert history_manager.alive_engram_ids == [2, 3, 4]
38 | assert history_manager.deleted_engram_ids == [1]
39 | assert history_manager.engram_firing_times == {3: [1]}
40 | assert history_manager.engram_firings == {3: [Firing(timestep=1, engram_id=3, lifespan_gain=3.0)]}
41 | assert history_manager.engram_fire_counts == {3: 1}
42 | assert history_manager.engram_lastest_alive_timestep == {1: 0, 2: 1, 3: 1, 4: 1}
43 | assert history_manager.latest_engram_infos == {1: engrams[1], 2: engrams2[2], 3: engrams2[3], 4: engrams2[4]}
44 |
45 | engram_history = history_manager.inspect(1)
46 | assert engram_history.id == 1
47 | assert engram_history.creation_timestep == 0
48 | assert engram_history.deletion_timestep == 1
49 | assert engram_history.duration == 1
50 | assert engram_history.firing_times == []
51 | assert engram_history.firings == []
52 |
--------------------------------------------------------------------------------
/tests/test_memoria.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from memoria.engram import Engrams, EngramType
4 | from memoria.memoria import Memoria
5 |
6 |
7 | def test_add_working_memory():
8 | memoria = Memoria(
9 | num_reminded_stm=10,
10 | ltm_search_depth=3,
11 | stm_capacity=100,
12 | initial_lifespan=100,
13 | num_final_ltms=100,
14 | )
15 | memoria.add_working_memory(torch.randn(3, 10, 32))
16 | assert len(memoria.engrams) == 30
17 |
18 |
19 | def test_calculate_wm_stm_weight():
20 | memoria = Memoria(
21 | num_reminded_stm=10,
22 | ltm_search_depth=3,
23 | stm_capacity=100,
24 | initial_lifespan=100,
25 | num_final_ltms=100,
26 | )
27 | memoria.add_working_memory(torch.randn(3, 10, 32))
28 |
29 | wm = Engrams(torch.randn(3, 10, 32))
30 | stm = Engrams(torch.randn(3, 20, 32), engrams_types=EngramType.SHORTTERM)
31 | weight = memoria._calculate_memory_weight(wm, stm)
32 | assert weight.shape == torch.Size([3, 10, 20])
33 |
34 |
35 | def test_remind_shortterm_memory():
36 | memoria = Memoria(
37 | num_reminded_stm=2,
38 | ltm_search_depth=3,
39 | stm_capacity=100,
40 | initial_lifespan=100,
41 | num_final_ltms=100,
42 | )
43 |
44 | weight = torch.tensor([[[0.51, 0.2, 0.2, 0.8]]])
45 | shortterm_memory_indices = torch.tensor([[1, 2, 3, 4]])
46 | reminded = memoria._remind_shortterm_memory(weight, shortterm_memory_indices)
47 | assert (reminded == torch.tensor([[1, -1, -1, 4]])).all()
48 |
49 |
50 | def test_find_initial_ltm():
51 | num_stm = 5
52 | num_ltm = 4
53 | memoria = Memoria(
54 | num_reminded_stm=10,
55 | ltm_search_depth=3,
56 | stm_capacity=100,
57 | initial_lifespan=100,
58 | num_final_ltms=100,
59 | )
60 |
61 | stm = Engrams(torch.randn(1, num_stm, 32), engrams_types=EngramType.SHORTTERM)
62 | ltm = Engrams(torch.randn(1, num_ltm, 32), engrams_types=EngramType.LONGTERM)
63 | engrams = stm + ltm
64 | memoria.engrams = engrams
65 | memoria.engrams.induce_counts[:, :num_stm, num_stm:] = torch.tensor(
66 | [
67 | [
68 | [1, 1, 2, 1],
69 | [1, 1, 1, 2],
70 | [1, 10, 1, 1],
71 | [1, 1, 2, 1],
72 | [1, 5, 1, 1],
73 | ]
74 | ]
75 | )
76 | memoria.engrams.induce_counts[:, :num_stm, :num_stm] = 999
77 | nearest_stm_indices = torch.tensor([[0, 2, 3]])
78 |
79 | initial_ltm_indices = memoria._find_initial_longterm_memory(nearest_stm_indices)
80 | assert (initial_ltm_indices == torch.tensor([[6, 7]])).all()
81 |
82 |
83 | def test_search_longterm_memories_with_initials():
84 | num_stm = 5
85 | num_ltm = 4
86 | ltm_search_depth = 3
87 | memoria = Memoria(
88 | num_reminded_stm=10,
89 | ltm_search_depth=ltm_search_depth,
90 | stm_capacity=100,
91 | initial_lifespan=100,
92 | num_final_ltms=100,
93 | )
94 |
95 | stm = Engrams(torch.randn(1, num_stm, 32), engrams_types=EngramType.SHORTTERM)
96 | ltm = Engrams(torch.randn(1, num_ltm, 32), engrams_types=EngramType.LONGTERM)
97 | memoria.engrams = stm + ltm
98 |
99 | initial_ltm_indices = torch.tensor([[5, 7]])
100 | searched_ltm_indices = memoria._search_longterm_memories_with_initials(initial_ltm_indices, ltm)
101 |
102 | assert (searched_ltm_indices == torch.tensor([[1, 3, 2, -1, 0]])).all()
103 |
104 |
105 | def test_memorize_working_memory_as_shortterm_memory():
106 | batch_size = 3
107 | num_wm = 5
108 | num_stm = 4
109 | num_ltm = 2
110 | ltm_search_depth = 3
111 | memoria = Memoria(
112 | num_reminded_stm=10,
113 | ltm_search_depth=ltm_search_depth,
114 | stm_capacity=100,
115 | initial_lifespan=100,
116 | num_final_ltms=100,
117 | )
118 |
119 | wm = Engrams(torch.randn(batch_size, num_wm, 32), engrams_types=EngramType.WORKING)
120 | stm = Engrams(torch.randn(batch_size, num_stm, 32), engrams_types=EngramType.SHORTTERM)
121 | ltm = Engrams(torch.randn(batch_size, num_ltm, 32), engrams_types=EngramType.LONGTERM)
122 | memoria.engrams = wm + stm + ltm
123 |
124 | memoria._memorize_working_memory_as_shortterm_memory()
125 |
126 | assert memoria.engrams.get_shortterm_memory()[0].data.shape == torch.Size([batch_size, num_wm + num_stm, 32])
127 |
128 |
129 | def test_memorize_shortterm_memory_as_longterm_memory_or_drop():
130 | batch_size = 1
131 | num_stm = 5
132 | num_ltm = 3
133 | ltm_search_depth = 3
134 | memoria = Memoria(
135 | num_reminded_stm=10,
136 | ltm_search_depth=ltm_search_depth,
137 | stm_capacity=2,
138 | initial_lifespan=100,
139 | num_final_ltms=100,
140 | )
141 |
142 | fire_count = torch.tensor([[0, 1, 2, 3, 0]], dtype=torch.int32)
143 | stm = Engrams(torch.randn(batch_size, num_stm, 32), engrams_types=EngramType.SHORTTERM)
144 | stm.fire_count = fire_count
145 | ltm = Engrams(torch.randn(batch_size, num_ltm, 32), engrams_types=EngramType.LONGTERM)
146 | memoria.engrams = stm + ltm
147 |
148 | memoria._memorize_shortterm_memory_as_longterm_memory()
149 |
150 | assert len(memoria.engrams) == batch_size * (num_stm + num_ltm)
151 |
152 |
153 | def test_remind():
154 | num_reminded_stm = 2
155 | ltm_search_depth = 3
156 | stm_capacity = 100
157 | memoria = Memoria(
158 | num_reminded_stm=num_reminded_stm,
159 | ltm_search_depth=ltm_search_depth,
160 | stm_capacity=stm_capacity,
161 | initial_lifespan=100,
162 | num_final_ltms=100,
163 | )
164 |
165 | batch_size = 3
166 | memory_length = 50
167 | hidden_dim = 32
168 | working_memory = torch.randn(batch_size, memory_length, hidden_dim)
169 | memoria.add_working_memory(working_memory)
170 | outputs, indices = memoria.remind()
171 | memoria.adjust_lifespan_and_memories(indices, torch.ones_like(indices, dtype=float))
172 | assert len(memoria.engrams.get_shortterm_memory()[0]) == batch_size * memory_length
173 | assert outputs.size(1) == 0
174 |
175 | working_memory = torch.randn(batch_size, memory_length, hidden_dim)
176 | memoria.add_working_memory(working_memory)
177 | outputs, indices = memoria.remind()
178 | memoria.adjust_lifespan_and_memories(indices, torch.ones_like(indices, dtype=float))
179 | assert len(memoria.engrams.get_shortterm_memory()[0]) == batch_size * memory_length * 2
180 | assert outputs.size(1) > 0
181 |
182 | working_memory = torch.randn(batch_size, memory_length, hidden_dim)
183 | memoria.add_working_memory(working_memory)
184 | outputs, indices = memoria.remind()
185 | memoria.adjust_lifespan_and_memories(indices, torch.ones_like(indices, dtype=float))
186 | assert len(memoria.engrams.get_shortterm_memory()[0]) == batch_size * memory_length * 2
187 | assert outputs.size(1) > 0
188 |
189 |
190 | def test_reset_memory():
191 | memoria = Memoria(
192 | num_reminded_stm=10,
193 | ltm_search_depth=3,
194 | stm_capacity=100,
195 | initial_lifespan=100,
196 | num_final_ltms=100,
197 | )
198 | memoria.add_working_memory(torch.randn(3, 10, 32))
199 | memoria.reset_memory()
200 | assert memoria.engrams == Engrams.empty()
201 |
--------------------------------------------------------------------------------
/tests/test_sparse_tensor.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from memoria.sparse_tensor import SparseTensor
4 |
5 |
6 | def test_from_tensor():
7 | tensor = torch.tensor([[1, 0, 0], [0, 2, 0], [0, 0, 3]], dtype=torch.int32)
8 | sparse_tensor = SparseTensor.from_tensor(tensor)
9 | assert sparse_tensor.indices.tolist() == [[0, 0], [1, 1], [2, 2]]
10 | assert sparse_tensor.values.tolist() == [1, 2, 3]
11 |
12 |
13 | def test_get_item():
14 | tensor = torch.tensor([[1, 0, 0], [0, 2, 0], [0, 0, 3]], dtype=torch.int32)
15 | sparse_tensor = SparseTensor.from_tensor(tensor)
16 |
17 | assert sparse_tensor.indices.tolist() == [[0, 0], [1, 1], [2, 2]]
18 | assert sparse_tensor.values.tolist() == [1, 2, 3]
19 |
20 | selected = sparse_tensor[0, 0]
21 | assert isinstance(selected, torch.Tensor)
22 | assert selected.item() == 1
23 |
24 | selected = sparse_tensor[0, 2]
25 | assert isinstance(selected, torch.Tensor)
26 | assert selected.item() == 0
27 |
28 | selected = sparse_tensor[1]
29 | assert selected.shape == (3,)
30 | assert selected.indices.tolist() == [[1]]
31 | assert selected.values.tolist() == [2]
32 | assert selected.tolist() == [0, 2, 0]
33 |
34 | selected = sparse_tensor[:, 2]
35 | assert selected.shape == (3,)
36 | assert selected.indices.tolist() == [[2]]
37 | assert selected.values.tolist() == [3]
38 | assert selected.tolist() == [0, 0, 3]
39 |
40 | selected = sparse_tensor[torch.tensor([0, 2])]
41 | assert selected.shape == (2, 3)
42 | assert selected.indices.tolist() == [[0, 0], [1, 2]]
43 | assert selected.values.tolist() == [1, 3]
44 | assert selected.tolist() == [[1, 0, 0], [0, 0, 3]]
45 |
46 | selected = sparse_tensor[torch.tensor([0, 2]), 2]
47 | assert selected.shape == (2,)
48 | assert selected.indices.tolist() == [[1]]
49 | assert selected.values.tolist() == [3]
50 | assert selected.tolist() == [0, 3]
51 |
52 | selected = sparse_tensor[torch.tensor([[[0, 2]]])]
53 | assert selected.shape == (1, 1, 2, 3)
54 | assert selected.indices.tolist() == [[0, 0, 0, 0], [0, 0, 1, 2]]
55 | assert selected.values.tolist() == [1, 3]
56 | assert selected.tolist() == [[[[1, 0, 0], [0, 0, 3]]]]
57 |
58 | selected = sparse_tensor[torch.tensor([[0, 1], [2, 0]]), torch.tensor([[1, 2], [2, 0]])]
59 | assert selected.shape == (2, 2)
60 | assert selected.indices.tolist() == [[1, 0], [1, 1]]
61 | assert selected.values.tolist() == [3, 1]
62 | assert selected.tolist() == [[0, 0], [3, 1]]
63 |
64 | selected = sparse_tensor[0:2]
65 | assert selected.shape == (2, 3)
66 | assert selected.indices.tolist() == [[0, 0], [1, 1]]
67 | assert selected.values.tolist() == [1, 2]
68 | assert selected.tolist() == [[1, 0, 0], [0, 2, 0]]
69 |
70 | selected = sparse_tensor[0:1, 1:3]
71 | assert selected.shape == (1, 2)
72 | assert selected.indices.tolist() == []
73 | assert selected.values.tolist() == []
74 | assert selected.tolist() == [[0, 0]]
75 |
76 | selected = sparse_tensor[0:2, 1:3]
77 | assert selected.shape == (2, 2)
78 | assert selected.indices.tolist() == [[1, 0]]
79 | assert selected.values.tolist() == [2]
80 | assert selected.tolist() == [[0, 0], [2, 0]]
81 |
82 | selected = sparse_tensor[torch.tensor([0, 2]), 1:3]
83 | assert selected.shape == (2, 2)
84 | assert selected.indices.tolist() == [[1, 1]]
85 | assert selected.values.tolist() == [3]
86 | assert selected.tolist() == [[0, 0], [0, 3]]
87 |
88 |
89 | def test_set_item():
90 | tensor = torch.tensor([[1, 0, 0], [0, 2, 0], [0, 0, 3]], dtype=torch.int32)
91 | sparse_tensor = SparseTensor.from_tensor(tensor)
92 |
93 | sparse_tensor[0, 0] = torch.tensor(90)
94 | assert sparse_tensor.indices.tolist() == [[1, 1], [2, 2], [0, 0]]
95 | assert sparse_tensor.values.tolist() == [2, 3, 90]
96 | assert sparse_tensor.to_dense().tolist() == [[90, 0, 0], [0, 2, 0], [0, 0, 3]]
97 |
98 | sparse_tensor[0, 0] += 10
99 | assert sparse_tensor.indices.tolist() == [[1, 1], [2, 2], [0, 0]]
100 | assert sparse_tensor.values.tolist() == [2, 3, 100]
101 | assert sparse_tensor.to_dense().tolist() == [[100, 0, 0], [0, 2, 0], [0, 0, 3]]
102 |
103 | sparse_tensor = SparseTensor.from_tensor(tensor)
104 | sparse_tensor[0, 0] = 10
105 | assert sparse_tensor.indices.tolist() == [[1, 1], [2, 2], [0, 0]]
106 | assert sparse_tensor.values.tolist() == [2, 3, 10]
107 | assert sparse_tensor.to_dense().tolist() == [[10, 0, 0], [0, 2, 0], [0, 0, 3]]
108 |
109 | sparse_tensor[1] = 20
110 | assert sparse_tensor.indices.tolist() == [[2, 2], [0, 0], [1, 0], [1, 1], [1, 2]]
111 | assert sparse_tensor.values.tolist() == [3, 10, 20, 20, 20]
112 | assert sparse_tensor.to_dense().tolist() == [[10, 0, 0], [20, 20, 20], [0, 0, 3]]
113 |
114 | sparse_tensor[:, 2] = 30
115 | assert sparse_tensor.indices.tolist() == [[0, 0], [1, 0], [1, 1], [0, 2], [1, 2], [2, 2]]
116 | assert sparse_tensor.values.tolist() == [10, 20, 20, 30, 30, 30]
117 | assert sparse_tensor.to_dense().tolist() == [[10, 0, 30], [20, 20, 30], [0, 0, 30]]
118 |
119 | sparse_tensor = SparseTensor.from_tensor(tensor)
120 | sparse_tensor[torch.tensor([0, 2])] = 40
121 | assert sparse_tensor.indices.tolist() == [[1, 1], [0, 0], [0, 1], [0, 2], [2, 0], [2, 1], [2, 2]]
122 | assert sparse_tensor.values.tolist() == [2, 40, 40, 40, 40, 40, 40]
123 | assert sparse_tensor.to_dense().tolist() == [[40, 40, 40], [0, 2, 0], [40, 40, 40]]
124 |
125 | sparse_tensor = SparseTensor.from_tensor(tensor)
126 | sparse_tensor[torch.tensor([0, 2]), 2] = 50
127 | assert sparse_tensor.indices.tolist() == [[0, 0], [1, 1], [0, 2], [2, 2]]
128 | assert sparse_tensor.values.tolist() == [1, 2, 50, 50]
129 | assert sparse_tensor.to_dense().tolist() == [[1, 0, 50], [0, 2, 0], [0, 0, 50]]
130 |
131 | sparse_tensor = SparseTensor.from_tensor(tensor)
132 | sparse_tensor[torch.tensor([[[0, 2]]])] = 60
133 | assert sparse_tensor.indices.tolist() == [[1, 1], [0, 0], [0, 1], [0, 2], [2, 0], [2, 1], [2, 2]]
134 | assert sparse_tensor.values.tolist() == [2, 60, 60, 60, 60, 60, 60]
135 | assert sparse_tensor.to_dense().tolist() == [[60, 60, 60], [0, 2, 0], [60, 60, 60]]
136 |
137 | sparse_tensor = SparseTensor.from_tensor(tensor)
138 | sparse_tensor[torch.tensor([[0, 1], [2, 0]]), torch.tensor([[1, 2], [2, 0]])] = 70
139 | assert sparse_tensor.indices.tolist() == [[1, 1], [0, 1], [1, 2], [2, 2], [0, 0]]
140 | assert sparse_tensor.values.tolist() == [2, 70, 70, 70, 70]
141 | assert sparse_tensor.to_dense().tolist() == [[70, 70, 0], [0, 2, 70], [0, 0, 70]]
142 |
143 | sparse_tensor = SparseTensor.from_tensor(tensor)
144 | sparse_tensor[torch.tensor([[0, 1], [2, 0]]), torch.tensor([[1, 2], [2, 0]])] = torch.tensor([[70, 80], [90, 100]])
145 | assert sparse_tensor.indices.tolist() == [[1, 1], [0, 1], [1, 2], [2, 2], [0, 0]]
146 | assert sparse_tensor.values.tolist() == [2, 70, 80, 90, 100]
147 | assert sparse_tensor.to_dense().tolist() == [[100, 70, 0], [0, 2, 80], [0, 0, 90]]
148 |
149 | sparse_tensor = SparseTensor.from_tensor(tensor)
150 | sparse_tensor[0:2] = 80
151 | assert sparse_tensor.indices.tolist() == [[2, 2], [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]]
152 | assert sparse_tensor.values.tolist() == [3, 80, 80, 80, 80, 80, 80]
153 | assert sparse_tensor.to_dense().tolist() == [[80, 80, 80], [80, 80, 80], [0, 0, 3]]
154 |
155 |
156 | def test_diagonal():
157 | tensor = torch.randn(2, 5, 3, 5)
158 | sparse_tensor = SparseTensor.from_tensor(tensor)
159 |
160 | assert (tensor.diagonal(dim1=1, dim2=3) == sparse_tensor.diagonal(dim1=1, dim2=3).to_dense()).all()
161 |
162 |
163 | def test_equals():
164 | tensor = torch.randn(2, 5, 3, 5)
165 | sparse_tensor = SparseTensor.from_tensor(tensor)
166 |
167 | assert tensor == sparse_tensor
168 | assert (tensor == sparse_tensor.to_dense()).all()
169 | assert sparse_tensor != SparseTensor.from_tensor(torch.randn(2, 5, 3, 5))
170 |
171 |
172 | def test_add():
173 | tensor = torch.randint(0, 5, [2, 5, 3, 5])
174 | tensor2 = torch.randint(0, 5, [2, 5, 3, 5])
175 | sparse_tensor = SparseTensor.from_tensor(tensor)
176 | sparse_tensor2 = SparseTensor.from_tensor(tensor2)
177 |
178 | assert (tensor + 1 == (sparse_tensor + 1).to_dense()).all()
179 | assert (tensor + tensor == (sparse_tensor + sparse_tensor).to_dense()).all()
180 | assert (tensor + tensor2 == (sparse_tensor + sparse_tensor2).to_dense()).all()
181 |
182 |
183 | def test_unsqueeze():
184 | tensor = torch.randn(2, 5, 3, 5)
185 | sparse_tensor = SparseTensor.from_tensor(tensor)
186 |
187 | assert sparse_tensor.unsqueeze(0).shape == (1, 2, 5, 3, 5)
188 | assert sparse_tensor.unsqueeze(1).shape == (2, 1, 5, 3, 5)
189 | assert sparse_tensor.unsqueeze(2).shape == (2, 5, 1, 3, 5)
190 | assert sparse_tensor.unsqueeze(3).shape == (2, 5, 3, 1, 5)
191 | assert sparse_tensor.unsqueeze(4).shape == (2, 5, 3, 5, 1)
192 |
193 |
194 | def test_to():
195 | tensor = torch.randn(2, 5, 3, 5)
196 | sparse_tensor = SparseTensor.from_tensor(tensor)
197 |
198 | assert sparse_tensor.to(torch.device("cpu")) == sparse_tensor
199 | assert (sparse_tensor.to(torch.device("cpu")).to_dense() == tensor).all()
200 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from memoria.utils import super_unique
4 |
5 |
6 | def test_super_unique():
7 | x = torch.tensor(
8 | [
9 | [[0, 1, 0, 2, 1], [1, 4, 1, 2, 1], [3, 4, 1, 4, 3], [2, 0, 0, 3, 2]],
10 | [[0, 1, 4, 2, 4], [1, 4, 3, 2, 0], [2, 3, 4, 2, 0], [4, 4, 4, 3, 0]],
11 | [[2, 3, 2, 3, 3], [2, 0, 1, 3, 0], [1, 3, 2, 0, 0], [1, 2, 3, 0, 0]],
12 | ],
13 | dtype=torch.int32,
14 | )
15 | assert (
16 | super_unique(x, dim=1)
17 | == torch.tensor(
18 | [
19 | [[1, 1, 1, 2, 1], [3, 4, 0, 4, 3], [2, 0, -1, 3, 2], [0, -1, -1, -1, -1]],
20 | [[2, 1, 4, 2, 4], [4, 4, 3, 3, 0], [0, 3, -1, -1, -1], [1, -1, -1, -1, -1]],
21 | [[1, 2, 1, 0, 0], [2, 3, 3, 3, 3], [-1, 0, 2, -1, -1], [-1, -1, -1, -1, -1]],
22 | ]
23 | )
24 | ).all()
25 |
26 | x = torch.tensor([[2, 3, 4, 3, 0], [3, 1, 3, 1, 0], [4, 3, 2, 2, 4], [2, 2, 2, 0, 3]], dtype=torch.int32)
27 | assert (
28 | super_unique(x, dim=0) == torch.tensor([[2, 1, 2, 1, 0], [4, 3, 4, 3, 3], [3, 2, 3, 2, 4], [-1, -1, -1, 0, -1]])
29 | ).all()
30 |
--------------------------------------------------------------------------------