├── README.md ├── env1.yaml ├── env2.yaml ├── length_extrapolation.sh ├── network.png ├── requirements_lra.txt ├── requirements_tnn.txt ├── run_task.sh ├── script_alm.sh ├── script_blm.sh ├── script_im.sh ├── script_lra.py ├── train_alm.sh ├── train_blm.sh ├── train_im.sh └── train_lra.sh /README.md: -------------------------------------------------------------------------------- 1 | # TNN 2 | 3 | Official implementation of Toeplitz Neural Network in our ICLR 2023 paper - [Toeplitz Neural Network for Sequence Modeling](https://openreview.net/forum?id=IxmWsm4xrua). This repo does not contain specific codes, but only scripts and some instructions on how to reproduce the results of the paper. The overall directory is as follows: 4 | 5 | 6 | - [TNN](#tnn) 7 | - [Network Architecture](#network-architecture) 8 | - [Experiments](#experiments) 9 | - [Environments Preparation](#environments-preparation) 10 | - [Env1](#env1) 11 | - [Env2](#env2) 12 | - [Autoregressive language model](#autoregressive-language-model) 13 | - [1) Preprocess the data](#1-preprocess-the-data) 14 | - [2) Train the autoregressive language model](#2-train-the-autoregressive-language-model) 15 | - [3) Length extrapolation](#3-length-extrapolation) 16 | - [Bidirectional language model](#bidirectional-language-model) 17 | - [1) Preprocess the data](#1-preprocess-the-data-1) 18 | - [2) Train the bidirectional language model](#2-train-the-bidirectional-language-model) 19 | - [3) Finetuning](#3-finetuning) 20 | - [Image modeling](#image-modeling) 21 | - [1) Preparation](#1-preparation) 22 | - [2) Training](#2-training) 23 | - [LRA](#lra) 24 | - [1) Preparation](#1-preparation-1) 25 | - [2) Training](#2-training-1) 26 | - [Speed test](#speed-test) 27 | - [Standalone code](#standalone-code) 28 | - [Citation](#citation) 29 | - [Wip](#wip) 30 | 31 | 32 | 33 | ## Network Architecture 34 | 35 | The overall network architecture is as follows: 36 | 37 | ![](./network.png) 38 | 39 | 40 | 41 | ## Experiments 42 | 43 | ### Environments Preparation 44 | 45 | Our experiment uses two conda environments, where Autoregressive language modeling, Bidirectional language modeling, and Image modeling need to configure the environment according to the Env1 part, and LRA needs to configure the environment according to the Env2 part. 46 | 47 | #### Env1 48 | 49 | First, build the conda environment based on the yaml file: 50 | 51 | ``` 52 | conda env create --file env1.yaml 53 | ``` 54 | 55 | If you meet an error when installing torch, just remove torch and torchvision in the yaml file, rerun the above command, and then run the below commands: 56 | 57 | ``` 58 | pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html 59 | pip install -r requirements_tnn.txt 60 | ``` 61 | 62 | Finally, install our version of fairseq: 63 | 64 | ``` 65 | git clone https://github.com/OpenNLPLab/fairseq-evo.git 66 | cd fairseq 67 | pip install --editable ./ 68 | ``` 69 | 70 | 71 | 72 | #### Env2 73 | 74 | Build the conda environment based on the yaml file: 75 | 76 | ``` 77 | conda env create --file env2.yaml 78 | ``` 79 | If you encounter difficulties in setting up the environment, you can install the conda environment first, and then use the following command to install the pip packages: 80 | ``` 81 | pip install torch==1.10.0+cu111 torchvision==0.11.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html 82 | pip install -r requirements_lra.txt 83 | ``` 84 | 85 | 86 | ### Autoregressive language model 87 | 88 | #### 1) Preprocess the data 89 | 90 | First download the [WikiText-103 dataset](https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/): 91 | 92 | ``` 93 | wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip 94 | unzip wikitext-103-raw-v1.zip 95 | ``` 96 | 97 | Next, encode it with the GPT-2 BPE: 98 | 99 | ``` 100 | mkdir -p gpt2_bpe 101 | wget -O gpt2_bpe/encoder.json https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json 102 | wget -O gpt2_bpe/vocab.bpe https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe 103 | for SPLIT in train valid test; do \ 104 | python -m examples.roberta.multiprocessing_bpe_encoder \ 105 | --encoder-json gpt2_bpe/encoder.json \ 106 | --vocab-bpe gpt2_bpe/vocab.bpe \ 107 | --inputs wikitext-103-raw/wiki.${SPLIT}.raw \ 108 | --outputs wikitext-103-raw/wiki.${SPLIT}.bpe \ 109 | --keep-empty \ 110 | --workers 60; \ 111 | done 112 | ``` 113 | 114 | Finally, preprocess/binarize the data using the GPT-2 fairseq dictionary: 115 | 116 | ``` 117 | wget -O gpt2_bpe/dict.txt https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt 118 | fairseq-preprocess \ 119 | --only-source \ 120 | --srcdict gpt2_bpe/dict.txt \ 121 | --trainpref wikitext-103-raw/wiki.train.bpe \ 122 | --validpref wikitext-103-raw/wiki.valid.bpe \ 123 | --testpref wikitext-103-raw/wiki.test.bpe \ 124 | --destdir data-bin/wikitext-103 \ 125 | --workers 60 126 | ``` 127 | 128 | This step comes from [fairseq](https://github.com/facebookresearch/fairseq/blob/main/examples/roberta/README.pretraining.md). 129 | 130 | 131 | 132 | #### 2) Train the autoregressive language model 133 | 134 | Use the following command to train the autoregressive language model: 135 | 136 | ``` 137 | bash script_alm.sh 138 | ``` 139 | 140 | You should change data_dir to preprocessed data. If you are using a slurm cluster, please add `--distributed-port $PORT` to fairseq-train's parameter. 141 | 142 | 143 | 144 | #### 3) Length extrapolation 145 | 146 | After training, you can do a length extrapolation test by the following command, where length is the test length, e.g. 512, 1024,....: 147 | 148 | ``` 149 | bash length_extrapolation.sh tnn_v2_decay_99_pre length 150 | ``` 151 | 152 | 153 | 154 | 155 | 156 | ### Bidirectional language model 157 | 158 | #### 1) Preprocess the data 159 | 160 | The same as the autoregressive language model part. 161 | 162 | 163 | 164 | #### 2) Train the bidirectional language model 165 | 166 | Use the following command to train the bidirectional language model: 167 | 168 | ``` 169 | bash script_blm.sh 170 | ``` 171 | 172 | You should change data_dir to preprocessed data. If you are using a slurm cluster, please add `--distributed-port $PORT` to fairseq-train's parameter. 173 | 174 | 175 | 176 | #### 3) Finetuning 177 | 178 | Please refer to the [official Fairseq script](https://github.com/facebookresearch/fairseq/blob/main/examples/roberta/README.glue.md). 179 | 180 | 181 | 182 | ### Image modeling 183 | 184 | #### 1) Preparation 185 | 186 | Download the codebase: 187 | 188 | ``` 189 | git clone https://github.com/OpenNLPLab/im.git 190 | ``` 191 | 192 | 193 | 194 | #### 2) Training 195 | 196 | Use the following command for training: 197 | 198 | ``` 199 | bash script_im.sh 200 | ``` 201 | 202 | 203 | 204 | ### LRA 205 | 206 | #### 1) Preparation 207 | 208 | Download the codebase: 209 | 210 | ``` 211 | git clone https://github.com/OpenNLPLab/lra.gits 212 | ``` 213 | 214 | Download the data: 215 | 216 | ``` 217 | wget https://storage.googleapis.com/long-range-arena/lra_release.gz 218 | mv lra_release.gz lra_release.tar.gz 219 | tar -xvf lra_release.tar.gz 220 | ``` 221 | 222 | 223 | 224 | #### 2) Training 225 | 226 | Use the following script to run the experiments, you should change `PREFIX` to your lra path, and change `tasks` to a specific task, for aan, imdb and listops, the `archs` should be `tno`, for other tasks, the `archs` should be `tno2d`: 227 | 228 | ``` 229 | python script_lra.py 230 | ``` 231 | 232 | ### Speed test 233 | 234 | For Figure 1, we used the imdb task from lra benchmark for testing, and the config is tno-lra-imdb.yaml, and other models can adjust the model size according to this configuration. (we are cleaning up the code) 235 | 236 | ## Standalone code 237 | 238 | For those of you who want to use tnn in your projects, you can install tnn-pytorch: 239 | 240 | ``` 241 | $ pip install tnn-pytorch 242 | ``` 243 | 244 | The code base is at the following address, you can adapt it as needed: 245 | 246 | - [https://github.com/Doraemonzzz/tnn-pytorch](https://github.com/Doraemonzzz/tnn-pytorch) 247 | 248 | 249 | 250 | ## Citation 251 | 252 | ``` 253 | @inproceedings{ 254 | qin2023toeplitz, 255 | title={Toeplitz Neural Network for Sequence Modeling}, 256 | author={Zhen Qin and Xiaodong Han and Weixuan Sun and Bowen He and Dong Li and Dongxu Li and Yuchao Dai and Lingpeng Kong and Yiran Zhong}, 257 | booktitle={The Eleventh International Conference on Learning Representations }, 258 | year={2023}, 259 | url={https://openreview.net/forum?id=IxmWsm4xrua} 260 | } 261 | ``` 262 | 263 | 264 | 265 | ## Wip 266 | 267 | - [ ] Check the training script. 268 | - [ ] Update tnn-pytorch. 269 | 270 | 271 | 272 | -------------------------------------------------------------------------------- /env1.yaml: -------------------------------------------------------------------------------- 1 | name: tnn 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2022.07.19=h06a4308_0 8 | - certifi=2021.5.30=py36h06a4308_0 9 | - libedit=3.1.20210910=h7f8727e_0 10 | - libffi=3.2.1=hf484d3e_1007 11 | - libgcc-ng=11.2.0=h1234567_1 12 | - libgomp=11.2.0=h1234567_1 13 | - libstdcxx-ng=11.2.0=h1234567_1 14 | - ncurses=6.3=h5eee18b_3 15 | - openssl=1.1.1q=h7f8727e_0 16 | - pip=21.2.2=py36h06a4308_0 17 | - python=3.6.9=h265db76_0 18 | - readline=7.0=h7b6447c_5 19 | - setuptools=58.0.4=py36h06a4308_0 20 | - sqlite=3.33.0=h62c20be_0 21 | - tk=8.6.12=h1ccaba5_0 22 | - wheel=0.37.1=pyhd3eb1b0_0 23 | - xz=5.2.6=h5eee18b_0 24 | - zlib=1.2.12=h5eee18b_3 25 | - pip: 26 | - antlr4-python3-runtime==4.8 27 | - cffi==1.15.1 28 | - charset-normalizer==2.0.12 29 | - colorama==0.4.5 30 | - cython==0.29.32 31 | - dataclasses==0.8 32 | - einops==0.4.1 33 | - filelock==3.4.1 34 | - huggingface-hub==0.4.0 35 | - hydra-core==1.0.7 36 | - idna==3.4 37 | - importlib-metadata==4.8.3 38 | - importlib-resources==5.4.0 39 | - jieba==0.42.1 40 | - lxml==4.9.1 41 | - numpy==1.19.5 42 | - omegaconf==2.0.6 43 | - opencv-python==4.6.0.66 44 | - opt-einsum==3.3.0 45 | - packaging==21.3 46 | - pillow==8.4.0 47 | - portalocker==2.5.1 48 | - pycparser==2.21 49 | - pyparsing==3.0.9 50 | - pyyaml==6.0 51 | - regex==2022.9.13 52 | - requests==2.27.1 53 | - sacrebleu==2.2.1 54 | - scipy==1.5.4 55 | - sentencepiece==0.1.97 56 | - tabulate==0.8.10 57 | - timm==0.4.12 58 | - torch==1.8.1+cu111 59 | - torchvision==0.9.1 60 | - tqdm==4.64.1 61 | - typing-extensions==4.1.1 62 | - urllib3==1.26.12 63 | - zipp==3.6.0 -------------------------------------------------------------------------------- /env2.yaml: -------------------------------------------------------------------------------- 1 | name: lra 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2022.4.26=h06a4308_0 8 | - certifi=2022.5.18.1=py38h06a4308_0 9 | - ld_impl_linux-64=2.38=h1181459_1 10 | - libffi=3.3=he6710b0_2 11 | - libgcc-ng=11.2.0=h1234567_1 12 | - libgomp=11.2.0=h1234567_1 13 | - libstdcxx-ng=11.2.0=h1234567_1 14 | - ncurses=6.3=h7f8727e_2 15 | - openssl=1.1.1o=h7f8727e_0 16 | - pip=21.2.4=py38h06a4308_0 17 | - python=3.8.13=h12debd9_0 18 | - readline=8.1.2=h7f8727e_1 19 | - sqlite=3.38.3=hc218d9a_0 20 | - tk=8.6.12=h1ccaba5_0 21 | - wheel=0.37.1=pyhd3eb1b0_0 22 | - xz=5.2.5=h7f8727e_1 23 | - zlib=1.2.12=h7f8727e_2 24 | - pip: 25 | - absl-py==1.2.0 26 | - aiohttp==3.8.1 27 | - aiosignal==1.2.0 28 | - antlr4-python3-runtime==4.9.3 29 | - async-timeout==4.0.2 30 | - attrs==22.1.0 31 | - cachetools==5.2.0 32 | - charset-normalizer==2.1.0 33 | - click==8.1.3 34 | - cmake==3.24.0 35 | - cycler==0.11.0 36 | - datasets==2.4.0 37 | - deprecated==1.2.13 38 | - dill==0.3.5.1 39 | - docker-pycreds==0.4.0 40 | - einops==0.4.1 41 | - filelock==3.8.0 42 | - fonttools==4.34.4 43 | - frozenlist==1.3.1 44 | - fsspec==2022.7.1 45 | - future==0.18.2 46 | - gitdb==4.0.9 47 | - gitpython==3.1.27 48 | - google-auth==2.10.0 49 | - google-auth-oauthlib==0.4.6 50 | - grpcio==1.48.0 51 | - huggingface-hub==0.8.1 52 | - hydra-core==1.2.0 53 | - idna==3.3 54 | - importlib-metadata==4.12.0 55 | - importlib-resources==5.9.0 56 | - joblib==1.1.0 57 | - kiwisolver==1.4.4 58 | - llvmlite==0.39.0 59 | - markdown==3.4.1 60 | - matplotlib==3.5.3 61 | - multidict==6.0.2 62 | - multiprocess==0.70.13 63 | - munch==2.5.0 64 | - numba==0.56.0 65 | - numpy==1.22.4 66 | - oauthlib==3.2.0 67 | - omegaconf==2.2.2 68 | - opt-einsum==3.3.0 69 | - packaging==21.3 70 | - pandas==1.4.3 71 | - pathtools==0.1.2 72 | - patsy==0.5.2 73 | - pillow==9.2.0 74 | - promise==2.3 75 | - protobuf==3.19.4 76 | - psutil==5.9.1 77 | - pyarrow==9.0.0 78 | - pyasn1==0.4.8 79 | - pyasn1-modules==0.2.8 80 | - pydeprecate==0.3.1 81 | - pygments==2.12.0 82 | - pynvml==11.4.1 83 | - pyparsing==3.0.9 84 | - python-dateutil==2.8.2 85 | - pytorch-lightning==1.5.10 86 | - pytz==2022.1 87 | - pyyaml==6.0 88 | - regex==2022.7.25 89 | - requests==2.28.1 90 | - requests-oauthlib==1.3.1 91 | - responses==0.18.0 92 | - rich==12.6.0 93 | - rsa==4.9 94 | - scikit-learn==1.1.2 95 | - scipy==1.8.1 96 | - sentry-sdk==1.9.3 97 | - setproctitle==1.3.2 98 | - setuptools==59.5.0 99 | - shortuuid==1.0.9 100 | - six==1.16.0 101 | - sklearn==0.0 102 | - sktime==0.13.1 103 | - smmap==5.0.0 104 | - statsmodels==0.13.2 105 | - tensorboard==2.10.0 106 | - tensorboard-data-server==0.6.1 107 | - tensorboard-plugin-wit==1.8.1 108 | - threadpoolctl==3.1.0 109 | - timm==0.6.7 110 | - tokenizers==0.12.1 111 | - torch==1.10.0+cu111 112 | - torch-tb-profiler==0.4.0 113 | - torchaudio==0.10.0 114 | - torchmetrics==0.9.3 115 | - torchtext==0.11.0 116 | - torchvision==0.11.1 117 | - tqdm==4.64.0 118 | - transformers==4.21.1 119 | - typing-extensions==4.3.0 120 | - urllib3==1.26.11 121 | - wandb==0.13.1 122 | - werkzeug==2.2.2 123 | - wrapt==1.14.1 124 | - xxhash==3.0.0 125 | - yarl==1.8.1 126 | - zipp==3.8.1 -------------------------------------------------------------------------------- /length_extrapolation.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/bash 2 | 3 | batch_size=2 4 | data_dir=path_to_bin_data 5 | ARCH=$1 6 | ckpt=your_ckpt_dir/$ARCH/checkpoint_best.pt 7 | l=$2 8 | 9 | fairseq-eval-lm \ 10 | $data_dir \ 11 | --sample-break-mode none \ 12 | --path $ckpt \ 13 | --max-sentences 1 \ 14 | --model-overrides \"{'max_tokens':$l, 'tokens_per_sample':$l, 'max_target_positions':$l}\" 15 | --max-tokens $l \ 16 | --tokens-per-sample $l \ 17 | --max-target-positions $l \ 18 | --context-window 0 -------------------------------------------------------------------------------- /network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenNLPLab/Tnn/0b66c46dae5ad02b679dae5dca0c17853c7e7109/network.png -------------------------------------------------------------------------------- /requirements_lra.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.2.0 2 | aiohttp==3.8.1 3 | aiosignal==1.2.0 4 | antlr4-python3-runtime==4.9.3 5 | async-timeout==4.0.2 6 | attrs==22.1.0 7 | cachetools==5.2.0 8 | charset-normalizer==2.1.0 9 | click==8.1.3 10 | cmake==3.24.0 11 | cycler==0.11.0 12 | datasets==2.4.0 13 | deprecated==1.2.13 14 | dill==0.3.5.1 15 | docker-pycreds==0.4.0 16 | einops==0.4.1 17 | filelock==3.8.0 18 | fonttools==4.34.4 19 | frozenlist==1.3.1 20 | fsspec==2022.7.1 21 | future==0.18.2 22 | gitdb==4.0.9 23 | gitpython==3.1.27 24 | google-auth==2.10.0 25 | google-auth-oauthlib==0.4.6 26 | grpcio==1.48.0 27 | huggingface-hub==0.8.1 28 | hydra-core==1.2.0 29 | idna==3.3 30 | importlib-metadata==4.12.0 31 | importlib-resources==5.9.0 32 | joblib==1.1.0 33 | kiwisolver==1.4.4 34 | llvmlite==0.39.0 35 | markdown==3.4.1 36 | matplotlib==3.5.3 37 | multidict==6.0.2 38 | multiprocess==0.70.13 39 | munch==2.5.0 40 | numba==0.56.0 41 | numpy==1.22.4 42 | oauthlib==3.2.0 43 | omegaconf==2.2.2 44 | opt-einsum==3.3.0 45 | packaging==21.3 46 | pandas==1.4.3 47 | pathtools==0.1.2 48 | patsy==0.5.2 49 | pillow==9.2.0 50 | promise==2.3 51 | protobuf==3.19.4 52 | psutil==5.9.1 53 | pyarrow==9.0.0 54 | pyasn1==0.4.8 55 | pyasn1-modules==0.2.8 56 | pydeprecate==0.3.1 57 | pygments==2.12.0 58 | pynvml==11.4.1 59 | pyparsing==3.0.9 60 | python-dateutil==2.8.2 61 | pytorch-lightning==1.5.10 62 | pytz==2022.1 63 | pyyaml==6.0 64 | regex==2022.7.25 65 | requests==2.28.1 66 | requests-oauthlib==1.3.1 67 | responses==0.18.0 68 | rich==12.6.0 69 | rsa==4.9 70 | scikit-learn==1.1.2 71 | scipy==1.8.1 72 | sentry-sdk==1.9.3 73 | setproctitle==1.3.2 74 | setuptools==59.5.0 75 | shortuuid==1.0.9 76 | six==1.16.0 77 | sklearn==0.0 78 | sktime==0.13.1 79 | smmap==5.0.0 80 | statsmodels==0.13.2 81 | tensorboard==2.10.0 82 | tensorboard-data-server==0.6.1 83 | tensorboard-plugin-wit==1.8.1 84 | threadpoolctl==3.1.0 85 | timm==0.6.7 86 | tokenizers==0.12.1 87 | torch==1.10.0+cu111 88 | torch-tb-profiler==0.4.0 89 | torchaudio==0.10.0 90 | torchmetrics==0.9.3 91 | torchtext==0.11.0 92 | torchvision==0.11.1 93 | tqdm==4.64.0 94 | transformers==4.21.1 95 | typing-extensions==4.3.0 96 | urllib3==1.26.11 97 | wandb==0.13.1 98 | werkzeug==2.2.2 99 | wrapt==1.14.1 100 | xxhash==3.0.0 101 | yarl==1.8.1 102 | zipp==3.8.1 -------------------------------------------------------------------------------- /requirements_tnn.txt: -------------------------------------------------------------------------------- 1 | antlr4-python3-runtime==4.8 2 | cffi==1.15.1 3 | charset-normalizer==2.0.12 4 | colorama==0.4.5 5 | cython==0.29.32 6 | dataclasses==0.8 7 | einops==0.4.1 8 | filelock==3.4.1 9 | huggingface-hub==0.4.0 10 | hydra-core==1.0.7 11 | idna==3.4 12 | importlib-metadata==4.8.3 13 | importlib-resources==5.4.0 14 | jieba==0.42.1 15 | lxml==4.9.1 16 | numpy==1.19.5 17 | omegaconf==2.0.6 18 | opencv-python==4.6.0.66 19 | opt-einsum==3.3.0 20 | packaging==21.3 21 | pillow==8.4.0 22 | portalocker==2.5.1 23 | pycparser==2.21 24 | pyparsing==3.0.9 25 | pyyaml==6.0 26 | regex==2022.9.13 27 | requests==2.27.1 28 | sacrebleu==2.2.1 29 | scipy==1.5.4 30 | sentencepiece==0.1.97 31 | tabulate==0.8.10 32 | timm==0.4.12 33 | torch==1.8.1+cu111 34 | torchvision==0.9.1 35 | tqdm==4.64.1 36 | typing-extensions==4.1.1 37 | urllib3==1.26.12 38 | zipp==3.6.0 -------------------------------------------------------------------------------- /run_task.sh: -------------------------------------------------------------------------------- 1 | export HYDRA_FULL_ERROR=1 2 | export DATA_PATH=path_to_your_lra_dir 3 | 4 | program_path=path_to_your_lra_dir 5 | 6 | TASK=$1 7 | ARCH=$2 8 | BS=$3 9 | N_LAYERS=$4 10 | D_MODEL=$5 11 | GTU_DPB_DIM=$6 12 | NORM=$7 13 | EXPAND_RATIO_GTU=$8 14 | EXPAND_RATIO_GLU=$9 15 | GTU_USE_DECAY=${10} 16 | GTU_GAMMA=${11} 17 | lr=${12} 18 | wd=${13} 19 | cards=${14} 20 | n_works=${15} 21 | dropout=${16} 22 | n_works=4 23 | dpb_type=${17} 24 | dpb_layers=${18} 25 | PRENORM=${19} 26 | warmup_steps=${20} 27 | 28 | python ${program_path}/train.py wandb=null experiment=${ARCH}-lra-${TASK} \ 29 | trainer.gpus=$cards \ 30 | loader.batch_size=${BS} \ 31 | loader.num_workers=${n_works} \ 32 | scheduler.num_warmup_steps=${warmup_steps} \ 33 | optimizer.lr=${lr} optimizer.weight_decay=${wd} \ 34 | model.n_layers=${N_LAYERS} model.d_model=${D_MODEL} \ 35 | model.norm=${NORM} model.prenorm=${PRENORM} train.seed=2222 \ 36 | model.gtu_rpe_dim=${GTU_DPB_DIM} \ 37 | model.expand_ratio_gtu=${EXPAND_RATIO_GTU} model.expand_ratio_glu=${EXPAND_RATIO_GLU} \ 38 | model.gtu_use_decay=True model.gtu_gamma=${GTU_GAMMA} \ 39 | model.dropout=${dropout} 40 | -------------------------------------------------------------------------------- /script_alm.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/bash 2 | 3 | arch=tnn_v2_decay_99_pre 4 | # change to your data dir 5 | data_dir=path_to_bin_data 6 | 7 | bash train_alm.sh 8 $arch $data_dir -------------------------------------------------------------------------------- /script_blm.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/bash 2 | 3 | arch=roberta_tnn_v2_decay_99 4 | # change to your data dir 5 | data_dir=path_to_bin_data 6 | 7 | bash train_blm.sh 8 $arch $data_dir -------------------------------------------------------------------------------- /script_im.sh: -------------------------------------------------------------------------------- 1 | GPUS=16 2 | BATCH=128 3 | # you can also choose tno_vit_e3g1_small_rpe_l1_90_prenorm 4 | arch=tnn_vit_e3g1_tiny_rpe_l1_90_prenorm 5 | LR=0.0005 6 | DATASET=IMNET 7 | decay=0.05 8 | warmup=10 9 | clip_grad=5 10 | 11 | bash train_im.sh $GPUS $BATCH $arch \ 12 | $arch $LR $DATASET \ 13 | $decay $warmup $clip_grad -------------------------------------------------------------------------------- /script_lra.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | 5 | # change 6 | gtu_use_decay = True 7 | PREFIX = "your path to lra" 8 | 9 | batches = { 10 | "cifar": 200, 11 | "imdb": 32, 12 | "listops": 128, 13 | "pathfinder": 128, 14 | "pathfinderx": 64, 15 | "aan": 64, 16 | } 17 | 18 | gpus = { 19 | "cifar": 4, 20 | "imdb": 2, 21 | "listops": 2, 22 | "pathfinder": 4, 23 | "pathfinderx": 8, 24 | "aan": 2, 25 | } 26 | 27 | d_model_dict = { 28 | "cifar": 128, 29 | "imdb": 128, 30 | "listops": 128, 31 | "pathfinder": 64, 32 | "pathfinderx": 64, 33 | "aan": 64, 34 | } 35 | 36 | n_layers_dict = { 37 | "cifar": 12, 38 | "imdb": 4, 39 | "listops": 4, 40 | "pathfinder": 6, 41 | "pathfinderx": 6, 42 | "aan": 2, 43 | } 44 | 45 | expand_ratio_gtu_dict = { 46 | "cifar": 2, 47 | "imdb": 1, 48 | "listops": 1, 49 | "pathfinder": 3, 50 | "pathfinderx": 3, 51 | "aan": 1, 52 | } 53 | 54 | expand_ratio_glu_dict = { 55 | "cifar": 3, 56 | "imdb": 2, 57 | "listops": 1, 58 | "pathfinder": 2.5, 59 | "pathfinderx": 2.5, 60 | "aan": 1.5, 61 | } 62 | 63 | gtu_gamma = { 64 | "cifar": 0.7, 65 | "imdb": 0.9, 66 | "listops": 1, 67 | "pathfinder": [1], 68 | "pathfinderx": [0.999], 69 | "aan": [0.6], 70 | } 71 | 72 | norm_dict = { 73 | "cifar": "synbatch", 74 | "imdb": "synbatch", 75 | "listops": "synbatch", 76 | "pathfinder": "synbatch", 77 | "pathfinderx": "synbatch", 78 | "aan": "synbatch", 79 | } 80 | 81 | lr_dict = { 82 | "cifar": 0.007, 83 | "imdb": [0.00105], 84 | "listops": 0.0005, 85 | "pathfinder": [0.001], 86 | "pathfinderx": [0.0001], 87 | "aan": [0.01], 88 | } 89 | 90 | wd_dict = { 91 | "cifar": [0.001], 92 | "imdb": 0.001, 93 | "listops": 0.1, 94 | "pathfinder": 0, 95 | "pathfinderx": 0, 96 | "aan": [0.01], 97 | } 98 | 99 | dropout_dict = { 100 | "cifar": [0.1], 101 | "imdb": 0.1, 102 | "listops": 0, 103 | "pathfinder": 0, 104 | "pathfinderx": 0, 105 | "aan": 0, 106 | } 107 | 108 | dpb_type_dict = { 109 | "cifar": 1, 110 | "imdb": 1, 111 | "listops": 1, 112 | "pathfinder": 1, 113 | "pathfinderx": 1, 114 | "aan": 1, 115 | } 116 | 117 | dpb_layers_dict = { 118 | "cifar": 0, 119 | "imdb": 3, 120 | "listops": 3, 121 | "pathfinder": 0, 122 | "pathfinderx": 0, 123 | "aan": 3, 124 | } 125 | 126 | gtu_dpb_dim_dict = { 127 | "cifar": 16, 128 | "imdb": 32, 129 | "listops": 32, 130 | "pathfinder": 16, 131 | "pathfinderx": 16, 132 | "aan": 16, 133 | } 134 | 135 | prenorm_dict = { 136 | "cifar": True, 137 | "imdb": True, 138 | "listops": True, 139 | "pathfinder": True, 140 | "pathfinderx": True, 141 | "aan": True, 142 | } 143 | 144 | warmup_steps_dict = { 145 | "cifar": [30000], 146 | "imdb": 3000, 147 | "listops": [1000], 148 | "pathfinder": [5000], 149 | "pathfinderx": 312, 150 | "aan": [25000], 151 | } 152 | 153 | 154 | # for these tasks, you should use tno 155 | tasks = ["aan"] 156 | tasks = ["imdb"] 157 | tasks = ["listops"] 158 | archs = ["tno"] 159 | 160 | # for these tasks, you should use tno2d 161 | tasks = ["cifar"] 162 | tasks = ["pathfinder"] 163 | tasks = ["pathfinderx"] 164 | archs = ["tno2d"] 165 | 166 | 167 | def to_iter(*args): 168 | n = len(args) 169 | new_args = [] 170 | for i in range(n): 171 | if not isinstance(args[i], list): 172 | arg = [args[i]] 173 | else: 174 | arg = args[i] 175 | new_args.append(arg) 176 | 177 | return helper(*new_args) 178 | 179 | 180 | def helper(*args): 181 | n = len(args) 182 | if n == 1: 183 | res = [[arg] for arg in args[0]] 184 | return res 185 | else: 186 | arr = helper(*args[1:]) 187 | res = [] 188 | for par in args[0]: 189 | for data in arr: 190 | res.append([par] + list(data)) 191 | return res 192 | 193 | 194 | for i, task in enumerate(tasks): 195 | pars = to_iter( 196 | archs, 197 | n_layers_dict[task], 198 | expand_ratio_gtu_dict[task], 199 | expand_ratio_glu_dict[task], 200 | d_model_dict[task], 201 | gtu_gamma[task], 202 | batches[task], 203 | norm_dict[task], 204 | lr_dict[task], 205 | wd_dict[task], 206 | dropout_dict[task], 207 | dpb_type_dict[task], 208 | dpb_layers_dict[task], 209 | gtu_dpb_dim_dict[task], 210 | prenorm_dict[task], 211 | warmup_steps_dict[task], 212 | ) 213 | print(pars) 214 | print(task) 215 | print(len(pars)) 216 | time.sleep(10) 217 | for ( 218 | arch, 219 | n_layers, 220 | expand_ratio_gtu, 221 | expand_ratio_glu, 222 | d_model, 223 | gamma, 224 | total_batch, 225 | norm, 226 | lr, 227 | wd, 228 | dropout, 229 | dpb_type, 230 | dpb_layers, 231 | gtu_dpb_dim, 232 | prenorm, 233 | warmup_steps, 234 | ) in pars: 235 | if task == "imdb": 236 | seq_len = 4096 237 | if not gtu_use_decay: 238 | os.system( 239 | f"sh {PREFIX}/run_task.sh {task} {arch} {batch} {n_layers} {d_model} {gtu_dpb_dim} {norm} {expand_ratio_gtu} {expand_ratio_glu} {gtu_use_decay} 0 {lr} {wd} {gpu} {workers} {dropout} {dpb_type} {dpb_layers} {prenorm} {warmup_steps}" 240 | ) 241 | sys.exit(0) 242 | else: 243 | gpu = gpus[task] 244 | batch = total_batch // gpu 245 | workers = gpu * 20 246 | for i in range(1): 247 | print("imdb lr: ", lr) 248 | time.sleep(10) 249 | pid = os.fork() 250 | if pid == 0: 251 | os.system( 252 | f"sh {PREFIX}/run_task.sh {task} {arch} {batch} {n_layers} {d_model} {gtu_dpb_dim} {norm} {expand_ratio_gtu} {expand_ratio_glu} {gtu_use_decay} {gamma} {lr} {wd} {gpu} {workers} {dropout} {dpb_type} {dpb_layers} {prenorm} {warmup_steps}" 253 | ) 254 | sys.exit(0) 255 | elif task == "cifar": 256 | seq_len = 1024 257 | if not gtu_use_decay: 258 | os.system( 259 | f"sh {PREFIX}/run_task.sh {task} {arch} {batch} {n_layers} {d_model} {gtu_dpb_dim} {norm} {expand_ratio_gtu} {expand_ratio_glu} {gtu_use_decay} 0 {lr} {wd} {gpu}" 260 | ) 261 | sys.exit(0) 262 | else: 263 | gpu = gpus[task] 264 | batch = total_batch // gpu 265 | workers = gpu * 20 266 | print("cifar lr: ", lr) 267 | time.sleep(10) 268 | pid = os.fork() 269 | if pid == 0: 270 | os.system( 271 | f"sh {PREFIX}/run_task.sh {task} {arch} {batch} {n_layers} {d_model} {gtu_dpb_dim} {norm} {expand_ratio_gtu} {expand_ratio_glu} {gtu_use_decay} {gamma} {lr} {wd} {gpu} {workers} {dropout} {dpb_type} {dpb_layers} {prenorm} {warmup_steps}" 272 | ) 273 | sys.exit(0) 274 | elif task == "listops": 275 | seq_len = 2048 276 | if not gtu_use_decay: 277 | os.system( 278 | f"sh {PREFIX}/run_task.sh {task} {arch} {batch} {n_layers} {d_model} {gtu_dpb_dim} {norm} {expand_ratio_gtu} {expand_ratio_glu} {gtu_use_decay} 0 {lr} {wd} {gpu}" 279 | ) 280 | sys.exit(0) 281 | else: 282 | gpu = gpus[task] 283 | batch = total_batch // gpu 284 | workers = gpu * 20 285 | print("listops lr: ", lr) 286 | time.sleep(10) 287 | pid = os.fork() 288 | if pid == 0: 289 | os.system( 290 | f"sh {PREFIX}/run_task.sh {task} {arch} {batch} {n_layers} {d_model} {gtu_dpb_dim} {norm} {expand_ratio_gtu} {expand_ratio_glu} {gtu_use_decay} {gamma} {lr} {wd} {gpu} {workers} {dropout} {dpb_type} {dpb_layers} {prenorm} {warmup_steps}" 291 | ) 292 | sys.exit(0) 293 | elif task == "pathfinder": 294 | seq_len = 1024 295 | if not gtu_use_decay: 296 | print("pathfinder lr: ", lr) 297 | time.sleep(10) 298 | pid = os.fork() 299 | if pid == 0: 300 | os.system( 301 | f"sh {PREFIX}/run_task.sh {task} {arch} {batch} {n_layers} {d_model} {gtu_dpb_dim} {norm} {expand_ratio_gtu} {expand_ratio_glu} {gtu_use_decay} 0 {lr} {wd} {gpu} {workers} {dropout} {dpb_type} {dpb_layers} {prenorm} {warmup_steps}" 302 | ) 303 | sys.exit(0) 304 | else: 305 | gpu = gpus[task] 306 | batch = total_batch // gpu 307 | workers = gpu * 20 308 | print("pathfinder lr: ", lr) 309 | time.sleep(10) 310 | pid = os.fork() 311 | if pid == 0: 312 | os.system( 313 | f"sh {PREFIX}/run_task.sh {task} {arch} {batch} {n_layers} {d_model} {gtu_dpb_dim} {norm} {expand_ratio_gtu} {expand_ratio_glu} {gtu_use_decay} {gamma} {lr} {wd} {gpu} {workers} {dropout} {dpb_type} {dpb_layers} {prenorm} {warmup_steps}" 314 | ) 315 | sys.exit(0) 316 | elif task == "pathfinderx": 317 | seq_len = 128 * 128 318 | if not gtu_use_decay: 319 | print("pathfinderx lr: ", lr) 320 | time.sleep(10) 321 | pid = os.fork() 322 | if pid == 0: 323 | os.system( 324 | f"sh {PREFIX}/run_task.sh {task} {arch} {batch} {n_layers} {d_model} {gtu_dpb_dim} {norm} {expand_ratio_gtu} {expand_ratio_glu} {gtu_use_decay} 0 {lr} {wd} {gpu}" 325 | ) 326 | sys.exit(0) 327 | else: 328 | gpu = gpus[task] 329 | batch = total_batch // gpu 330 | workers = gpu * 20 331 | print("pathfinderx lr: ", lr) 332 | time.sleep(10) 333 | pid = os.fork() 334 | if pid == 0: 335 | os.system( 336 | f"sh {PREFIX}/run_task.sh {task} {arch} {batch} {n_layers} {d_model} {gtu_dpb_dim} {norm} {expand_ratio_gtu} {expand_ratio_glu} {gtu_use_decay} {gamma} {lr} {wd} {gpu} {workers} {dropout} {dpb_type} {dpb_layers} {prenorm} {warmup_steps}" 337 | ) 338 | sys.exit(0) 339 | elif task == "aan": 340 | seq_len = 4000 341 | if not gtu_use_decay: 342 | os.system( 343 | f"sh {PREFIX}/run_task.sh {task} {arch} {batch} {n_layers} {d_model} {gtu_dpb_dim} {norm} {expand_ratio_gtu} {expand_ratio_glu} {gtu_use_decay} 0 {lr} {wd} {gpu} {workers} {dropout} {dpb_type} {dpb_layers} {prenorm} {warmup_steps}" 344 | ) 345 | sys.exit(0) 346 | else: 347 | gpu = gpus[task] 348 | batch = total_batch // gpu 349 | workers = gpu * 20 350 | print("aan lr: ", lr) 351 | time.sleep(10) 352 | pid = os.fork() 353 | if pid == 0: 354 | os.system( 355 | f"sh {PREFIX}/run_task.sh {task} {arch} {batch} {n_layers} {d_model} {gtu_dpb_dim} {norm} {expand_ratio_gtu} {expand_ratio_glu} {gtu_use_decay} {gamma} {lr} {wd} {gpu} {workers} {dropout} {dpb_type} {dpb_layers} {prenorm} {warmup_steps}" 356 | ) 357 | sys.exit(0) 358 | -------------------------------------------------------------------------------- /train_alm.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/bash 2 | 3 | BATCH_SIZE=8 4 | TOKENS_PER_SAMPLE=512 5 | MAX_TOKEN=$((TOKENS_PER_SAMPLE*BATCH_SIZE)) 6 | DATA_DIR=$3 7 | ARCH=$2 8 | prefix=lm 9 | MAX_UPDATE=50000 10 | WARM_UP=4000 11 | UPDATE_FREQ=$(( 128 / $BATCH_SIZE / $1 )) 12 | PORT=$(( $RANDOM + 2000 )) 13 | echo $PORT 14 | LR=0.0005 15 | CLIP_NORM=1.0 16 | decay=0.2 17 | 18 | fairseq-train --task language_modeling \ 19 | $DATA_DIR \ 20 | --save-dir checkpoints/$prefix/${ARCH} \ 21 | --distributed-world-size $1 \ 22 | --arch $ARCH --share-decoder-input-output-embed \ 23 | --dropout 0.1 \ 24 | --optimizer adam --adam-betas '(0.9, 0.98)' --weight-decay $decay --clip-norm $CLIP_NORM \ 25 | --lr $LR --lr-scheduler inverse_sqrt --warmup-updates $WARM_UP --warmup-init-lr 1e-07 \ 26 | --tokens-per-sample $TOKENS_PER_SAMPLE --sample-break-mode none \ 27 | --max-tokens $MAX_TOKEN --update-freq $UPDATE_FREQ \ 28 | --ddp-backend=legacy_ddp \ 29 | --batch-size $BATCH_SIZE \ 30 | --max-update $MAX_UPDATE --log-interval 10 2>&1 | tee $ARCH.log -------------------------------------------------------------------------------- /train_blm.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/bash 2 | 3 | TOTAL_UPDATES=50000 # Total number of training steps 4 | WARMUP_UPDATES=3000 # Warmup the learning rate over this many updates 5 | TOKENS_PER_SAMPLE=512 # Max sequence length 6 | MAX_POSITIONS=512 # Num. positional embeddings (usually same as above) 7 | MAX_SENTENCES=4 8 | PEAK_LR=0.0005 # Peak learning rate, adjust as needed 9 | CLIP_NORM=1.0 10 | PORT=$(( $RANDOM + 2000 )) 11 | prefix=roberta 12 | GPUS=$1 13 | ARCH=$2 14 | DATA_DIR=$3 15 | UPDATE_FREQ=$(( 512 / $MAX_SENTENCES / $GPUS )) 16 | 17 | fairseq-train $DATA_DIR \ 18 | --task masked_lm --criterion masked_lm \ 19 | --distributed-world-size $1 \ 20 | --save-dir checkpoints/$prefix/$ARCH \ 21 | --arch $ARCH --sample-break-mode complete --tokens-per-sample $TOKENS_PER_SAMPLE \ 22 | --optimizer adam --adam-betas '(0.9,0.98)' --adam-eps 1e-6 --clip-norm $CLIP_NORM \ 23 | --lr-scheduler polynomial_decay --lr $PEAK_LR --warmup-updates $WARMUP_UPDATES --total-num-update $TOTAL_UPDATES \ 24 | --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.2 \ 25 | --batch-size $MAX_SENTENCES --update-freq $UPDATE_FREQ \ 26 | --ddp-backend=legacy_ddp \ 27 | --find-unused-parameters \ 28 | --max-update $TOTAL_UPDATES --log-format simple --log-interval 1 2>&1 | tee $ARCH.log -------------------------------------------------------------------------------- /train_im.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | export NCCL_LL_THRESHOLD=0 3 | 4 | ARCH=$3 5 | GPUS=$1 6 | batch_size=$2 7 | PORT=$(( $RANDOM + 2000 )) 8 | export MASTER_PORT=${MASTER_PORT:-$PORT} 9 | LR=$5 10 | DATASET=$6 11 | 12 | OUTPUT_DIR=./checkpoints/$4 13 | RESUME=./checkpoints/$4/checkpoint_best.pth 14 | 15 | PROG=path_to_im/classification/main.py 16 | DATA=path_to_data 17 | 18 | python $PROG \ 19 | --data-set $DATASET --data-path $DATA \ 20 | --batch-size $batch_size --dist-eval --output_dir $OUTPUT_DIR \ 21 | --resume $RESUME --model $ARCH --epochs 300 --fp32-resume --lr $LR \ 22 | --weight-decay $7 \ 23 | --warmup-epochs $8 \ 24 | --clip-grad ${9} \ 25 | --broadcast_buffers 26 | -------------------------------------------------------------------------------- /train_lra.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export HYDRA_FULL_ERROR=1 4 | export DATA_PATH=PATH_TO_LRA_DATA 5 | program_path=PATH_TO_LRA_DIR 6 | 7 | TASK=$1 8 | ARCH=$2 9 | BS=$3 10 | N_LAYERS=$4 11 | D_MODEL=$5 12 | GTU_DPB_DIM=$6 13 | NORM=$7 14 | EXPAND_RATIO_GTU=$8 15 | EXPAND_RATIO_GLU=$9 16 | GTU_USE_DECAY=${10} 17 | GTU_GAMMA=${11} 18 | lr=${12} 19 | wd=${13} 20 | cards=${14} 21 | n_works=${15} 22 | dropout=${16} 23 | n_works=4 24 | dpb_type=${17} 25 | dpb_layers=${18} 26 | PRENORM=${19} 27 | warmup_steps=${20} 28 | 29 | python ${program_path}/train.py wandb=null experiment=${ARCH}-lra-${TASK} \ 30 | trainer.gpus=$cards \ 31 | loader.batch_size=${BS} \ 32 | loader.num_workers=${n_works} \ 33 | scheduler.num_warmup_steps=${warmup_steps} \ 34 | optimizer.lr=${lr} optimizer.weight_decay=${wd} \ 35 | model.n_layers=${N_LAYERS} model.d_model=${D_MODEL} model.tno_dpb_dim=${GTU_DPB_DIM} \ 36 | model.norm=${NORM} model.prenorm=${PRENORM} train.seed=2222 \ 37 | model.expand_ratio_tno=${EXPAND_RATIO_GTU} model.expand_ratio_glu=${EXPAND_RATIO_GLU} \ 38 | model.tno_use_decay=True model.tno_gamma=${GTU_GAMMA} \ 39 | model.dropout=${dropout} model.dpb_type=${dpb_type} model.dpb_layers=${dpb_layers} 40 | --------------------------------------------------------------------------------