├── experiments ├── groove │ ├── src │ │ ├── spe_music │ │ │ ├── __init__.py │ │ │ ├── model │ │ │ │ ├── __init__.py │ │ │ │ └── music_performer.py │ │ │ └── style_eval │ │ │ │ ├── __init__.py │ │ │ │ └── note_features.py │ │ └── setup.py │ ├── exp │ │ ├── clone.sh │ │ ├── style_eval_drums_config.yaml │ │ ├── style_eval_config.yaml │ │ ├── trio_performer_softmax_l512_v01 │ │ │ └── config.yaml │ │ ├── trio_performer_softmax_sinespe_l512_v01 │ │ │ └── config.yaml │ │ └── trio_performer_softmax_convspe_l512_v01 │ │ │ └── config.yaml │ ├── requirements.txt │ └── README.md ├── lra │ ├── fast_attention │ │ ├── fast_self_attention │ │ │ └── __init__.py │ │ ├── setup.py │ │ └── README.md │ ├── models │ │ └── gpu_16g │ │ │ ├── performer_softmax │ │ │ ├── cifar10 │ │ │ │ ├── r1 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ │ ├── r2 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ │ ├── tune01 │ │ │ │ │ ├── r1 │ │ │ │ │ │ ├── results.json │ │ │ │ │ │ └── config.py │ │ │ │ │ ├── r2 │ │ │ │ │ │ ├── results.json │ │ │ │ │ │ └── config.py │ │ │ │ │ └── r3 │ │ │ │ │ │ ├── results.json │ │ │ │ │ │ └── config.py │ │ │ │ └── r3 │ │ │ │ │ └── config.py │ │ │ ├── aan │ │ │ │ ├── r1 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ │ ├── r2 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ │ └── r3 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ ├── tc │ │ │ │ ├── r1 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ │ ├── r2 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ │ └── r3 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ └── listops │ │ │ │ ├── r1 │ │ │ │ ├── results.json │ │ │ │ └── config.py │ │ │ │ ├── r2 │ │ │ │ ├── results.json │ │ │ │ └── config.py │ │ │ │ └── r3 │ │ │ │ ├── results.json │ │ │ │ └── config.py │ │ │ ├── performer_softmax_sinespe │ │ │ ├── cifar10 │ │ │ │ ├── r2 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ │ ├── r1 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ │ └── r3 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ ├── aan │ │ │ │ ├── r1 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ │ ├── r2 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ │ └── r3 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ ├── listops │ │ │ │ ├── r2 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ │ ├── r3 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ │ └── r1 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ └── tc │ │ │ │ ├── r1 │ │ │ │ ├── results.json │ │ │ │ └── config.py │ │ │ │ ├── r2 │ │ │ │ ├── results.json │ │ │ │ └── config.py │ │ │ │ └── r3 │ │ │ │ ├── results.json │ │ │ │ └── config.py │ │ │ ├── linear_transformer_relu │ │ │ ├── cifar10 │ │ │ │ ├── r1 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ │ ├── r2 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ │ └── r3 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ ├── aan │ │ │ │ ├── r1 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ │ ├── r2 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ │ └── r3 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ ├── tc │ │ │ │ ├── r1 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ │ ├── r2 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ │ └── r3 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ └── listops │ │ │ │ ├── r1 │ │ │ │ ├── results.json │ │ │ │ └── config.py │ │ │ │ ├── r2 │ │ │ │ ├── results.json │ │ │ │ └── config.py │ │ │ │ └── r3 │ │ │ │ ├── results.json │ │ │ │ └── config.py │ │ │ ├── linear_transformer_relu_sinespe │ │ │ ├── cifar10 │ │ │ │ ├── r1 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ │ ├── r2 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ │ └── r3 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ ├── aan │ │ │ │ ├── r1 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ │ ├── r2 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ │ └── r3 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ ├── listops │ │ │ │ ├── r3 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ │ ├── r1 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ │ └── r2 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ └── tc │ │ │ │ ├── r1 │ │ │ │ ├── results.json │ │ │ │ └── config.py │ │ │ │ ├── r2 │ │ │ │ ├── results.json │ │ │ │ └── config.py │ │ │ │ └── r3 │ │ │ │ ├── results.json │ │ │ │ └── config.py │ │ │ ├── performer_softmax_convspe_k128_shr │ │ │ ├── cifar10 │ │ │ │ ├── r1 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ │ ├── r2 │ │ │ │ │ └── config.py │ │ │ │ └── r3 │ │ │ │ │ └── config.py │ │ │ ├── aan │ │ │ │ ├── r1 │ │ │ │ │ ├── results.json │ │ │ │ │ └── config.py │ │ │ │ ├── r2 │ │ │ │ │ └── config.py │ │ │ │ └── r3 │ │ │ │ │ └── config.py │ │ │ ├── listops │ │ │ │ ├── r1 │ │ │ │ │ └── config.py │ │ │ │ ├── r2 │ │ │ │ │ └── config.py │ │ │ │ └── r3 │ │ │ │ │ └── config.py │ │ │ └── tc │ │ │ │ ├── r1 │ │ │ │ └── config.py │ │ │ │ ├── r3 │ │ │ │ └── config.py │ │ │ │ └── r2 │ │ │ │ └── config.py │ │ │ └── linear_transformer_relu_convspe_k128_shr │ │ │ ├── cifar10 │ │ │ ├── r1 │ │ │ │ ├── results.json │ │ │ │ └── config.py │ │ │ ├── r3 │ │ │ │ ├── results.json │ │ │ │ └── config.py │ │ │ └── r2 │ │ │ │ └── config.py │ │ │ ├── aan │ │ │ ├── r1 │ │ │ │ ├── results.json │ │ │ │ └── config.py │ │ │ ├── r2 │ │ │ │ └── config.py │ │ │ └── r3 │ │ │ │ └── config.py │ │ │ ├── listops │ │ │ ├── r1 │ │ │ │ ├── results.json │ │ │ │ └── config.py │ │ │ ├── r2 │ │ │ │ ├── results.json │ │ │ │ └── config.py │ │ │ └── r3 │ │ │ │ └── config.py │ │ │ └── tc │ │ │ ├── r1 │ │ │ └── config.py │ │ │ ├── r2 │ │ │ └── config.py │ │ │ └── r3 │ │ │ └── config.py │ ├── requirements.txt │ ├── run_cifar10.sh │ ├── run_path32.sh │ ├── run_tc.sh │ ├── run_listops.sh │ ├── run_aan.sh │ └── README.md └── pop_piano │ ├── pickles │ ├── remi_vocab.pkl │ ├── test_pieces.pkl │ ├── val_pieces.pkl │ └── train_pieces.pkl │ ├── download_dataset.sh │ ├── configs │ ├── inference │ │ └── default.yaml │ └── train │ │ ├── ape_default.yaml │ │ ├── convspe_default.yaml │ │ └── sinespe_default.yaml │ ├── README.md │ ├── utils.py │ ├── models │ ├── music_performer_ape.py │ ├── ape_fast_transformer_decoder.py │ ├── transformer_helpers.py │ └── music_performer_spe.py │ └── inference.py ├── src ├── pytorch │ ├── spe │ │ └── __init__.py │ ├── README.md │ ├── setup.py │ └── LICENSE └── jax │ ├── jax_spe │ └── __init__.py │ ├── README.md │ ├── setup.py │ └── LICENSE ├── .gitmodules ├── README.md └── .gitignore /experiments/groove/src/spe_music/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/groove/src/spe_music/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/groove/src/spe_music/style_eval/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/pytorch/spe/__init__.py: -------------------------------------------------------------------------------- 1 | from .spe import ConvSPE, SineSPE, SPEFilter 2 | -------------------------------------------------------------------------------- /src/jax/jax_spe/__init__.py: -------------------------------------------------------------------------------- 1 | from .spe import SineSPE, ConvSPE, SPEGate, apply_spe 2 | -------------------------------------------------------------------------------- /experiments/lra/fast_attention/fast_self_attention/__init__.py: -------------------------------------------------------------------------------- 1 | from .fast_self_attention import * 2 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax/cifar10/r1/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.41436299681663513, "loss": 1.7024986743927002} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax/cifar10/r2/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.4311898946762085, "loss": 1.6787774562835693} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_sinespe/cifar10/r2/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.40234375, "loss": 1.7382149696350098} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu/cifar10/r1/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.4225761294364929, "loss": 1.6958818435668945} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu/cifar10/r2/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.4224759638309479, "loss": 1.7094354629516602} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax/cifar10/tune01/r1/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.31089743971824646, "loss": 1.88819420337677} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax/cifar10/tune01/r2/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.317307710647583, "loss": 1.9050369262695312} -------------------------------------------------------------------------------- /experiments/pop_piano/pickles/remi_vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliutkus/spe/HEAD/experiments/pop_piano/pickles/remi_vocab.pkl -------------------------------------------------------------------------------- /experiments/pop_piano/pickles/test_pieces.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliutkus/spe/HEAD/experiments/pop_piano/pickles/test_pieces.pkl -------------------------------------------------------------------------------- /experiments/pop_piano/pickles/val_pieces.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliutkus/spe/HEAD/experiments/pop_piano/pickles/val_pieces.pkl -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu/cifar10/r3/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.42237579822540283, "loss": 1.6746834516525269} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax/cifar10/tune01/r3/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.3149038553237915, "loss": 1.9058215618133545} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_sinespe/cifar10/r1/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.4307892620563507, "loss": 1.6581761837005615} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_sinespe/cifar10/r3/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.40034055709838867, "loss": 1.7466264963150024} -------------------------------------------------------------------------------- /experiments/pop_piano/pickles/train_pieces.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliutkus/spe/HEAD/experiments/pop_piano/pickles/train_pieces.pkl -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_sinespe/cifar10/r1/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.4205729365348816, "loss": 1.6797353029251099} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_sinespe/cifar10/r2/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.4170673191547394, "loss": 1.7022228240966797} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_sinespe/cifar10/r3/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.39853766560554504, "loss": 1.7336770296096802} -------------------------------------------------------------------------------- /experiments/groove/exp/clone.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ex 3 | mkdir "$2" 4 | cp "$1/config.yaml" "$2/config.yaml" 5 | sensible-editor "$2/config.yaml" 6 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_convspe_k128_shr/cifar10/r1/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.4006410241127014, "loss": 1.7773103713989258} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_convspe_k128_shr/cifar10/r1/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.3845152258872986, "loss": 1.797257423400879} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_convspe_k128_shr/cifar10/r3/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.40815305709838867, "loss": 1.7343062162399292} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax/aan/r1/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.5822625160217285, "loss": 0.6590374112129211, "perplexity": 1.932930827140808} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax/aan/r2/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.6106585264205933, "loss": 0.6642841100692749, "perplexity": 1.9430989027023315} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax/aan/r3/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.6022257804870605, "loss": 0.6531969904899597, "perplexity": 1.9216746091842651} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax/tc/r1/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.6273199915885925, "loss": 0.6473246216773987, "perplexity": 1.910422921180725} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax/tc/r2/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.6268399953842163, "loss": 0.6484074592590332, "perplexity": 1.9124926328659058} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax/tc/r3/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.6233199834823608, "loss": 0.6503739953041077, "perplexity": 1.916257381439209} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu/aan/r1/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.5855897068977356, "loss": 0.6629791855812073, "perplexity": 1.9405648708343506} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu/aan/r2/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.5979233384132385, "loss": 0.6598353385925293, "perplexity": 1.9344737529754639} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu/aan/r3/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.579795777797699, "loss": 0.6649137139320374, "perplexity": 1.9443225860595703} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu/tc/r1/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.6394799947738647, "loss": 0.6382478475570679, "perplexity": 1.8931608200073242} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu/tc/r2/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.6403200030326843, "loss": 0.638923168182373, "perplexity": 1.8944398164749146} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu/tc/r3/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.639519989490509, "loss": 0.6495220065116882, "perplexity": 1.9146254062652588} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax/listops/r1/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.17800000309944153, "loss": 2.246126413345337, "perplexity": 9.451055526733398} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax/listops/r2/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.17800000309944153, "loss": 2.2406880855560303, "perplexity": 9.399796485900879} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax/listops/r3/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.17800000309944153, "loss": 2.240405321121216, "perplexity": 9.397139549255371} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu/listops/r1/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.1730000078678131, "loss": 2.6976065635681152, "perplexity": 14.844160079956055} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu/listops/r2/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.16750000417232513, "loss": 2.5761983394622803, "perplexity": 13.147062301635742} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu/listops/r3/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.18700000643730164, "loss": 2.239375591278076, "perplexity": 9.387467384338379} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_sinespe/aan/r1/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.5920146703720093, "loss": 0.6763992309570312, "perplexity": 1.966783046722412} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_sinespe/aan/r2/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.6138136982917786, "loss": 0.6445307731628418, "perplexity": 1.9050928354263306} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_sinespe/aan/r3/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.5940224528312683, "loss": 0.6604791879653931, "perplexity": 1.935719609260559} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_sinespe/listops/r2/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.17250001430511475, "loss": 2.245021104812622, "perplexity": 9.440614700317383} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_sinespe/listops/r3/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.17800000309944153, "loss": 2.241204261779785, "perplexity": 9.40464973449707} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_sinespe/tc/r1/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.6220799684524536, "loss": 0.6553142070770264, "perplexity": 1.9257475137710571} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_sinespe/tc/r2/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.6244399547576904, "loss": 0.6501898169517517, "perplexity": 1.9159045219421387} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_sinespe/tc/r3/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.6315999627113342, "loss": 0.6539342999458313, "perplexity": 1.9230918884277344} -------------------------------------------------------------------------------- /experiments/lra/requirements.txt: -------------------------------------------------------------------------------- 1 | flax==0.2.2 2 | gin-config==0.4.0 3 | ml-collections==0.1.0 4 | tensorboard==2.4.0 5 | tensorflow==2.3.1 6 | tensorflow-datasets==4.1.0 7 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_sinespe/aan/r1/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.6174851059913635, "loss": 0.6383522152900696, "perplexity": 1.8933584690093994} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_sinespe/aan/r2/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.6290729641914368, "loss": 0.6321158409118652, "perplexity": 1.8815875053405762} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_sinespe/aan/r3/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.6252868175506592, "loss": 0.6331627368927002, "perplexity": 1.8835583925247192} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_sinespe/listops/r3/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.17800000309944153, "loss": 2.239715814590454, "perplexity": 9.39066219329834} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_sinespe/tc/r1/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.6352800130844116, "loss": 0.6422345042228699, "perplexity": 1.9007232189178467} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_sinespe/tc/r2/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.6475600004196167, "loss": 0.6308760643005371, "perplexity": 1.879256248474121} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_sinespe/tc/r3/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.6397199630737305, "loss": 0.6403006911277771, "perplexity": 1.897051215171814} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_sinespe/listops/r1/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.17250001430511475, "loss": 2.2420830726623535, "perplexity": 9.412919044494629} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_sinespe/listops/r1/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.17800000309944153, "loss": 2.241586446762085, "perplexity": 9.408245086669922} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_sinespe/listops/r2/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.17800000309944153, "loss": 2.2449517250061035, "perplexity": 9.439959526062012} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_convspe_k128_shr/aan/r1/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.5722234845161438, "loss": 0.6704565286636353, "perplexity": 1.9551297426223755} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_convspe_k128_shr/aan/r1/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.6112895607948303, "loss": 0.656229555606842, "perplexity": 1.9275110960006714} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_convspe_k128_shr/listops/r1/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.08150000125169754, "loss": 2.466508388519287, "perplexity": 11.78123950958252} -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_convspe_k128_shr/listops/r2/results.json: -------------------------------------------------------------------------------- 1 | {"accuracy": 0.10200000554323196, "loss": 2.427858591079712, "perplexity": 11.33458423614502} -------------------------------------------------------------------------------- /experiments/groove/requirements.txt: -------------------------------------------------------------------------------- 1 | bidict==0.21.2 2 | confugue==0.1.1 3 | flatdict==4.0.1 4 | muspy==0.3.0 5 | neptune-client==0.4.130 6 | numpy==1.19.4 7 | PyYAML==5.3.1 8 | torch==1.7.1 9 | -------------------------------------------------------------------------------- /experiments/pop_piano/download_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | wget -O remi_dataset.tar.gz https://zenodo.org/record/4782721/files/remi_dataset.tar.gz?download=1 4 | tar xzvf remi_dataset.tar.gz 5 | rm remi_dataset.tar.gz -------------------------------------------------------------------------------- /experiments/groove/src/setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name="spe-music", 5 | author="Ondřej Cífka, Shih-Lun Wu", 6 | description="Stochasic positional encoding - Music experiments", 7 | packages=setuptools.find_packages(), 8 | ) 9 | -------------------------------------------------------------------------------- /experiments/lra/fast_attention/setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name="fast-self-attention", 5 | author="Google", 6 | description="Fast Attention Via positive Orthogonal Random features", 7 | packages=setuptools.find_packages(), 8 | ) 9 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "experiments/lra/long-range-arena"] 2 | path = experiments/lra/long-range-arena 3 | url = https://github.com/cifkao/long-range-arena.git 4 | [submodule "experiments/groove/lib/fast-transformers"] 5 | path = experiments/groove/lib/fast-transformers 6 | url = https://github.com/cifkao/fast-transformers.git 7 | -------------------------------------------------------------------------------- /experiments/lra/run_cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [[ $# -ne 2 ]]; then 3 | echo "Expected exactly 2 arguments: config, model_dir" >&2 4 | exit 1 5 | fi 6 | 7 | config=$1 8 | model_dir=$2 9 | 10 | set -ex 11 | mkdir -p "$model_dir" 12 | python -m lra_benchmarks.image.train --config="$config" --model_dir="$model_dir/" --task_name=cifar10 13 | -------------------------------------------------------------------------------- /experiments/lra/run_path32.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [[ $# -ne 2 ]]; then 3 | echo "Expected exactly 2 arguments: config, model_dir" >&2 4 | exit 1 5 | fi 6 | 7 | config=$1 8 | model_dir=$2 9 | 10 | set -ex 11 | mkdir -p "$model_dir" 12 | python -m lra_benchmarks.image.train --config="$config" --model_dir="$model_dir/" --task_name=pathfinder32_hard 13 | -------------------------------------------------------------------------------- /experiments/pop_piano/configs/inference/default.yaml: -------------------------------------------------------------------------------- 1 | ckpt_path: ckpt_spe/0128_l24_n2048_SineSPENew-r64-s5_sharePE/params/ep078_loss1.532_params.pt 2 | gpuid: 0 3 | gen_output_dir: generations/SineSPE 4 | gen_n_pieces: 5 5 | gen_max_events: 2048 6 | gen_max_bars: 32 7 | sampling: 8 | temp: 1.2 9 | top_p: 0.9 -------------------------------------------------------------------------------- /experiments/groove/exp/style_eval_drums_config.yaml: -------------------------------------------------------------------------------- 1 | time_pitch_diff: null 2 | 3 | note_stats: 4 | stats: 5 | - name: onset.drum 6 | features: 7 | - name: onset 8 | bins: &onset_bins 9 | bin_size: !!python/object/apply:eval [ 1/6 ] 10 | - name: pitch 11 | bins: &pitch_bins 12 | min_value: 0 13 | max_value: 127 14 | -------------------------------------------------------------------------------- /src/jax/README.md: -------------------------------------------------------------------------------- 1 | # jax-spe 2 | Stochastic Positional Encoding for JAX/Flax. 3 | 4 | ## Installation 5 | 6 | ```bash 7 | pip install -e . 8 | ``` 9 | 10 | # Usage 11 | The `SineSPE` and `ConvSPE` modules generate positional codes Q̅ and K̅, the `SPEGate` applies the optional gating, and the `apply_spe` functions combines Q̅ and K̅ with queries Q and keys K to form new queries Q̂ and keys K̂. 12 | 13 | See the [example notebook](./examples/test_spe.ipynb). 14 | -------------------------------------------------------------------------------- /experiments/lra/run_tc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [[ $# -ne 2 ]]; then 3 | echo "Expected exactly 2 arguments: config, model_dir" >&2 4 | exit 1 5 | fi 6 | 7 | config=$1 8 | model_dir=$2 9 | 10 | set -ex 11 | mkdir -p "$model_dir" 12 | python -m lra_benchmarks.text_classification.train --config="$config" --model_dir="$model_dir/" --task_name=imdb_reviews 13 | python -m lra_benchmarks.text_classification.train --config="$config" --model_dir="$model_dir/" --task_name=imdb_reviews --test_only 14 | -------------------------------------------------------------------------------- /experiments/groove/exp/style_eval_config.yaml: -------------------------------------------------------------------------------- 1 | time_pitch_diff: 2 | bin_size: !!python/object/apply:eval [ 1/6 ] 3 | max_time: 4 4 | pitch_range: 20 5 | 6 | note_stats: 7 | stats: 8 | - name: onset.duration 9 | features: 10 | - name: onset 11 | bins: &onset_bins 12 | bin_size: !!python/object/apply:eval [ 1/6 ] 13 | - name: duration 14 | bins: 15 | bin_size: !!python/object/apply:eval [ 1/6 ] 16 | max_value: 2 17 | -------------------------------------------------------------------------------- /experiments/lra/run_listops.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | data_dir=data/lra_release/listops-1000/ 3 | 4 | if [[ $# -ne 2 ]]; then 5 | echo "Expected exactly 2 arguments: config, model_dir" >&2 6 | exit 1 7 | fi 8 | 9 | config=$1 10 | model_dir=$2 11 | 12 | set -ex 13 | mkdir -p "$model_dir" 14 | python -m lra_benchmarks.listops.train --config="$config" --model_dir="$model_dir/" --task_name=basic --data_dir="$data_dir" 15 | python -m lra_benchmarks.listops.train --config="$config" --model_dir="$model_dir/" --task_name=basic --data_dir="$data_dir" --test_only 16 | -------------------------------------------------------------------------------- /experiments/lra/run_aan.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | data_dir=data/lra_release/tsv_data/ 3 | 4 | if [[ $# -ne 2 ]]; then 5 | echo "Expected exactly 2 arguments: config, model_dir" >&2 6 | exit 1 7 | fi 8 | 9 | config=$1 10 | model_dir=$2 11 | 12 | set -ex 13 | mkdir -p "$model_dir" 14 | python -m lra_benchmarks.matching.train --config="$config" --model_dir="$model_dir/" --vocab_file_path="$model_dir"/vocab --data_dir="$data_dir" 15 | python -m lra_benchmarks.matching.train --config="$config" --model_dir="$model_dir/" --vocab_file_path="$model_dir"/vocab --data_dir="$data_dir" --test_only 16 | -------------------------------------------------------------------------------- /src/jax/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | # To use a consistent encoding 3 | from codecs import open 4 | from os import path 5 | 6 | here = path.abspath(path.dirname(__file__)) 7 | 8 | # Get the long description from the README file 9 | with open(path.join(here, 'README.md'), encoding='utf-8') as f: 10 | long_description = f.read() 11 | 12 | # Proceed to setup 13 | setup( 14 | name='jax-spe', 15 | version='0.1', 16 | description='stochastic positional encoding for PyTorch', 17 | long_description=long_description, 18 | long_description_content_type='text/markdown', 19 | author='Ondřej Cífka', 20 | author_email='cifkao@gmail.com', 21 | packages=['jax_spe'], 22 | install_requires=[ 23 | 'jax>=0.2.6', 24 | ], 25 | classifiers=[ 26 | "Programming Language :: Python :: 3", 27 | "License :: OSI Approved :: MIT License", 28 | ] 29 | ) 30 | -------------------------------------------------------------------------------- /experiments/pop_piano/configs/train/ape_default.yaml: -------------------------------------------------------------------------------- 1 | data_loader: 2 | batch_size: 4 3 | train_split: ./pickles/train_pieces.pkl 4 | val_split: ./pickles/val_pieces.pkl 5 | max_bars: 48 6 | 7 | model: 8 | pe_type: APE 9 | d_model: 512 10 | d_embed: 512 11 | max_len: 2048 12 | n_layer: 24 13 | n_head: 8 14 | d_ff: 2048 15 | feature_map: 16 | n_dims: 128 17 | 18 | 19 | training: 20 | gpuid: 0 21 | num_epochs: 200 22 | ckpt_dir: './ckpt/APE' 23 | ckpt_interval: 3 # epochs 24 | log_interval: 200 # steps 25 | trained_params: null 26 | trained_optim: null 27 | 28 | lr: 1.0e-4 29 | lr_scheduler: 30 | eta_min: 5.0e-6 31 | T_max: 62880 # 160 epochs 32 | warmup_steps: 200 33 | 34 | feat_redraw_prob: 0.05 35 | -------------------------------------------------------------------------------- /experiments/pop_piano/README.md: -------------------------------------------------------------------------------- 1 | # Experiments on Pop Piano Generation 2 | 3 | This directory corresponds to **Section 3.2** of the paper. 4 | 5 | ## Prerequisites 6 | * Python >3.6 7 | * Additional dependencies 8 | ```bash 9 | pip3 install miditoolkit 10 | pip3 install -e ../groove/lib/fast-transformers 11 | pip3 install -e ../../src/pytorch 12 | ``` 13 | 14 | ## Usage Notes 15 | For detailed configuration settings, please read the `yaml` files under `configs/` directory. 16 | 17 | ### Training 18 | ```bash 19 | python3 train.py [training config path] 20 | ``` 21 | * e.g. 22 | ```bash 23 | python3 train.py configs/train/sinespe_default.yaml 24 | ``` 25 | 26 | ### Inference 27 | ```bash 28 | python3 inference.py [training config path] [inference config file] 29 | ``` 30 | * e.g. 31 | ```bash 32 | python3 inference.py configs/train/sinespe_default.yaml config/inference/default.yaml 33 | ``` 34 | 35 | ### Evaluation (NLL Loss vs. Position) 36 | ```bash 37 | python3 eval.py [training config path] [checkpoint path] 38 | ``` 39 | -------------------------------------------------------------------------------- /src/pytorch/README.md: -------------------------------------------------------------------------------- 1 | # spe 2 | Stochastic Positional Encoding for PyTorch. 3 | 4 | ## Installation 5 | 6 | ```bash 7 | pip install -e . 8 | ``` 9 | 10 | ## Usage 11 | 12 | Create an instance of either `SineSPE` or `ConvSPE`, and an instance of `SPEFilter`: 13 | ```python 14 | spe_encoder = spe.SineSPE(num_heads=8, # Number of attention heads 15 | in_features=64, # Dimension of keys and queries 16 | num_realizations=64, # New dimension of keys and queries 17 | num_sines=5) # Number of sinusoidal components 18 | spe_filter = spe.SPEFilter(gated=True, code_shape=spe_encoder.code_shape) 19 | ``` 20 | `SineSPE` and `ConvSPE` take care of generating the positional codes Q̅ and K̅, and `SPEFilter` combines these with queries Q and keys K to form new queries Q̂ and keys K̂: 21 | ```python 22 | pos_codes = spe_encoder(queries.shape[:2]) # pos_codes is a tuple (qbar, kbar) 23 | queries, keys = spe_filter(queries, keys, pos_codes) 24 | ``` 25 | -------------------------------------------------------------------------------- /experiments/groove/exp/trio_performer_softmax_l512_v01/config.yaml: -------------------------------------------------------------------------------- 1 | data_loader: 2 | batch_size: 10 3 | model: 4 | d_model: 512 5 | d_embed: 512 6 | max_len: 512 7 | decoder: 8 | n_layer: 24 9 | n_head: 8 10 | d_ff: 2048 11 | feature_map: 12 | n_dims: 128 13 | representation: 14 | num_tracks: 3 15 | resolution: 12 16 | timing: 17 | resolution: 12 18 | max_shift: 2 # beats 19 | 20 | seed: 0 21 | train_data_path: ../data/train_split/train 22 | data_augmentation: 23 | max_harm_tracks: 1 24 | harm_tracks_shuffle_prob: 1. 25 | track_drop_prob: 0.1 26 | training: 27 | num_epochs: 24 28 | ckpt_interval: 3 # epochs 29 | log_interval: 200 # steps 30 | 31 | lr: 4.0e-4 32 | lr_scheduler: 33 | class: !!python/name:torch.optim.lr_scheduler.CosineAnnealingLR 34 | eta_min: 2.0e-5 35 | T_max: 43184 # 1841 * 24 - 1000 36 | warmup_steps: 1000 37 | 38 | feature_redraw_interval: 40 39 | 40 | val_data_paths: 41 | val: ../data/train_split/val 42 | ival: ../data/train_split/ival 43 | val_data_loader: 44 | batch_size: 8 45 | -------------------------------------------------------------------------------- /src/pytorch/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | # To use a consistent encoding 3 | from codecs import open 4 | from os import path 5 | 6 | # trying to import the required torch package 7 | try: 8 | import torch 9 | except ImportError: 10 | raise Exception('SPE requires PyTorch to be installed. aborting') 11 | 12 | here = path.abspath(path.dirname(__file__)) 13 | 14 | # Get the long description from the README file 15 | with open(path.join(here, 'README.md'), encoding='utf-8') as f: 16 | long_description = f.read() 17 | 18 | # Proceed to setup 19 | setup( 20 | name='spe', 21 | version='0.1', 22 | description='stochastic positional encoding for PyTorch', 23 | long_description=long_description, 24 | long_description_content_type='text/markdown', 25 | author='Antoine Liutkus', 26 | author_email='antoine.liutkus@inria.fr', 27 | packages=['spe'], 28 | keywords='pytorch', 29 | install_requires=[ 30 | 'torch>=1.7', 31 | ], 32 | classifiers=[ 33 | "Programming Language :: Python :: 3", 34 | "License :: OSI Approved :: MIT License", 35 | ] 36 | ) 37 | -------------------------------------------------------------------------------- /src/jax/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Ondřej Cífka 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /experiments/pop_piano/configs/train/convspe_default.yaml: -------------------------------------------------------------------------------- 1 | data_loader: 2 | batch_size: 4 3 | train_split: ./pickles/train_pieces.pkl 4 | val_split: ./pickles/val_pieces.pkl 5 | max_bars: 48 6 | 7 | model: 8 | pe_type: ConvSPE 9 | d_model: 512 10 | d_embed: 512 11 | max_len: 2048 12 | n_layer: 24 13 | n_head: 8 14 | d_ff: 2048 15 | feature_map: 16 | n_dims: 128 17 | positional_encoder: 18 | in_features: 64 19 | num_realizations: 64 20 | kernel_size: 128 21 | share_pe: True 22 | share_spe_filter: False 23 | use_gated_filter: True 24 | 25 | 26 | training: 27 | gpuid: 0 28 | num_epochs: 200 29 | ckpt_dir: './ckpt/ConvSPE' 30 | ckpt_interval: 2 # epochs 31 | log_interval: 200 # steps 32 | trained_params: null 33 | trained_optim: null 34 | 35 | lr: 2.0e-4 36 | lr_scheduler: 37 | eta_min: 5.0e-6 38 | T_max: 55020 # 140 epochs 39 | warmup_steps: 200 40 | 41 | feat_redraw_prob: 0.05 -------------------------------------------------------------------------------- /experiments/pop_piano/configs/train/sinespe_default.yaml: -------------------------------------------------------------------------------- 1 | data_loader: 2 | batch_size: 4 3 | train_split: ./pickles/train_pieces.pkl 4 | val_split: ./pickles/val_pieces.pkl 5 | max_bars: 48 6 | 7 | model: 8 | pe_type: SineSPE 9 | d_model: 512 10 | d_embed: 512 11 | max_len: 2048 12 | n_layer: 24 13 | n_head: 8 14 | d_ff: 2048 15 | feature_map: 16 | n_dims: 128 17 | positional_encoder: 18 | in_features: 64 19 | num_realizations: 64 20 | num_sines: 5 21 | share_pe: True 22 | share_spe_filter: False 23 | use_gated_filter: True 24 | 25 | 26 | training: 27 | gpuid: 0 28 | num_epochs: 200 29 | ckpt_dir: './ckpt/SineSPE' 30 | ckpt_interval: 3 # epochs 31 | log_interval: 200 # steps 32 | trained_params: null 33 | trained_optim: null 34 | 35 | lr: 2.0e-4 36 | lr_scheduler: 37 | eta_min: 5.0e-6 38 | T_max: 55020 # 140 epochs 39 | warmup_steps: 200 40 | 41 | feat_redraw_prob: 0.05 -------------------------------------------------------------------------------- /src/pytorch/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Antoine Liutkus, Ondřej Cífka, Shih-Lun Wu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /experiments/groove/README.md: -------------------------------------------------------------------------------- 1 | # Groove continuation experiments 2 | 3 | Code and configuration files for training Performers on the Groove2Groove dataset. 4 | 5 | ## Setup 6 | 7 | In a Python 3.7 environment: 8 | ```bash 9 | pip install -r requirements.txt 10 | pip install -e lib/fast-transformers 11 | pip install -e ../../src/pytorch 12 | pip install -e ./src 13 | ``` 14 | 15 | ## Data 16 | Run the Jupyter notebook [`data/prepare.ipynb`](./data/prepare.ipynb) to download and prepare the dataset. 17 | 18 | ## Training 19 | 20 | ```bash 21 | python -m spe_music.train_performer_grv2grv --model-dir $DIR 22 | ``` 23 | `$DIR` should be a directory containing a `config.yaml` file. To log to [Neptune](https://neptune.ai/), set the `NEPTUNE_API_TOKEN` and `NEPTUNE_PROJECT` environment variables. 24 | Optionally, use `--name` to specify the experiment name for Neptune (otherwise it will be equal to the model directory path). 25 | 26 | ## Evaluation 27 | 28 | The evaluation metrics are implemented in [`spe_music.style_eval`](./src/spe_music/style_eval) module. Running the evaluation consists of two steps: 1. generate continuations using the [`exp/continuation.ipynb`](./exp/continuation.ipynb) notebook, 2. compute metrics using the [`exp/style_eval_midi.ipynb`](./exp/style_eval_midi.ipynb) notebook. 29 | -------------------------------------------------------------------------------- /experiments/groove/exp/trio_performer_softmax_sinespe_l512_v01/config.yaml: -------------------------------------------------------------------------------- 1 | data_loader: 2 | batch_size: 10 3 | model: 4 | d_model: 512 5 | d_embed: 512 6 | max_len: 512 7 | add_positional_encoding: False 8 | decoder: 9 | n_layer: 24 10 | n_head: 8 11 | d_ff: 2048 12 | feature_map: 13 | n_dims: 128 14 | positional_encoder: 15 | class: !!python/name:spe.SineSPE 16 | in_features: 64 17 | num_realizations: &R 64 18 | num_sines: 8 19 | spe_filter: 20 | gated: True 21 | share_pe: True 22 | share_spe_filter: False 23 | attention: 24 | query_dimensions: *R 25 | representation: 26 | num_tracks: 3 27 | resolution: 12 28 | timing: 29 | resolution: 12 30 | max_shift: 2 # beats 31 | 32 | seed: 0 33 | train_data_path: ../data/train_split/train 34 | data_augmentation: 35 | max_harm_tracks: 1 36 | harm_tracks_shuffle_prob: 1. 37 | track_drop_prob: 0.1 38 | training: 39 | num_epochs: 24 40 | ckpt_interval: 3 # epochs 41 | log_interval: 200 # steps 42 | 43 | lr: 4.0e-4 44 | lr_scheduler: 45 | class: !!python/name:torch.optim.lr_scheduler.CosineAnnealingLR 46 | eta_min: 2.0e-5 47 | T_max: 43184 # 1841 * 24 - 1000 48 | warmup_steps: 1000 49 | 50 | feature_redraw_interval: 40 51 | 52 | val_data_paths: 53 | val: ../data/train_split/val 54 | ival: ../data/train_split/ival 55 | val_data_loader: 56 | batch_size: 8 57 | -------------------------------------------------------------------------------- /experiments/groove/exp/trio_performer_softmax_convspe_l512_v01/config.yaml: -------------------------------------------------------------------------------- 1 | data_loader: 2 | batch_size: 10 3 | model: 4 | d_model: 512 5 | d_embed: 512 6 | max_len: 512 7 | add_positional_encoding: False 8 | decoder: 9 | n_layer: 24 10 | n_head: 8 11 | d_ff: 2048 12 | feature_map: 13 | n_dims: 128 14 | positional_encoder: 15 | class: !!python/name:spe.ConvSPE 16 | in_features: 64 17 | num_realizations: &R 64 18 | kernel_size: 128 19 | spe_filter: 20 | gated: True 21 | share_pe: True 22 | share_spe_filter: False 23 | attention: 24 | query_dimensions: *R 25 | representation: 26 | num_tracks: 3 27 | resolution: 12 28 | timing: 29 | resolution: 12 30 | max_shift: 2 # beats 31 | 32 | seed: 0 33 | train_data_path: ../data/train_split/train 34 | data_augmentation: 35 | max_harm_tracks: 1 36 | harm_tracks_shuffle_prob: 1. 37 | track_drop_prob: 0.1 38 | training: 39 | num_epochs: 24 40 | ckpt_interval: 3 # epochs 41 | log_interval: 200 # steps 42 | 43 | lr: 4.0e-4 44 | lr_scheduler: 45 | class: !!python/name:torch.optim.lr_scheduler.CosineAnnealingLR 46 | eta_min: 2.0e-5 47 | T_max: 43184 # 1841 * 24 - 1000 48 | warmup_steps: 1000 49 | 50 | feature_redraw_interval: 40 51 | 52 | val_data_paths: 53 | val: ../data/train_split/val 54 | ival: ../data/train_split/ival 55 | val_data_loader: 56 | batch_size: 8 57 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax/listops/r1/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | from fast_self_attention import fast_self_attention as favor 17 | 18 | from lra_benchmarks.listops.configs import base_listops_config 19 | 20 | 21 | def get_config(): 22 | """Get the default hyperparameter configuration.""" 23 | config = base_listops_config.get_config() 24 | config.random_seed = 0 25 | config.model_type = "transformer" 26 | config.attention_fn = favor.make_fast_softmax_attention( 27 | qkv_dim=config.qkv_dim // config.num_heads, 28 | lax_scan_unroll=16) 29 | config.batch_size = 8 30 | config.learning_rate = config.learning_rate / 32 * 8 31 | config.num_train_steps = 10000 32 | return config 33 | 34 | 35 | def get_hyper(hyper): 36 | return hyper.product([]) 37 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax/listops/r2/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | from fast_self_attention import fast_self_attention as favor 17 | 18 | from lra_benchmarks.listops.configs import base_listops_config 19 | 20 | 21 | def get_config(): 22 | """Get the default hyperparameter configuration.""" 23 | config = base_listops_config.get_config() 24 | config.random_seed = 1 25 | config.model_type = "transformer" 26 | config.attention_fn = favor.make_fast_softmax_attention( 27 | qkv_dim=config.qkv_dim // config.num_heads, 28 | lax_scan_unroll=16) 29 | config.batch_size = 8 30 | config.learning_rate = config.learning_rate / 32 * 8 31 | config.num_train_steps = 10000 32 | return config 33 | 34 | 35 | def get_hyper(hyper): 36 | return hyper.product([]) 37 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax/listops/r3/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | from fast_self_attention import fast_self_attention as favor 17 | 18 | from lra_benchmarks.listops.configs import base_listops_config 19 | 20 | 21 | def get_config(): 22 | """Get the default hyperparameter configuration.""" 23 | config = base_listops_config.get_config() 24 | config.random_seed = 2 25 | config.model_type = "transformer" 26 | config.attention_fn = favor.make_fast_softmax_attention( 27 | qkv_dim=config.qkv_dim // config.num_heads, 28 | lax_scan_unroll=16) 29 | config.batch_size = 8 30 | config.learning_rate = config.learning_rate / 32 * 8 31 | config.num_train_steps = 10000 32 | return config 33 | 34 | 35 | def get_hyper(hyper): 36 | return hyper.product([]) 37 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax/tc/r1/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | from fast_self_attention import fast_self_attention as favor 17 | 18 | from lra_benchmarks.text_classification.configs import base_tc_config 19 | 20 | 21 | def get_config(): 22 | """Get the default hyperparameter configuration.""" 23 | config = base_tc_config.get_config() 24 | config.random_seed = 0 25 | config.model_type = "transformer" 26 | config.attention_fn = favor.make_fast_softmax_attention( 27 | qkv_dim=config.qkv_dim // config.num_heads, 28 | lax_scan_unroll=16) 29 | config.batch_size = config.batch_size // 2 30 | config.learning_rate = config.learning_rate / 2 31 | config.num_train_steps = 30000 32 | return config 33 | 34 | 35 | def get_hyper(hyper): 36 | return hyper.product([]) 37 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax/tc/r2/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | from fast_self_attention import fast_self_attention as favor 17 | 18 | from lra_benchmarks.text_classification.configs import base_tc_config 19 | 20 | 21 | def get_config(): 22 | """Get the default hyperparameter configuration.""" 23 | config = base_tc_config.get_config() 24 | config.random_seed = 1 25 | config.model_type = "transformer" 26 | config.attention_fn = favor.make_fast_softmax_attention( 27 | qkv_dim=config.qkv_dim // config.num_heads, 28 | lax_scan_unroll=16) 29 | config.batch_size = config.batch_size // 2 30 | config.learning_rate = config.learning_rate / 2 31 | config.num_train_steps = 30000 32 | return config 33 | 34 | 35 | def get_hyper(hyper): 36 | return hyper.product([]) 37 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax/tc/r3/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | from fast_self_attention import fast_self_attention as favor 17 | 18 | from lra_benchmarks.text_classification.configs import base_tc_config 19 | 20 | 21 | def get_config(): 22 | """Get the default hyperparameter configuration.""" 23 | config = base_tc_config.get_config() 24 | config.random_seed = 2 25 | config.model_type = "transformer" 26 | config.attention_fn = favor.make_fast_softmax_attention( 27 | qkv_dim=config.qkv_dim // config.num_heads, 28 | lax_scan_unroll=16) 29 | config.batch_size = config.batch_size // 2 30 | config.learning_rate = config.learning_rate / 2 31 | config.num_train_steps = 30000 32 | return config 33 | 34 | 35 | def get_hyper(hyper): 36 | return hyper.product([]) 37 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax/aan/r1/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | from fast_self_attention import fast_self_attention as favor 17 | 18 | from lra_benchmarks.matching.configs import base_match_config 19 | 20 | 21 | def get_config(): 22 | """Get the default hyperparameter configuration.""" 23 | config = base_match_config.get_config() 24 | config.random_seed = 0 25 | config.model_type = "transformer" 26 | config.attention_fn = favor.make_fast_softmax_attention( 27 | qkv_dim=config.qkv_dim // config.num_heads, 28 | lax_scan_unroll=16) 29 | config.batch_size = 8 30 | config.learning_rate = 0.005 31 | config.num_train_steps = 15000 32 | config.warmup = 3000 33 | config.eval_frequency = 1500 34 | return config 35 | 36 | 37 | def get_hyper(hyper): 38 | return hyper.product([]) 39 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax/aan/r2/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | from fast_self_attention import fast_self_attention as favor 17 | 18 | from lra_benchmarks.matching.configs import base_match_config 19 | 20 | 21 | def get_config(): 22 | """Get the default hyperparameter configuration.""" 23 | config = base_match_config.get_config() 24 | config.random_seed = 1 25 | config.model_type = "transformer" 26 | config.attention_fn = favor.make_fast_softmax_attention( 27 | qkv_dim=config.qkv_dim // config.num_heads, 28 | lax_scan_unroll=16) 29 | config.batch_size = 8 30 | config.learning_rate = 0.005 31 | config.num_train_steps = 15000 32 | config.warmup = 3000 33 | config.eval_frequency = 1500 34 | return config 35 | 36 | 37 | def get_hyper(hyper): 38 | return hyper.product([]) 39 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax/aan/r3/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | from fast_self_attention import fast_self_attention as favor 17 | 18 | from lra_benchmarks.matching.configs import base_match_config 19 | 20 | 21 | def get_config(): 22 | """Get the default hyperparameter configuration.""" 23 | config = base_match_config.get_config() 24 | config.random_seed = 2 25 | config.model_type = "transformer" 26 | config.attention_fn = favor.make_fast_softmax_attention( 27 | qkv_dim=config.qkv_dim // config.num_heads, 28 | lax_scan_unroll=16) 29 | config.batch_size = 8 30 | config.learning_rate = 0.005 31 | config.num_train_steps = 15000 32 | config.warmup = 3000 33 | config.eval_frequency = 1500 34 | return config 35 | 36 | 37 | def get_hyper(hyper): 38 | return hyper.product([]) 39 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu/listops/r1/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | from fast_self_attention import fast_self_attention as favor 17 | import jax 18 | 19 | from lra_benchmarks.listops.configs import base_listops_config 20 | 21 | 22 | def get_config(): 23 | """Get the default hyperparameter configuration.""" 24 | config = base_listops_config.get_config() 25 | config.random_seed = 0 26 | config.model_type = "transformer" 27 | config.attention_fn = favor.make_fast_generalized_attention( 28 | qkv_dim=config.qkv_dim // config.num_heads, 29 | features_type='deterministic', 30 | kernel_fn=jax.nn.relu, 31 | lax_scan_unroll=16) 32 | config.batch_size = 8 33 | config.learning_rate = config.learning_rate / 32 * 8 34 | config.num_train_steps = 10000 35 | return config 36 | 37 | 38 | def get_hyper(hyper): 39 | return hyper.product([]) 40 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu/listops/r3/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | from fast_self_attention import fast_self_attention as favor 17 | import jax 18 | 19 | from lra_benchmarks.listops.configs import base_listops_config 20 | 21 | 22 | def get_config(): 23 | """Get the default hyperparameter configuration.""" 24 | config = base_listops_config.get_config() 25 | config.random_seed = 2 26 | config.model_type = "transformer" 27 | config.attention_fn = favor.make_fast_generalized_attention( 28 | qkv_dim=config.qkv_dim // config.num_heads, 29 | features_type='deterministic', 30 | kernel_fn=jax.nn.relu, 31 | lax_scan_unroll=16) 32 | config.batch_size = 8 33 | config.learning_rate = config.learning_rate / 32 * 8 34 | config.num_train_steps = 10000 35 | return config 36 | 37 | 38 | def get_hyper(hyper): 39 | return hyper.product([]) 40 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu/tc/r1/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | from fast_self_attention import fast_self_attention as favor 17 | import jax 18 | 19 | from lra_benchmarks.text_classification.configs import base_tc_config 20 | 21 | 22 | def get_config(): 23 | """Get the default hyperparameter configuration.""" 24 | config = base_tc_config.get_config() 25 | config.random_seed = 0 26 | config.model_type = "transformer" 27 | config.attention_fn = favor.make_fast_generalized_attention( 28 | qkv_dim=config.qkv_dim // config.num_heads, 29 | features_type='deterministic', 30 | kernel_fn=jax.nn.relu, 31 | lax_scan_unroll=16) 32 | config.batch_size = 8 33 | config.learning_rate = config.learning_rate / 32 * 8 34 | config.num_train_steps = 30000 35 | return config 36 | 37 | 38 | def get_hyper(hyper): 39 | return hyper.product([]) 40 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu/tc/r2/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | from fast_self_attention import fast_self_attention as favor 17 | import jax 18 | 19 | from lra_benchmarks.text_classification.configs import base_tc_config 20 | 21 | 22 | def get_config(): 23 | """Get the default hyperparameter configuration.""" 24 | config = base_tc_config.get_config() 25 | config.random_seed = 1 26 | config.model_type = "transformer" 27 | config.attention_fn = favor.make_fast_generalized_attention( 28 | qkv_dim=config.qkv_dim // config.num_heads, 29 | features_type='deterministic', 30 | kernel_fn=jax.nn.relu, 31 | lax_scan_unroll=16) 32 | config.batch_size = 8 33 | config.learning_rate = config.learning_rate / 32 * 8 34 | config.num_train_steps = 30000 35 | return config 36 | 37 | 38 | def get_hyper(hyper): 39 | return hyper.product([]) 40 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu/tc/r3/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | from fast_self_attention import fast_self_attention as favor 17 | import jax 18 | 19 | from lra_benchmarks.text_classification.configs import base_tc_config 20 | 21 | 22 | def get_config(): 23 | """Get the default hyperparameter configuration.""" 24 | config = base_tc_config.get_config() 25 | config.random_seed = 2 26 | config.model_type = "transformer" 27 | config.attention_fn = favor.make_fast_generalized_attention( 28 | qkv_dim=config.qkv_dim // config.num_heads, 29 | features_type='deterministic', 30 | kernel_fn=jax.nn.relu, 31 | lax_scan_unroll=16) 32 | config.batch_size = 8 33 | config.learning_rate = config.learning_rate / 32 * 8 34 | config.num_train_steps = 30000 35 | return config 36 | 37 | 38 | def get_hyper(hyper): 39 | return hyper.product([]) 40 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu/aan/r1/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | from fast_self_attention import fast_self_attention as favor 17 | import jax 18 | 19 | from lra_benchmarks.matching.configs import base_match_config 20 | 21 | 22 | def get_config(): 23 | """Get the default hyperparameter configuration.""" 24 | config = base_match_config.get_config() 25 | config.random_seed = 0 26 | config.model_type = "transformer" 27 | config.attention_fn = favor.make_fast_generalized_attention( 28 | qkv_dim=config.qkv_dim // config.num_heads, 29 | features_type='deterministic', 30 | kernel_fn=jax.nn.relu, 31 | lax_scan_unroll=16) 32 | config.batch_size = 8 33 | config.learning_rate = 0.005 34 | config.num_train_steps = 15000 35 | config.warmup = 3000 36 | config.eval_frequency = 1500 37 | return config 38 | 39 | 40 | def get_hyper(hyper): 41 | return hyper.product([]) 42 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu/aan/r2/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | from fast_self_attention import fast_self_attention as favor 17 | import jax 18 | 19 | from lra_benchmarks.matching.configs import base_match_config 20 | 21 | 22 | def get_config(): 23 | """Get the default hyperparameter configuration.""" 24 | config = base_match_config.get_config() 25 | config.random_seed = 1 26 | config.model_type = "transformer" 27 | config.attention_fn = favor.make_fast_generalized_attention( 28 | qkv_dim=config.qkv_dim // config.num_heads, 29 | features_type='deterministic', 30 | kernel_fn=jax.nn.relu, 31 | lax_scan_unroll=16) 32 | config.batch_size = 8 33 | config.learning_rate = 0.005 34 | config.num_train_steps = 15000 35 | config.warmup = 3000 36 | config.eval_frequency = 1500 37 | return config 38 | 39 | 40 | def get_hyper(hyper): 41 | return hyper.product([]) 42 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu/aan/r3/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | from fast_self_attention import fast_self_attention as favor 17 | import jax 18 | 19 | from lra_benchmarks.matching.configs import base_match_config 20 | 21 | 22 | def get_config(): 23 | """Get the default hyperparameter configuration.""" 24 | config = base_match_config.get_config() 25 | config.random_seed = 2 26 | config.model_type = "transformer" 27 | config.attention_fn = favor.make_fast_generalized_attention( 28 | qkv_dim=config.qkv_dim // config.num_heads, 29 | features_type='deterministic', 30 | kernel_fn=jax.nn.relu, 31 | lax_scan_unroll=16) 32 | config.batch_size = 8 33 | config.learning_rate = 0.005 34 | config.num_train_steps = 15000 35 | config.warmup = 3000 36 | config.eval_frequency = 1500 37 | return config 38 | 39 | 40 | def get_hyper(hyper): 41 | return hyper.product([]) 42 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu/listops/r2/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | from fast_self_attention import fast_self_attention as favor 17 | import jax 18 | 19 | from lra_benchmarks.listops.configs import base_listops_config 20 | 21 | 22 | def get_config(): 23 | """Get the default hyperparameter configuration.""" 24 | config = base_listops_config.get_config() 25 | config.random_seed = 1 26 | config.model_type = "transformer" 27 | config.attention_fn = favor.make_fast_generalized_attention( 28 | qkv_dim=config.qkv_dim // config.num_heads, 29 | features_type='deterministic', 30 | kernel_fn=jax.nn.relu, 31 | lax_scan_unroll=16) 32 | config.batch_size = 8 33 | config.learning_rate = config.learning_rate / 32 * 8 34 | config.num_train_steps = 10000 35 | config.eval_frequency = config.eval_frequency * 4 36 | return config 37 | 38 | 39 | def get_hyper(hyper): 40 | return hyper.product([]) 41 | -------------------------------------------------------------------------------- /experiments/lra/README.md: -------------------------------------------------------------------------------- 1 | # Long-Range Arena experiments 2 | 3 | Original LRA repository: [google-research/long-range-arena](https://github.com/google-research/long-range-arena) 4 | 5 | **Note:** A known bug in the ListOps task removes the brackets from the input, making the task essentially impossible to solve. The bug appears to have been fixed in the upstream repository in [c209b2a4](https://github.com/google-research/long-range-arena/commit/c209b2a48eedfd7ffcd13c679f97f3fa466c47bc), but the change is not merged here and we have not tested it. Our results were produced before the fix and are therefore affected by the bug. 6 | 7 | ## Setup 8 | 9 | Install JAX (adjust the `jaxlib` version according to your CUDA version): 10 | ```bash 11 | pip install jax==0.2.6 jaxlib==0.1.57+cuda102 -f https://storage.googleapis.com/jax-releases/jax_releases.html 12 | ``` 13 | 14 | Install the rest of the requirements: 15 | ```bash 16 | pip install -r requirements.txt 17 | pip install -e ./fast_attention ./long-range-arena ../../src/jax 18 | ``` 19 | 20 | ## Data 21 | 22 | Download the data listed [here](https://github.com/google-research/long-range-arena). Adjust the paths in the training scripts (e.g. [`run_aan.sh`](./run_aan.sh)) if needed. 23 | 24 | ## Running the benchmark 25 | 26 | There is a Bash script for each of the LRA tasks. For example, to perform the first run of the Performer on the Retrieval (AAN) task, run: 27 | ```bash 28 | ./run_aan.sh models/gpu_16g/performer_softmax/aan/r1/config.py models/gpu_16g/performer_softmax/aan/r1 29 | ``` 30 | Once the run is finished, the results will be in a `results.json` file inside the model directory. 31 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_sinespe/listops/r1/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax_spe as spe 20 | 21 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 22 | from lra_benchmarks.listops.configs import base_listops_config 23 | 24 | 25 | def get_config(): 26 | """Get the default hyperparameter configuration.""" 27 | config = base_listops_config.get_config() 28 | config.random_seed = 0 29 | config.model_type = "transformer" 30 | config.attention_fn = favor.make_fast_softmax_attention( 31 | qkv_dim=config.qkv_dim // config.num_heads, 32 | lax_scan_unroll=16) 33 | config.model_kwargs = dict( 34 | add_pos_emb=False, 35 | qk_transform_fn_factory=functools.partial( 36 | make_spe_transform_fn, 37 | spe_cls=spe.SineSPE, 38 | spe_kwargs=dict( 39 | num_realizations=64, 40 | num_sines=10 41 | ), 42 | ) 43 | ) 44 | config.batch_size = 8 45 | config.learning_rate = config.learning_rate / 32 * 8 46 | config.num_train_steps = 10000 47 | return config 48 | 49 | 50 | def get_hyper(hyper): 51 | return hyper.product([]) 52 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_sinespe/listops/r2/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax_spe as spe 20 | 21 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 22 | from lra_benchmarks.listops.configs import base_listops_config 23 | 24 | 25 | def get_config(): 26 | """Get the default hyperparameter configuration.""" 27 | config = base_listops_config.get_config() 28 | config.random_seed = 1 29 | config.model_type = "transformer" 30 | config.attention_fn = favor.make_fast_softmax_attention( 31 | qkv_dim=config.qkv_dim // config.num_heads, 32 | lax_scan_unroll=16) 33 | config.model_kwargs = dict( 34 | add_pos_emb=False, 35 | qk_transform_fn_factory=functools.partial( 36 | make_spe_transform_fn, 37 | spe_cls=spe.SineSPE, 38 | spe_kwargs=dict( 39 | num_realizations=64, 40 | num_sines=10 41 | ), 42 | ) 43 | ) 44 | config.batch_size = 8 45 | config.learning_rate = config.learning_rate / 32 * 8 46 | config.num_train_steps = 10000 47 | return config 48 | 49 | 50 | def get_hyper(hyper): 51 | return hyper.product([]) 52 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_sinespe/listops/r3/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax_spe as spe 20 | 21 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 22 | from lra_benchmarks.listops.configs import base_listops_config 23 | 24 | 25 | def get_config(): 26 | """Get the default hyperparameter configuration.""" 27 | config = base_listops_config.get_config() 28 | config.random_seed = 2 29 | config.model_type = "transformer" 30 | config.attention_fn = favor.make_fast_softmax_attention( 31 | qkv_dim=config.qkv_dim // config.num_heads, 32 | lax_scan_unroll=16) 33 | config.model_kwargs = dict( 34 | add_pos_emb=False, 35 | qk_transform_fn_factory=functools.partial( 36 | make_spe_transform_fn, 37 | spe_cls=spe.SineSPE, 38 | spe_kwargs=dict( 39 | num_realizations=64, 40 | num_sines=10 41 | ), 42 | ) 43 | ) 44 | config.batch_size = 8 45 | config.learning_rate = config.learning_rate / 32 * 8 46 | config.num_train_steps = 10000 47 | return config 48 | 49 | 50 | def get_hyper(hyper): 51 | return hyper.product([]) 52 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_sinespe/tc/r1/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax_spe as spe 20 | 21 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 22 | from lra_benchmarks.text_classification.configs import base_tc_config 23 | 24 | 25 | def get_config(): 26 | """Get the default hyperparameter configuration.""" 27 | config = base_tc_config.get_config() 28 | config.random_seed = 0 29 | config.model_type = "transformer" 30 | config.attention_fn = favor.make_fast_softmax_attention( 31 | qkv_dim=config.qkv_dim // config.num_heads, 32 | lax_scan_unroll=16) 33 | config.model_kwargs = dict( 34 | add_pos_emb=False, 35 | qk_transform_fn_factory=functools.partial( 36 | make_spe_transform_fn, 37 | spe_cls=spe.SineSPE, 38 | spe_kwargs=dict( 39 | num_realizations=64, 40 | num_sines=10 41 | ), 42 | ) 43 | ) 44 | config.batch_size = 8 45 | config.learning_rate = config.learning_rate / 32 * 8 46 | config.num_train_steps = 30000 47 | return config 48 | 49 | 50 | def get_hyper(hyper): 51 | return hyper.product([]) 52 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_sinespe/tc/r2/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax_spe as spe 20 | 21 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 22 | from lra_benchmarks.text_classification.configs import base_tc_config 23 | 24 | 25 | def get_config(): 26 | """Get the default hyperparameter configuration.""" 27 | config = base_tc_config.get_config() 28 | config.random_seed = 1 29 | config.model_type = "transformer" 30 | config.attention_fn = favor.make_fast_softmax_attention( 31 | qkv_dim=config.qkv_dim // config.num_heads, 32 | lax_scan_unroll=16) 33 | config.model_kwargs = dict( 34 | add_pos_emb=False, 35 | qk_transform_fn_factory=functools.partial( 36 | make_spe_transform_fn, 37 | spe_cls=spe.SineSPE, 38 | spe_kwargs=dict( 39 | num_realizations=64, 40 | num_sines=10 41 | ), 42 | ) 43 | ) 44 | config.batch_size = 8 45 | config.learning_rate = config.learning_rate / 32 * 8 46 | config.num_train_steps = 30000 47 | return config 48 | 49 | 50 | def get_hyper(hyper): 51 | return hyper.product([]) 52 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_sinespe/tc/r3/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax_spe as spe 20 | 21 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 22 | from lra_benchmarks.text_classification.configs import base_tc_config 23 | 24 | 25 | def get_config(): 26 | """Get the default hyperparameter configuration.""" 27 | config = base_tc_config.get_config() 28 | config.random_seed = 2 29 | config.model_type = "transformer" 30 | config.attention_fn = favor.make_fast_softmax_attention( 31 | qkv_dim=config.qkv_dim // config.num_heads, 32 | lax_scan_unroll=16) 33 | config.model_kwargs = dict( 34 | add_pos_emb=False, 35 | qk_transform_fn_factory=functools.partial( 36 | make_spe_transform_fn, 37 | spe_cls=spe.SineSPE, 38 | spe_kwargs=dict( 39 | num_realizations=64, 40 | num_sines=10 41 | ), 42 | ) 43 | ) 44 | config.batch_size = 8 45 | config.learning_rate = config.learning_rate / 32 * 8 46 | config.num_train_steps = 30000 47 | return config 48 | 49 | 50 | def get_hyper(hyper): 51 | return hyper.product([]) 52 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_convspe_k128_shr/listops/r1/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax_spe as spe 20 | 21 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 22 | from lra_benchmarks.listops.configs import base_listops_config 23 | 24 | 25 | def get_config(): 26 | """Get the default hyperparameter configuration.""" 27 | config = base_listops_config.get_config() 28 | config.random_seed = 0 29 | config.model_type = "transformer" 30 | config.attention_fn = favor.make_fast_softmax_attention( 31 | qkv_dim=config.qkv_dim // config.num_heads, 32 | lax_scan_unroll=16) 33 | config.model_kwargs = dict( 34 | add_pos_emb=False, 35 | qk_transform_fn_factory=functools.partial( 36 | make_spe_transform_fn, 37 | spe_cls=spe.ConvSPE, 38 | spe_kwargs=dict( 39 | num_realizations=64, 40 | kernel_size=128 41 | ), 42 | shared=True 43 | ) 44 | ) 45 | config.batch_size = 8 46 | config.learning_rate = config.learning_rate / 32 * 8 47 | config.num_train_steps = 10000 48 | return config 49 | 50 | 51 | def get_hyper(hyper): 52 | return hyper.product([]) 53 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_convspe_k128_shr/listops/r2/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax_spe as spe 20 | 21 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 22 | from lra_benchmarks.listops.configs import base_listops_config 23 | 24 | 25 | def get_config(): 26 | """Get the default hyperparameter configuration.""" 27 | config = base_listops_config.get_config() 28 | config.random_seed = 1 29 | config.model_type = "transformer" 30 | config.attention_fn = favor.make_fast_softmax_attention( 31 | qkv_dim=config.qkv_dim // config.num_heads, 32 | lax_scan_unroll=16) 33 | config.model_kwargs = dict( 34 | add_pos_emb=False, 35 | qk_transform_fn_factory=functools.partial( 36 | make_spe_transform_fn, 37 | spe_cls=spe.ConvSPE, 38 | spe_kwargs=dict( 39 | num_realizations=64, 40 | kernel_size=128 41 | ), 42 | shared=True 43 | ) 44 | ) 45 | config.batch_size = 8 46 | config.learning_rate = config.learning_rate / 32 * 8 47 | config.num_train_steps = 10000 48 | return config 49 | 50 | 51 | def get_hyper(hyper): 52 | return hyper.product([]) 53 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_convspe_k128_shr/listops/r3/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax_spe as spe 20 | 21 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 22 | from lra_benchmarks.listops.configs import base_listops_config 23 | 24 | 25 | def get_config(): 26 | """Get the default hyperparameter configuration.""" 27 | config = base_listops_config.get_config() 28 | config.random_seed = 2 29 | config.model_type = "transformer" 30 | config.attention_fn = favor.make_fast_softmax_attention( 31 | qkv_dim=config.qkv_dim // config.num_heads, 32 | lax_scan_unroll=16) 33 | config.model_kwargs = dict( 34 | add_pos_emb=False, 35 | qk_transform_fn_factory=functools.partial( 36 | make_spe_transform_fn, 37 | spe_cls=spe.ConvSPE, 38 | spe_kwargs=dict( 39 | num_realizations=64, 40 | kernel_size=128 41 | ), 42 | shared=True 43 | ) 44 | ) 45 | config.batch_size = 8 46 | config.learning_rate = config.learning_rate / 32 * 8 47 | config.num_train_steps = 10000 48 | return config 49 | 50 | 51 | def get_hyper(hyper): 52 | return hyper.product([]) 53 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_convspe_k128_shr/tc/r1/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax_spe as spe 20 | 21 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 22 | from lra_benchmarks.text_classification.configs import base_tc_config 23 | 24 | 25 | def get_config(): 26 | """Get the default hyperparameter configuration.""" 27 | config = base_tc_config.get_config() 28 | config.random_seed = 0 29 | config.model_type = "transformer" 30 | config.attention_fn = favor.make_fast_softmax_attention( 31 | qkv_dim=config.qkv_dim // config.num_heads, 32 | lax_scan_unroll=16) 33 | config.model_kwargs = dict( 34 | add_pos_emb=False, 35 | qk_transform_fn_factory=functools.partial( 36 | make_spe_transform_fn, 37 | spe_cls=spe.ConvSPE, 38 | spe_kwargs=dict( 39 | num_realizations=64, 40 | kernel_size=128 41 | ), 42 | shared=True 43 | ) 44 | ) 45 | config.batch_size = 8 46 | config.learning_rate = config.learning_rate / 32 * 8 47 | config.num_train_steps = 30000 48 | return config 49 | 50 | 51 | def get_hyper(hyper): 52 | return hyper.product([]) 53 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_convspe_k128_shr/tc/r3/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax_spe as spe 20 | 21 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 22 | from lra_benchmarks.text_classification.configs import base_tc_config 23 | 24 | 25 | def get_config(): 26 | """Get the default hyperparameter configuration.""" 27 | config = base_tc_config.get_config() 28 | config.random_seed = 2 29 | config.model_type = "transformer" 30 | config.attention_fn = favor.make_fast_softmax_attention( 31 | qkv_dim=config.qkv_dim // config.num_heads, 32 | lax_scan_unroll=16) 33 | config.model_kwargs = dict( 34 | add_pos_emb=False, 35 | qk_transform_fn_factory=functools.partial( 36 | make_spe_transform_fn, 37 | spe_cls=spe.ConvSPE, 38 | spe_kwargs=dict( 39 | num_realizations=64, 40 | kernel_size=128 41 | ), 42 | shared=True 43 | ) 44 | ) 45 | config.batch_size = 8 46 | config.learning_rate = config.learning_rate / 32 * 8 47 | config.num_train_steps = 30000 48 | return config 49 | 50 | 51 | def get_hyper(hyper): 52 | return hyper.product([]) 53 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_sinespe/aan/r1/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax_spe as spe 20 | 21 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 22 | from lra_benchmarks.matching.configs import base_match_config 23 | 24 | 25 | def get_config(): 26 | """Get the default hyperparameter configuration.""" 27 | config = base_match_config.get_config() 28 | config.random_seed = 0 29 | config.model_type = "transformer" 30 | num_realizations = 64 31 | config.model_kwargs = dict( 32 | add_pos_emb=False, 33 | qk_transform_fn_factory=functools.partial( 34 | make_spe_transform_fn, 35 | spe_cls=spe.SineSPE, 36 | spe_kwargs=dict( 37 | num_realizations=num_realizations, 38 | num_sines=10 39 | ), 40 | ) 41 | ) 42 | config.attention_fn = favor.make_fast_softmax_attention( 43 | qkv_dim=num_realizations, 44 | lax_scan_unroll=16) 45 | config.batch_size = 8 46 | config.learning_rate = 0.005 47 | config.num_train_steps = 15000 48 | config.warmup = 3000 49 | config.eval_frequency = 1500 50 | return config 51 | 52 | 53 | def get_hyper(hyper): 54 | return hyper.product([]) 55 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_sinespe/aan/r2/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax_spe as spe 20 | 21 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 22 | from lra_benchmarks.matching.configs import base_match_config 23 | 24 | 25 | def get_config(): 26 | """Get the default hyperparameter configuration.""" 27 | config = base_match_config.get_config() 28 | config.random_seed = 1 29 | config.model_type = "transformer" 30 | num_realizations = 64 31 | config.model_kwargs = dict( 32 | add_pos_emb=False, 33 | qk_transform_fn_factory=functools.partial( 34 | make_spe_transform_fn, 35 | spe_cls=spe.SineSPE, 36 | spe_kwargs=dict( 37 | num_realizations=num_realizations, 38 | num_sines=10 39 | ), 40 | ) 41 | ) 42 | config.attention_fn = favor.make_fast_softmax_attention( 43 | qkv_dim=num_realizations, 44 | lax_scan_unroll=16) 45 | config.batch_size = 8 46 | config.learning_rate = 0.005 47 | config.num_train_steps = 15000 48 | config.warmup = 3000 49 | config.eval_frequency = 1500 50 | return config 51 | 52 | 53 | def get_hyper(hyper): 54 | return hyper.product([]) 55 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_sinespe/aan/r3/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax_spe as spe 20 | 21 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 22 | from lra_benchmarks.matching.configs import base_match_config 23 | 24 | 25 | def get_config(): 26 | """Get the default hyperparameter configuration.""" 27 | config = base_match_config.get_config() 28 | config.random_seed = 2 29 | config.model_type = "transformer" 30 | num_realizations = 64 31 | config.model_kwargs = dict( 32 | add_pos_emb=False, 33 | qk_transform_fn_factory=functools.partial( 34 | make_spe_transform_fn, 35 | spe_cls=spe.SineSPE, 36 | spe_kwargs=dict( 37 | num_realizations=num_realizations, 38 | num_sines=10 39 | ), 40 | ) 41 | ) 42 | config.attention_fn = favor.make_fast_softmax_attention( 43 | qkv_dim=num_realizations, 44 | lax_scan_unroll=16) 45 | config.batch_size = 8 46 | config.learning_rate = 0.005 47 | config.num_train_steps = 15000 48 | config.warmup = 3000 49 | config.eval_frequency = 1500 50 | return config 51 | 52 | 53 | def get_hyper(hyper): 54 | return hyper.product([]) 55 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax/cifar10/tune01/r1/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Configuration and hyperparameter sweeps.""" 16 | 17 | from fast_self_attention import fast_self_attention as favor 18 | 19 | from lra_benchmarks.image.configs.cifar10 import base_cifar10_config 20 | from lra_benchmarks.image.configs.cifar10.base_cifar10_config import TRAIN_EXAMPLES 21 | 22 | 23 | def get_config(): 24 | """Get the hyperparameter configuration.""" 25 | config = base_cifar10_config.get_config() 26 | config.random_seed = 0 27 | config.model_type = "transformer" 28 | config.learning_rate = .00019 29 | config.batch_size = 96 30 | config.factors = 'constant * linear_warmup * cosine_decay' 31 | config.warmup = (TRAIN_EXAMPLES // config.batch_size) * 1 32 | config.model.dropout_rate = 0.3 33 | config.model.attention_dropout_rate = 0.2 34 | config.model.learn_pos_emb = True 35 | config.model.num_layers = 1 36 | config.model.emb_dim = 128 37 | config.model.qkv_dim = 64 38 | config.model.mlp_dim = 128 39 | config.model.num_heads = 8 40 | config.model.classifier_pool = "CLS" 41 | config.attention_fn = favor.make_fast_softmax_attention( 42 | qkv_dim=config.model.qkv_dim // config.model.num_heads, 43 | lax_scan_unroll=16) 44 | return config 45 | 46 | 47 | def get_hyper(hyper): 48 | return hyper.product([]) 49 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax/cifar10/tune01/r2/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Configuration and hyperparameter sweeps.""" 16 | 17 | from fast_self_attention import fast_self_attention as favor 18 | 19 | from lra_benchmarks.image.configs.cifar10 import base_cifar10_config 20 | from lra_benchmarks.image.configs.cifar10.base_cifar10_config import TRAIN_EXAMPLES 21 | 22 | 23 | def get_config(): 24 | """Get the hyperparameter configuration.""" 25 | config = base_cifar10_config.get_config() 26 | config.random_seed = 1 27 | config.model_type = "transformer" 28 | config.learning_rate = .00019 29 | config.batch_size = 96 30 | config.factors = 'constant * linear_warmup * cosine_decay' 31 | config.warmup = (TRAIN_EXAMPLES // config.batch_size) * 1 32 | config.model.dropout_rate = 0.3 33 | config.model.attention_dropout_rate = 0.2 34 | config.model.learn_pos_emb = True 35 | config.model.num_layers = 1 36 | config.model.emb_dim = 128 37 | config.model.qkv_dim = 64 38 | config.model.mlp_dim = 128 39 | config.model.num_heads = 8 40 | config.model.classifier_pool = "CLS" 41 | config.attention_fn = favor.make_fast_softmax_attention( 42 | qkv_dim=config.model.qkv_dim // config.model.num_heads, 43 | lax_scan_unroll=16) 44 | return config 45 | 46 | 47 | def get_hyper(hyper): 48 | return hyper.product([]) 49 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax/cifar10/tune01/r3/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Configuration and hyperparameter sweeps.""" 16 | 17 | from fast_self_attention import fast_self_attention as favor 18 | 19 | from lra_benchmarks.image.configs.cifar10 import base_cifar10_config 20 | from lra_benchmarks.image.configs.cifar10.base_cifar10_config import TRAIN_EXAMPLES 21 | 22 | 23 | def get_config(): 24 | """Get the hyperparameter configuration.""" 25 | config = base_cifar10_config.get_config() 26 | config.random_seed = 2 27 | config.model_type = "transformer" 28 | config.learning_rate = .00019 29 | config.batch_size = 96 30 | config.factors = 'constant * linear_warmup * cosine_decay' 31 | config.warmup = (TRAIN_EXAMPLES // config.batch_size) * 1 32 | config.model.dropout_rate = 0.3 33 | config.model.attention_dropout_rate = 0.2 34 | config.model.learn_pos_emb = True 35 | config.model.num_layers = 1 36 | config.model.emb_dim = 128 37 | config.model.qkv_dim = 64 38 | config.model.mlp_dim = 128 39 | config.model.num_heads = 8 40 | config.model.classifier_pool = "CLS" 41 | config.attention_fn = favor.make_fast_softmax_attention( 42 | qkv_dim=config.model.qkv_dim // config.model.num_heads, 43 | lax_scan_unroll=16) 44 | return config 45 | 46 | 47 | def get_hyper(hyper): 48 | return hyper.product([]) 49 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_convspe_k128_shr/tc/r2/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax_spe as spe 20 | 21 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 22 | from lra_benchmarks.text_classification.configs import base_tc_config 23 | 24 | 25 | def get_config(): 26 | """Get the default hyperparameter configuration.""" 27 | config = base_tc_config.get_config() 28 | config.random_seed = 1 29 | config.model_type = "transformer" 30 | config.attention_fn = favor.make_fast_softmax_attention( 31 | qkv_dim=config.qkv_dim // config.num_heads, 32 | lax_scan_unroll=16) 33 | config.model_kwargs = dict( 34 | add_pos_emb=False, 35 | qk_transform_fn_factory=functools.partial( 36 | make_spe_transform_fn, 37 | spe_cls=spe.ConvSPE, 38 | spe_kwargs=dict( 39 | num_realizations=64, 40 | kernel_size=128 41 | ), 42 | shared=True 43 | ) 44 | ) 45 | config.batch_size = 8 46 | config.learning_rate = config.learning_rate / 32 * 8 47 | config.num_train_steps = 30000 48 | config.eval_frequency = config.eval_frequency * 4 49 | return config 50 | 51 | 52 | def get_hyper(hyper): 53 | return hyper.product([]) 54 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_sinespe/tc/r1/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax 20 | import jax_spe as spe 21 | 22 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 23 | from lra_benchmarks.text_classification.configs import base_tc_config 24 | 25 | 26 | def get_config(): 27 | """Get the default hyperparameter configuration.""" 28 | config = base_tc_config.get_config() 29 | config.random_seed = 0 30 | config.model_type = "transformer" 31 | config.attention_fn = favor.make_fast_generalized_attention( 32 | qkv_dim=config.qkv_dim // config.num_heads, 33 | features_type='deterministic', 34 | kernel_fn=jax.nn.relu, 35 | lax_scan_unroll=16) 36 | config.model_kwargs = dict( 37 | add_pos_emb=False, 38 | qk_transform_fn_factory=functools.partial( 39 | make_spe_transform_fn, 40 | spe_cls=spe.SineSPE, 41 | spe_kwargs=dict( 42 | num_realizations=64, 43 | num_sines=10 44 | ), 45 | ) 46 | ) 47 | config.batch_size = 8 48 | config.learning_rate = config.learning_rate / 32 * 8 49 | config.num_train_steps = 30000 50 | return config 51 | 52 | 53 | def get_hyper(hyper): 54 | return hyper.product([]) 55 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_sinespe/tc/r2/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax 20 | import jax_spe as spe 21 | 22 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 23 | from lra_benchmarks.text_classification.configs import base_tc_config 24 | 25 | 26 | def get_config(): 27 | """Get the default hyperparameter configuration.""" 28 | config = base_tc_config.get_config() 29 | config.random_seed = 1 30 | config.model_type = "transformer" 31 | config.attention_fn = favor.make_fast_generalized_attention( 32 | qkv_dim=config.qkv_dim // config.num_heads, 33 | features_type='deterministic', 34 | kernel_fn=jax.nn.relu, 35 | lax_scan_unroll=16) 36 | config.model_kwargs = dict( 37 | add_pos_emb=False, 38 | qk_transform_fn_factory=functools.partial( 39 | make_spe_transform_fn, 40 | spe_cls=spe.SineSPE, 41 | spe_kwargs=dict( 42 | num_realizations=64, 43 | num_sines=10 44 | ), 45 | ) 46 | ) 47 | config.batch_size = 8 48 | config.learning_rate = config.learning_rate / 32 * 8 49 | config.num_train_steps = 30000 50 | return config 51 | 52 | 53 | def get_hyper(hyper): 54 | return hyper.product([]) 55 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_sinespe/tc/r3/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax 20 | import jax_spe as spe 21 | 22 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 23 | from lra_benchmarks.text_classification.configs import base_tc_config 24 | 25 | 26 | def get_config(): 27 | """Get the default hyperparameter configuration.""" 28 | config = base_tc_config.get_config() 29 | config.random_seed = 2 30 | config.model_type = "transformer" 31 | config.attention_fn = favor.make_fast_generalized_attention( 32 | qkv_dim=config.qkv_dim // config.num_heads, 33 | features_type='deterministic', 34 | kernel_fn=jax.nn.relu, 35 | lax_scan_unroll=16) 36 | config.model_kwargs = dict( 37 | add_pos_emb=False, 38 | qk_transform_fn_factory=functools.partial( 39 | make_spe_transform_fn, 40 | spe_cls=spe.SineSPE, 41 | spe_kwargs=dict( 42 | num_realizations=64, 43 | num_sines=10 44 | ), 45 | ) 46 | ) 47 | config.batch_size = 8 48 | config.learning_rate = config.learning_rate / 32 * 8 49 | config.num_train_steps = 30000 50 | return config 51 | 52 | 53 | def get_hyper(hyper): 54 | return hyper.product([]) 55 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_convspe_k128_shr/aan/r1/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax_spe as spe 20 | 21 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 22 | from lra_benchmarks.matching.configs import base_match_config 23 | 24 | 25 | def get_config(): 26 | """Get the default hyperparameter configuration.""" 27 | config = base_match_config.get_config() 28 | config.random_seed = 0 29 | config.model_type = "transformer" 30 | num_realizations = 64 31 | config.model_kwargs = dict( 32 | add_pos_emb=False, 33 | qk_transform_fn_factory=functools.partial( 34 | make_spe_transform_fn, 35 | spe_cls=spe.ConvSPE, 36 | spe_kwargs=dict( 37 | num_realizations=num_realizations, 38 | kernel_size=128 39 | ), 40 | shared=True 41 | ) 42 | ) 43 | config.attention_fn = favor.make_fast_softmax_attention( 44 | qkv_dim=num_realizations, 45 | lax_scan_unroll=16) 46 | config.batch_size = 8 47 | config.learning_rate = 0.005 48 | config.num_train_steps = 15000 49 | config.warmup = 3000 50 | config.eval_frequency = 1500 51 | return config 52 | 53 | 54 | def get_hyper(hyper): 55 | return hyper.product([]) 56 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_convspe_k128_shr/aan/r2/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax_spe as spe 20 | 21 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 22 | from lra_benchmarks.matching.configs import base_match_config 23 | 24 | 25 | def get_config(): 26 | """Get the default hyperparameter configuration.""" 27 | config = base_match_config.get_config() 28 | config.random_seed = 1 29 | config.model_type = "transformer" 30 | num_realizations = 64 31 | config.model_kwargs = dict( 32 | add_pos_emb=False, 33 | qk_transform_fn_factory=functools.partial( 34 | make_spe_transform_fn, 35 | spe_cls=spe.ConvSPE, 36 | spe_kwargs=dict( 37 | num_realizations=num_realizations, 38 | kernel_size=128 39 | ), 40 | shared=True 41 | ) 42 | ) 43 | config.attention_fn = favor.make_fast_softmax_attention( 44 | qkv_dim=num_realizations, 45 | lax_scan_unroll=16) 46 | config.batch_size = 8 47 | config.learning_rate = 0.005 48 | config.num_train_steps = 15000 49 | config.warmup = 3000 50 | config.eval_frequency = 1500 51 | return config 52 | 53 | 54 | def get_hyper(hyper): 55 | return hyper.product([]) 56 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_convspe_k128_shr/aan/r3/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax_spe as spe 20 | 21 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 22 | from lra_benchmarks.matching.configs import base_match_config 23 | 24 | 25 | def get_config(): 26 | """Get the default hyperparameter configuration.""" 27 | config = base_match_config.get_config() 28 | config.random_seed = 2 29 | config.model_type = "transformer" 30 | num_realizations = 64 31 | config.model_kwargs = dict( 32 | add_pos_emb=False, 33 | qk_transform_fn_factory=functools.partial( 34 | make_spe_transform_fn, 35 | spe_cls=spe.ConvSPE, 36 | spe_kwargs=dict( 37 | num_realizations=num_realizations, 38 | kernel_size=128 39 | ), 40 | shared=True 41 | ) 42 | ) 43 | config.attention_fn = favor.make_fast_softmax_attention( 44 | qkv_dim=num_realizations, 45 | lax_scan_unroll=16) 46 | config.batch_size = 8 47 | config.learning_rate = 0.005 48 | config.num_train_steps = 15000 49 | config.warmup = 3000 50 | config.eval_frequency = 1500 51 | return config 52 | 53 | 54 | def get_hyper(hyper): 55 | return hyper.product([]) 56 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_sinespe/aan/r1/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax 20 | import jax_spe as spe 21 | 22 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 23 | from lra_benchmarks.matching.configs import base_match_config 24 | 25 | 26 | def get_config(): 27 | """Get the default hyperparameter configuration.""" 28 | config = base_match_config.get_config() 29 | config.random_seed = 0 30 | config.model_type = "transformer" 31 | num_realizations = 64 32 | config.attention_fn = favor.make_fast_generalized_attention( 33 | qkv_dim=num_realizations, 34 | features_type='deterministic', 35 | kernel_fn=jax.nn.relu, 36 | lax_scan_unroll=16) 37 | config.model_kwargs = dict( 38 | add_pos_emb=False, 39 | qk_transform_fn_factory=functools.partial( 40 | make_spe_transform_fn, 41 | spe_cls=spe.SineSPE, 42 | spe_kwargs=dict( 43 | num_realizations=num_realizations, 44 | num_sines=10 45 | ), 46 | ) 47 | ) 48 | config.batch_size = 8 49 | config.learning_rate = 0.005 50 | config.num_train_steps = 15000 51 | config.warmup = 3000 52 | config.eval_frequency = 1500 53 | return config 54 | 55 | 56 | def get_hyper(hyper): 57 | return hyper.product([]) 58 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_sinespe/aan/r2/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax 20 | import jax_spe as spe 21 | 22 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 23 | from lra_benchmarks.matching.configs import base_match_config 24 | 25 | 26 | def get_config(): 27 | """Get the default hyperparameter configuration.""" 28 | config = base_match_config.get_config() 29 | config.random_seed = 1 30 | config.model_type = "transformer" 31 | num_realizations = 64 32 | config.attention_fn = favor.make_fast_generalized_attention( 33 | qkv_dim=num_realizations, 34 | features_type='deterministic', 35 | kernel_fn=jax.nn.relu, 36 | lax_scan_unroll=16) 37 | config.model_kwargs = dict( 38 | add_pos_emb=False, 39 | qk_transform_fn_factory=functools.partial( 40 | make_spe_transform_fn, 41 | spe_cls=spe.SineSPE, 42 | spe_kwargs=dict( 43 | num_realizations=num_realizations, 44 | num_sines=10 45 | ), 46 | ) 47 | ) 48 | config.batch_size = 8 49 | config.learning_rate = 0.005 50 | config.num_train_steps = 15000 51 | config.warmup = 3000 52 | config.eval_frequency = 1500 53 | return config 54 | 55 | 56 | def get_hyper(hyper): 57 | return hyper.product([]) 58 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_sinespe/aan/r3/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax 20 | import jax_spe as spe 21 | 22 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 23 | from lra_benchmarks.matching.configs import base_match_config 24 | 25 | 26 | def get_config(): 27 | """Get the default hyperparameter configuration.""" 28 | config = base_match_config.get_config() 29 | config.random_seed = 2 30 | config.model_type = "transformer" 31 | num_realizations = 64 32 | config.attention_fn = favor.make_fast_generalized_attention( 33 | qkv_dim=num_realizations, 34 | features_type='deterministic', 35 | kernel_fn=jax.nn.relu, 36 | lax_scan_unroll=16) 37 | config.model_kwargs = dict( 38 | add_pos_emb=False, 39 | qk_transform_fn_factory=functools.partial( 40 | make_spe_transform_fn, 41 | spe_cls=spe.SineSPE, 42 | spe_kwargs=dict( 43 | num_realizations=num_realizations, 44 | num_sines=10 45 | ), 46 | ) 47 | ) 48 | config.batch_size = 8 49 | config.learning_rate = 0.005 50 | config.num_train_steps = 15000 51 | config.warmup = 3000 52 | config.eval_frequency = 1500 53 | return config 54 | 55 | 56 | def get_hyper(hyper): 57 | return hyper.product([]) 58 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_sinespe/listops/r1/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax 20 | import jax_spe as spe 21 | 22 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 23 | from lra_benchmarks.listops.configs import base_listops_config 24 | 25 | 26 | def get_config(): 27 | """Get the default hyperparameter configuration.""" 28 | config = base_listops_config.get_config() 29 | config.random_seed = 0 30 | config.model_type = "transformer" 31 | config.attention_fn = favor.make_fast_generalized_attention( 32 | qkv_dim=config.qkv_dim // config.num_heads, 33 | features_type='deterministic', 34 | kernel_fn=jax.nn.relu, 35 | lax_scan_unroll=16) 36 | config.model_kwargs = dict( 37 | add_pos_emb=False, 38 | qk_transform_fn_factory=functools.partial( 39 | make_spe_transform_fn, 40 | spe_cls=spe.SineSPE, 41 | spe_kwargs=dict( 42 | num_realizations=64, 43 | num_sines=10 44 | ), 45 | ) 46 | ) 47 | config.batch_size = 8 48 | config.learning_rate = config.learning_rate / 32 * 8 49 | config.num_train_steps = 10000 50 | config.eval_frequency = config.eval_frequency * 4 51 | return config 52 | 53 | 54 | def get_hyper(hyper): 55 | return hyper.product([]) 56 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_sinespe/listops/r2/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax 20 | import jax_spe as spe 21 | 22 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 23 | from lra_benchmarks.listops.configs import base_listops_config 24 | 25 | 26 | def get_config(): 27 | """Get the default hyperparameter configuration.""" 28 | config = base_listops_config.get_config() 29 | config.random_seed = 1 30 | config.model_type = "transformer" 31 | config.attention_fn = favor.make_fast_generalized_attention( 32 | qkv_dim=config.qkv_dim // config.num_heads, 33 | features_type='deterministic', 34 | kernel_fn=jax.nn.relu, 35 | lax_scan_unroll=16) 36 | config.model_kwargs = dict( 37 | add_pos_emb=False, 38 | qk_transform_fn_factory=functools.partial( 39 | make_spe_transform_fn, 40 | spe_cls=spe.SineSPE, 41 | spe_kwargs=dict( 42 | num_realizations=64, 43 | num_sines=10 44 | ), 45 | ) 46 | ) 47 | config.batch_size = 8 48 | config.learning_rate = config.learning_rate / 32 * 8 49 | config.num_train_steps = 10000 50 | config.eval_frequency = config.eval_frequency * 4 51 | return config 52 | 53 | 54 | def get_hyper(hyper): 55 | return hyper.product([]) 56 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_sinespe/listops/r3/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax 20 | import jax_spe as spe 21 | 22 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 23 | from lra_benchmarks.listops.configs import base_listops_config 24 | 25 | 26 | def get_config(): 27 | """Get the default hyperparameter configuration.""" 28 | config = base_listops_config.get_config() 29 | config.random_seed = 2 30 | config.model_type = "transformer" 31 | config.attention_fn = favor.make_fast_generalized_attention( 32 | qkv_dim=config.qkv_dim // config.num_heads, 33 | features_type='deterministic', 34 | kernel_fn=jax.nn.relu, 35 | lax_scan_unroll=16) 36 | config.model_kwargs = dict( 37 | add_pos_emb=False, 38 | qk_transform_fn_factory=functools.partial( 39 | make_spe_transform_fn, 40 | spe_cls=spe.SineSPE, 41 | spe_kwargs=dict( 42 | num_realizations=64, 43 | num_sines=10 44 | ), 45 | ) 46 | ) 47 | config.batch_size = 8 48 | config.learning_rate = config.learning_rate / 32 * 8 49 | config.num_train_steps = 10000 50 | config.eval_frequency = config.eval_frequency * 4 51 | return config 52 | 53 | 54 | def get_hyper(hyper): 55 | return hyper.product([]) 56 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_convspe_k128_shr/aan/r1/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax 20 | import jax_spe as spe 21 | 22 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 23 | from lra_benchmarks.matching.configs import base_match_config 24 | 25 | 26 | def get_config(): 27 | """Get the default hyperparameter configuration.""" 28 | config = base_match_config.get_config() 29 | config.random_seed = 0 30 | config.model_type = "transformer" 31 | num_realizations = 64 32 | config.attention_fn = favor.make_fast_generalized_attention( 33 | qkv_dim=num_realizations, 34 | features_type='deterministic', 35 | kernel_fn=jax.nn.relu, 36 | lax_scan_unroll=16) 37 | config.model_kwargs = dict( 38 | add_pos_emb=False, 39 | qk_transform_fn_factory=functools.partial( 40 | make_spe_transform_fn, 41 | spe_cls=spe.ConvSPE, 42 | spe_kwargs=dict( 43 | num_realizations=num_realizations, 44 | kernel_size=128 45 | ), 46 | shared=True 47 | ) 48 | ) 49 | config.batch_size = 8 50 | config.learning_rate = 0.005 51 | config.num_train_steps = 15000 52 | config.warmup = 3000 53 | config.eval_frequency = 1500 54 | return config 55 | 56 | 57 | def get_hyper(hyper): 58 | return hyper.product([]) 59 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_convspe_k128_shr/aan/r2/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax 20 | import jax_spe as spe 21 | 22 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 23 | from lra_benchmarks.matching.configs import base_match_config 24 | 25 | 26 | def get_config(): 27 | """Get the default hyperparameter configuration.""" 28 | config = base_match_config.get_config() 29 | config.random_seed = 1 30 | config.model_type = "transformer" 31 | num_realizations = 64 32 | config.attention_fn = favor.make_fast_generalized_attention( 33 | qkv_dim=num_realizations, 34 | features_type='deterministic', 35 | kernel_fn=jax.nn.relu, 36 | lax_scan_unroll=16) 37 | config.model_kwargs = dict( 38 | add_pos_emb=False, 39 | qk_transform_fn_factory=functools.partial( 40 | make_spe_transform_fn, 41 | spe_cls=spe.ConvSPE, 42 | spe_kwargs=dict( 43 | num_realizations=num_realizations, 44 | kernel_size=128 45 | ), 46 | shared=True 47 | ) 48 | ) 49 | config.batch_size = 8 50 | config.learning_rate = 0.005 51 | config.num_train_steps = 15000 52 | config.warmup = 3000 53 | config.eval_frequency = 1500 54 | return config 55 | 56 | 57 | def get_hyper(hyper): 58 | return hyper.product([]) 59 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_convspe_k128_shr/aan/r3/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax 20 | import jax_spe as spe 21 | 22 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 23 | from lra_benchmarks.matching.configs import base_match_config 24 | 25 | 26 | def get_config(): 27 | """Get the default hyperparameter configuration.""" 28 | config = base_match_config.get_config() 29 | config.random_seed = 2 30 | config.model_type = "transformer" 31 | num_realizations = 64 32 | config.attention_fn = favor.make_fast_generalized_attention( 33 | qkv_dim=num_realizations, 34 | features_type='deterministic', 35 | kernel_fn=jax.nn.relu, 36 | lax_scan_unroll=16) 37 | config.model_kwargs = dict( 38 | add_pos_emb=False, 39 | qk_transform_fn_factory=functools.partial( 40 | make_spe_transform_fn, 41 | spe_cls=spe.ConvSPE, 42 | spe_kwargs=dict( 43 | num_realizations=num_realizations, 44 | kernel_size=128 45 | ), 46 | shared=True 47 | ) 48 | ) 49 | config.batch_size = 8 50 | config.learning_rate = 0.005 51 | config.num_train_steps = 15000 52 | config.warmup = 3000 53 | config.eval_frequency = 1500 54 | return config 55 | 56 | 57 | def get_hyper(hyper): 58 | return hyper.product([]) 59 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_convspe_k128_shr/listops/r1/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax 20 | import jax_spe as spe 21 | 22 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 23 | from lra_benchmarks.listops.configs import base_listops_config 24 | 25 | 26 | def get_config(): 27 | """Get the default hyperparameter configuration.""" 28 | config = base_listops_config.get_config() 29 | config.random_seed = 0 30 | config.model_type = "transformer" 31 | config.attention_fn = favor.make_fast_generalized_attention( 32 | qkv_dim=config.qkv_dim // config.num_heads, 33 | features_type='deterministic', 34 | kernel_fn=jax.nn.relu, 35 | lax_scan_unroll=16) 36 | config.model_kwargs = dict( 37 | add_pos_emb=False, 38 | qk_transform_fn_factory=functools.partial( 39 | make_spe_transform_fn, 40 | spe_cls=spe.ConvSPE, 41 | spe_kwargs=dict( 42 | num_realizations=64, 43 | kernel_size=128 44 | ), 45 | shared=True 46 | ) 47 | ) 48 | config.batch_size = 8 49 | config.learning_rate = config.learning_rate / 32 * 8 50 | config.num_train_steps = 10000 51 | config.eval_frequency = config.eval_frequency * 4 52 | return config 53 | 54 | 55 | def get_hyper(hyper): 56 | return hyper.product([]) 57 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_convspe_k128_shr/listops/r2/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax 20 | import jax_spe as spe 21 | 22 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 23 | from lra_benchmarks.listops.configs import base_listops_config 24 | 25 | 26 | def get_config(): 27 | """Get the default hyperparameter configuration.""" 28 | config = base_listops_config.get_config() 29 | config.random_seed = 1 30 | config.model_type = "transformer" 31 | config.attention_fn = favor.make_fast_generalized_attention( 32 | qkv_dim=config.qkv_dim // config.num_heads, 33 | features_type='deterministic', 34 | kernel_fn=jax.nn.relu, 35 | lax_scan_unroll=16) 36 | config.model_kwargs = dict( 37 | add_pos_emb=False, 38 | qk_transform_fn_factory=functools.partial( 39 | make_spe_transform_fn, 40 | spe_cls=spe.ConvSPE, 41 | spe_kwargs=dict( 42 | num_realizations=64, 43 | kernel_size=128 44 | ), 45 | shared=True 46 | ) 47 | ) 48 | config.batch_size = 8 49 | config.learning_rate = config.learning_rate / 32 * 8 50 | config.num_train_steps = 10000 51 | config.eval_frequency = config.eval_frequency * 4 52 | return config 53 | 54 | 55 | def get_hyper(hyper): 56 | return hyper.product([]) 57 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_convspe_k128_shr/listops/r3/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax 20 | import jax_spe as spe 21 | 22 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 23 | from lra_benchmarks.listops.configs import base_listops_config 24 | 25 | 26 | def get_config(): 27 | """Get the default hyperparameter configuration.""" 28 | config = base_listops_config.get_config() 29 | config.random_seed = 2 30 | config.model_type = "transformer" 31 | config.attention_fn = favor.make_fast_generalized_attention( 32 | qkv_dim=config.qkv_dim // config.num_heads, 33 | features_type='deterministic', 34 | kernel_fn=jax.nn.relu, 35 | lax_scan_unroll=16) 36 | config.model_kwargs = dict( 37 | add_pos_emb=False, 38 | qk_transform_fn_factory=functools.partial( 39 | make_spe_transform_fn, 40 | spe_cls=spe.ConvSPE, 41 | spe_kwargs=dict( 42 | num_realizations=64, 43 | kernel_size=128 44 | ), 45 | shared=True 46 | ) 47 | ) 48 | config.batch_size = 8 49 | config.learning_rate = config.learning_rate / 32 * 8 50 | config.num_train_steps = 10000 51 | config.eval_frequency = config.eval_frequency * 4 52 | return config 53 | 54 | 55 | def get_hyper(hyper): 56 | return hyper.product([]) 57 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_convspe_k128_shr/tc/r1/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax 20 | import jax_spe as spe 21 | 22 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 23 | from lra_benchmarks.text_classification.configs import base_tc_config 24 | 25 | 26 | def get_config(): 27 | """Get the default hyperparameter configuration.""" 28 | config = base_tc_config.get_config() 29 | config.random_seed = 0 30 | config.model_type = "transformer" 31 | config.attention_fn = favor.make_fast_generalized_attention( 32 | qkv_dim=config.qkv_dim // config.num_heads, 33 | features_type='deterministic', 34 | kernel_fn=jax.nn.relu, 35 | lax_scan_unroll=16) 36 | config.model_kwargs = dict( 37 | add_pos_emb=False, 38 | qk_transform_fn_factory=functools.partial( 39 | make_spe_transform_fn, 40 | spe_cls=spe.ConvSPE, 41 | spe_kwargs=dict( 42 | num_realizations=64, 43 | kernel_size=128 44 | ), 45 | shared=True 46 | ) 47 | ) 48 | config.batch_size = 8 49 | config.learning_rate = config.learning_rate / 32 * 8 50 | config.num_train_steps = 30000 51 | config.eval_frequency = config.eval_frequency * 4 52 | return config 53 | 54 | 55 | def get_hyper(hyper): 56 | return hyper.product([]) 57 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_convspe_k128_shr/tc/r2/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax 20 | import jax_spe as spe 21 | 22 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 23 | from lra_benchmarks.text_classification.configs import base_tc_config 24 | 25 | 26 | def get_config(): 27 | """Get the default hyperparameter configuration.""" 28 | config = base_tc_config.get_config() 29 | config.random_seed = 1 30 | config.model_type = "transformer" 31 | config.attention_fn = favor.make_fast_generalized_attention( 32 | qkv_dim=config.qkv_dim // config.num_heads, 33 | features_type='deterministic', 34 | kernel_fn=jax.nn.relu, 35 | lax_scan_unroll=16) 36 | config.model_kwargs = dict( 37 | add_pos_emb=False, 38 | qk_transform_fn_factory=functools.partial( 39 | make_spe_transform_fn, 40 | spe_cls=spe.ConvSPE, 41 | spe_kwargs=dict( 42 | num_realizations=64, 43 | kernel_size=128 44 | ), 45 | shared=True 46 | ) 47 | ) 48 | config.batch_size = 8 49 | config.learning_rate = config.learning_rate / 32 * 8 50 | config.num_train_steps = 30000 51 | config.eval_frequency = config.eval_frequency * 4 52 | return config 53 | 54 | 55 | def get_hyper(hyper): 56 | return hyper.product([]) 57 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_convspe_k128_shr/tc/r3/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration and hyperparameter sweeps.""" 15 | 16 | import functools 17 | 18 | from fast_self_attention import fast_self_attention as favor 19 | import jax 20 | import jax_spe as spe 21 | 22 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 23 | from lra_benchmarks.text_classification.configs import base_tc_config 24 | 25 | 26 | def get_config(): 27 | """Get the default hyperparameter configuration.""" 28 | config = base_tc_config.get_config() 29 | config.random_seed = 2 30 | config.model_type = "transformer" 31 | config.attention_fn = favor.make_fast_generalized_attention( 32 | qkv_dim=config.qkv_dim // config.num_heads, 33 | features_type='deterministic', 34 | kernel_fn=jax.nn.relu, 35 | lax_scan_unroll=16) 36 | config.model_kwargs = dict( 37 | add_pos_emb=False, 38 | qk_transform_fn_factory=functools.partial( 39 | make_spe_transform_fn, 40 | spe_cls=spe.ConvSPE, 41 | spe_kwargs=dict( 42 | num_realizations=64, 43 | kernel_size=128 44 | ), 45 | shared=True 46 | ) 47 | ) 48 | config.batch_size = 8 49 | config.learning_rate = config.learning_rate / 32 * 8 50 | config.num_train_steps = 30000 51 | config.eval_frequency = config.eval_frequency * 4 52 | return config 53 | 54 | 55 | def get_hyper(hyper): 56 | return hyper.product([]) 57 | -------------------------------------------------------------------------------- /experiments/groove/src/spe_music/model/music_performer.py: -------------------------------------------------------------------------------- 1 | from confugue import configurable 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | from .fast_transformer_decoder import FastTransformerDecoder 7 | from .transformer_helpers import ( 8 | TokenEmbedding, 9 | PositionalEncoding, 10 | weights_init 11 | ) 12 | 13 | @configurable 14 | class MusicPerformer(nn.Module): 15 | def __init__(self, n_token, d_model, d_embed, dropout=0.1, max_len=20480): 16 | super(MusicPerformer, self).__init__() 17 | self.n_token = n_token 18 | self.d_model = d_model 19 | self.max_len = max_len 20 | 21 | self.token_emb = TokenEmbedding(n_token, d_embed, d_model) 22 | self.d_embed = d_embed 23 | 24 | if self._cfg.get('add_positional_encoding', True): 25 | self.pe = self._cfg['positional_encoding'].configure( 26 | PositionalEncoding, d_embed=d_embed, max_pos=max_len) 27 | else: 28 | self.pe = None 29 | self.dec_out_proj = nn.Linear(d_model, n_token) 30 | 31 | self.transformer_decoder = self._cfg['decoder'].configure( 32 | FastTransformerDecoder, 33 | d_model=d_model, 34 | dropout=dropout 35 | ) 36 | 37 | self.emb_dropout = nn.Dropout(dropout) 38 | self.apply(weights_init) 39 | 40 | def forward(self, x, attn_kwargs=None): 41 | x_emb = self.token_emb(x) 42 | x_inp = self.emb_dropout(x_emb) 43 | if self.pe: 44 | x_inp = x_inp + self.pe(x.size(1)).permute(1, 0, 2) 45 | 46 | dec_out = self.transformer_decoder(x_inp, attn_kwargs=attn_kwargs) 47 | dec_logits = self.dec_out_proj(dec_out) 48 | 49 | return dec_logits 50 | 51 | def compute_loss(self, dec_logits, dec_tgt, pad_index=None): 52 | if pad_index is None: 53 | pad_index = -100 54 | recons_loss = F.cross_entropy( 55 | dec_logits.view(-1, dec_logits.size(-1)), dec_tgt.contiguous().view(-1), 56 | ignore_index=pad_index, reduction='mean' 57 | ).float() 58 | 59 | return { 60 | 'recons_loss': recons_loss, 61 | 'total_loss': recons_loss 62 | } 63 | -------------------------------------------------------------------------------- /experiments/pop_piano/utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from models.music_performer_ape import MusicPerformer 3 | from models.music_performer_spe import MusicPerformerSPE 4 | 5 | def pickle_load(f): 6 | return pickle.load(open(f, 'rb')) 7 | 8 | def pickle_dump(obj, f): 9 | pickle.dump(obj, open(f, 'wb'), protocol=pickle.HIGHEST_PROTOCOL) 10 | 11 | def load_model(model_conf, gpuid, vocab_size): 12 | if model_conf['pe_type'] == 'APE': 13 | model = MusicPerformer( 14 | vocab_size, model_conf['n_layer'], model_conf['n_head'], 15 | model_conf['d_model'], model_conf['d_ff'], model_conf['d_embed'], 16 | favor_feature_dims=model_conf['feature_map']['n_dims'] 17 | ).cuda(gpuid) 18 | elif model_conf['pe_type'] == 'SineSPE': 19 | model = MusicPerformerSPE( 20 | vocab_size, model_conf['n_layer'], model_conf['n_head'], 21 | model_conf['d_model'], model_conf['d_ff'], model_conf['d_embed'], 22 | favor_feature_dims=model_conf['feature_map']['n_dims'], 23 | share_pe=model_conf['share_pe'], 24 | share_spe_filter=model_conf['share_spe_filter'], 25 | spe_type='SineSPE', 26 | use_gated_filter=model_conf['use_gated_filter'], 27 | spe_module_params={ 28 | 'num_sines': model_conf['positional_encoder']['num_sines'], 29 | 'num_realizations': model_conf['positional_encoder']['num_realizations'] 30 | } 31 | ).cuda(gpuid) 32 | elif model_conf['pe_type'] == 'ConvSPE': 33 | model = MusicPerformerSPE( 34 | vocab_size, model_conf['n_layer'], model_conf['n_head'], 35 | model_conf['d_model'], model_conf['d_ff'], model_conf['d_embed'], 36 | favor_feature_dims=model_conf['feature_map']['n_dims'], 37 | share_pe=model_conf['share_pe'], 38 | share_spe_filter=model_conf['share_spe_filter'], 39 | spe_type='ConvSPE', 40 | use_gated_filter=model_conf['use_gated_filter'], 41 | spe_module_params={ 42 | 'kernel_size': model_conf['positional_encoder']['kernel_size'], 43 | 'num_realizations': model_conf['positional_encoder']['num_realizations'] 44 | } 45 | ).cuda(gpuid) 46 | 47 | return model -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax/cifar10/r1/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Configuration and hyperparameter sweeps.""" 16 | 17 | from fast_self_attention import fast_self_attention as favor 18 | 19 | from lra_benchmarks.image.configs.cifar10 import base_cifar10_config 20 | from lra_benchmarks.image.configs.cifar10.base_cifar10_config import TRAIN_EXAMPLES, VALID_EXAMPLES 21 | 22 | 23 | NUM_EPOCHS = 200 24 | 25 | 26 | def get_config(): 27 | """Get the hyperparameter configuration.""" 28 | config = base_cifar10_config.get_config() 29 | config.random_seed = 0 30 | config.model_type = "transformer" 31 | config.learning_rate = .00025 32 | config.batch_size = 96 33 | config.eval_frequency = TRAIN_EXAMPLES // config.batch_size 34 | config.num_train_steps = (TRAIN_EXAMPLES // config.batch_size) * NUM_EPOCHS 35 | config.num_eval_steps = VALID_EXAMPLES // config.batch_size 36 | config.factors = 'constant * linear_warmup * cosine_decay' 37 | config.warmup = (TRAIN_EXAMPLES // config.batch_size) * 1 38 | 39 | config.model.dropout_rate = 0.3 40 | config.model.attention_dropout_rate = 0.2 41 | config.model.learn_pos_emb = True 42 | config.model.num_layers = 1 43 | config.model.emb_dim = 128 44 | config.model.qkv_dim = 64 45 | config.model.mlp_dim = 128 46 | config.model.num_heads = 8 47 | config.model.classifier_pool = "CLS" 48 | 49 | config.attention_fn = favor.make_fast_softmax_attention( 50 | qkv_dim=config.model.qkv_dim // config.model.num_heads, 51 | lax_scan_unroll=16) 52 | return config 53 | 54 | 55 | def get_hyper(hyper): 56 | return hyper.product([]) 57 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax/cifar10/r2/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Configuration and hyperparameter sweeps.""" 16 | 17 | from fast_self_attention import fast_self_attention as favor 18 | 19 | from lra_benchmarks.image.configs.cifar10 import base_cifar10_config 20 | from lra_benchmarks.image.configs.cifar10.base_cifar10_config import TRAIN_EXAMPLES, VALID_EXAMPLES 21 | 22 | 23 | NUM_EPOCHS = 200 24 | 25 | 26 | def get_config(): 27 | """Get the hyperparameter configuration.""" 28 | config = base_cifar10_config.get_config() 29 | config.random_seed = 1 30 | config.model_type = "transformer" 31 | config.learning_rate = .00025 32 | config.batch_size = 96 33 | config.eval_frequency = TRAIN_EXAMPLES // config.batch_size 34 | config.num_train_steps = (TRAIN_EXAMPLES // config.batch_size) * NUM_EPOCHS 35 | config.num_eval_steps = VALID_EXAMPLES // config.batch_size 36 | config.factors = 'constant * linear_warmup * cosine_decay' 37 | config.warmup = (TRAIN_EXAMPLES // config.batch_size) * 1 38 | 39 | config.model.dropout_rate = 0.3 40 | config.model.attention_dropout_rate = 0.2 41 | config.model.learn_pos_emb = True 42 | config.model.num_layers = 1 43 | config.model.emb_dim = 128 44 | config.model.qkv_dim = 64 45 | config.model.mlp_dim = 128 46 | config.model.num_heads = 8 47 | config.model.classifier_pool = "CLS" 48 | 49 | config.attention_fn = favor.make_fast_softmax_attention( 50 | qkv_dim=config.model.qkv_dim // config.model.num_heads, 51 | lax_scan_unroll=16) 52 | return config 53 | 54 | 55 | def get_hyper(hyper): 56 | return hyper.product([]) 57 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax/cifar10/r3/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Configuration and hyperparameter sweeps.""" 16 | 17 | from fast_self_attention import fast_self_attention as favor 18 | 19 | from lra_benchmarks.image.configs.cifar10 import base_cifar10_config 20 | from lra_benchmarks.image.configs.cifar10.base_cifar10_config import TRAIN_EXAMPLES, VALID_EXAMPLES 21 | 22 | 23 | NUM_EPOCHS = 200 24 | 25 | 26 | def get_config(): 27 | """Get the hyperparameter configuration.""" 28 | config = base_cifar10_config.get_config() 29 | config.random_seed = 2 30 | config.model_type = "transformer" 31 | config.learning_rate = .00025 32 | config.batch_size = 96 33 | config.eval_frequency = TRAIN_EXAMPLES // config.batch_size 34 | config.num_train_steps = (TRAIN_EXAMPLES // config.batch_size) * NUM_EPOCHS 35 | config.num_eval_steps = VALID_EXAMPLES // config.batch_size 36 | config.factors = 'constant * linear_warmup * cosine_decay' 37 | config.warmup = (TRAIN_EXAMPLES // config.batch_size) * 1 38 | 39 | config.model.dropout_rate = 0.3 40 | config.model.attention_dropout_rate = 0.2 41 | config.model.learn_pos_emb = True 42 | config.model.num_layers = 1 43 | config.model.emb_dim = 128 44 | config.model.qkv_dim = 64 45 | config.model.mlp_dim = 128 46 | config.model.num_heads = 8 47 | config.model.classifier_pool = "CLS" 48 | 49 | config.attention_fn = favor.make_fast_softmax_attention( 50 | qkv_dim=config.model.qkv_dim // config.model.num_heads, 51 | lax_scan_unroll=16) 52 | return config 53 | 54 | 55 | def get_hyper(hyper): 56 | return hyper.product([]) 57 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Stochastic Positional Encoding (SPE) 2 | 3 | This is the source code repository for the ICML 2021 paper [*Relative Positional Encoding for Transformers with Linear Complexity*](http://proceedings.mlr.press/v139/liutkus21a.html) by Antoine Liutkus, Ondřej Cífka, Shih-Lun Wu, Umut Şimşekli, Yi-Hsuan Yang and Gaël Richard. 4 | 5 | In this paper, we propose **Stochastic Positional Encoding** (SPE), which provably behaves like relative PE while being compatible with linear-complexity Transformers. We do this by drawing a connection between positional encoding and cross-covariance structures of correlated Gaussian processes. 6 | 7 | ![image](https://user-images.githubusercontent.com/8046580/119335679-fcf09280-bc8c-11eb-9525-bec9372bf6fb.png) 8 | 9 | Check out also the [companion website](https://cifkao.github.io/spe/) with music examples. 10 | 11 | Citation: 12 | ```bibtex 13 | @inproceedings{pmlr-v139-liutkus21a, 14 | title = {Relative Positional Encoding for {Transformers} with Linear Complexity}, 15 | author = {Liutkus, Antoine and C{\'i}fka, Ond{\v r}ej and Wu, Shih-Lun and {\c S}im{\c s}ekli, Umut and Yang, Yi-Hsuan and Richard, Ga{\"e}l}, 16 | booktitle = {Proceedings of the 38th International Conference on Machine Learning}, 17 | pages = {7067--7079}, 18 | year = {2021}, 19 | editor = {Meila, Marina and Zhang, Tong}, 20 | volume = {139}, 21 | series = {Proceedings of Machine Learning Research}, 22 | month = {18--24 Jul}, 23 | publisher = {PMLR}, 24 | pdf = {http://proceedings.mlr.press/v139/liutkus21a/liutkus21a.pdf}, 25 | url = {http://proceedings.mlr.press/v139/liutkus21a.html} 26 | } 27 | ``` 28 | 29 | ## SPE implementation 30 | 31 | We have implemented SPE in PyTorch and JAX/Flax. Each implementation is available as a separate Python package under [`src`](./src). 32 | 33 | ## Experiments 34 | 35 | Each of the 3 experiments (LRA, pop piano generation, groove continuation) has a dedicated directory under [`experiments`](./experiments). See the README files there for how to set up the environment and prepare the datasets. To make sure you have the custom dependencies for each experiment, clone this repository with `--recurse-submodules` or run `git submodule init && git submodule update` after cloning. 36 | -------------------------------------------------------------------------------- /experiments/pop_piano/models/music_performer_ape.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | from ape_fast_transformer_decoder import FastTransformerDecoder 6 | from transformer_helpers import ( 7 | TokenEmbedding, 8 | PositionalEncoding, 9 | weights_init 10 | ) 11 | 12 | class MusicPerformer(nn.Module): 13 | def __init__(self, n_token, n_layer, n_head, d_model, d_ff, d_embed, 14 | activation='relu', dropout=0.1, use_pe=True, favor_feature_dims=None 15 | ): 16 | super(MusicPerformer, self).__init__() 17 | self.n_token = n_token 18 | self.n_layer = n_layer 19 | self.n_head = n_head 20 | self.d_model = d_model 21 | self.d_ff = d_ff 22 | self.dropout = dropout 23 | self.activation = activation 24 | self.favor_feature_dims = favor_feature_dims 25 | 26 | self.token_emb = TokenEmbedding(n_token, d_embed, d_model) 27 | self.d_embed = d_embed 28 | 29 | self.pe = PositionalEncoding(d_embed) 30 | self.dec_out_proj = nn.Linear(d_model, n_token) 31 | 32 | self.transformer_decoder = FastTransformerDecoder( 33 | n_layer, n_head, d_model, d_ff, dropout, activation, favor_feature_dims 34 | ) 35 | 36 | self.emb_dropout = nn.Dropout(self.dropout) 37 | self.use_pe = use_pe 38 | self.apply(weights_init) 39 | 40 | print ('[info] model init completed') 41 | 42 | def forward(self, x, keep_last_only=False, attn_kwargs=None): 43 | x_emb = self.token_emb(x) 44 | 45 | if self.use_pe: 46 | x_inp = self.emb_dropout(x_emb) + self.pe(x.size(1)).permute(1, 0, 2) 47 | else: 48 | x_inp = self.emb_dropout(x_emb) 49 | 50 | dec_out = self.transformer_decoder(x_inp, attn_kwargs=attn_kwargs) 51 | dec_logits = self.dec_out_proj(dec_out) 52 | 53 | if keep_last_only: 54 | dec_logits = dec_logits[:, -1, :] 55 | 56 | return dec_logits 57 | 58 | def compute_loss(self, dec_logits, dec_tgt, reduction='mean'): 59 | recons_loss = F.cross_entropy( 60 | dec_logits.view(-1, dec_logits.size(-1)), dec_tgt.contiguous().view(-1), 61 | ignore_index=self.n_token - 1, reduction=reduction 62 | ).float() 63 | 64 | return { 65 | 'recons_loss': recons_loss, 66 | 'total_loss': recons_loss 67 | } 68 | 69 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu/cifar10/r1/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Configuration and hyperparameter sweeps.""" 16 | 17 | from fast_self_attention import fast_self_attention as favor 18 | import jax 19 | 20 | from lra_benchmarks.image.configs.cifar10 import base_cifar10_config 21 | from lra_benchmarks.image.configs.cifar10.base_cifar10_config import TRAIN_EXAMPLES, VALID_EXAMPLES 22 | 23 | 24 | NUM_EPOCHS = 200 25 | 26 | 27 | def get_config(): 28 | """Get the hyperparameter configuration.""" 29 | config = base_cifar10_config.get_config() 30 | config.random_seed = 0 31 | config.model_type = "transformer" 32 | config.learning_rate = .00025 33 | config.batch_size = 96 34 | config.eval_frequency = TRAIN_EXAMPLES // config.batch_size 35 | config.num_train_steps = (TRAIN_EXAMPLES // config.batch_size) * NUM_EPOCHS 36 | config.num_eval_steps = VALID_EXAMPLES // config.batch_size 37 | config.factors = 'constant * linear_warmup * cosine_decay' 38 | config.warmup = (TRAIN_EXAMPLES // config.batch_size) * 1 39 | 40 | config.model.dropout_rate = 0.3 41 | config.model.attention_dropout_rate = 0.2 42 | config.model.learn_pos_emb = True 43 | config.model.num_layers = 1 44 | config.model.emb_dim = 128 45 | config.model.qkv_dim = 64 46 | config.model.mlp_dim = 128 47 | config.model.num_heads = 8 48 | config.model.classifier_pool = "CLS" 49 | 50 | config.attention_fn = favor.make_fast_generalized_attention( 51 | qkv_dim=config.model.qkv_dim // config.model.num_heads, 52 | features_type='deterministic', 53 | kernel_fn=jax.nn.relu, 54 | lax_scan_unroll=16) 55 | return config 56 | 57 | 58 | def get_hyper(hyper): 59 | return hyper.product([]) 60 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu/cifar10/r2/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Configuration and hyperparameter sweeps.""" 16 | 17 | from fast_self_attention import fast_self_attention as favor 18 | import jax 19 | 20 | from lra_benchmarks.image.configs.cifar10 import base_cifar10_config 21 | from lra_benchmarks.image.configs.cifar10.base_cifar10_config import TRAIN_EXAMPLES, VALID_EXAMPLES 22 | 23 | 24 | NUM_EPOCHS = 200 25 | 26 | 27 | def get_config(): 28 | """Get the hyperparameter configuration.""" 29 | config = base_cifar10_config.get_config() 30 | config.random_seed = 1 31 | config.model_type = "transformer" 32 | config.learning_rate = .00025 33 | config.batch_size = 96 34 | config.eval_frequency = TRAIN_EXAMPLES // config.batch_size 35 | config.num_train_steps = (TRAIN_EXAMPLES // config.batch_size) * NUM_EPOCHS 36 | config.num_eval_steps = VALID_EXAMPLES // config.batch_size 37 | config.factors = 'constant * linear_warmup * cosine_decay' 38 | config.warmup = (TRAIN_EXAMPLES // config.batch_size) * 1 39 | 40 | config.model.dropout_rate = 0.3 41 | config.model.attention_dropout_rate = 0.2 42 | config.model.learn_pos_emb = True 43 | config.model.num_layers = 1 44 | config.model.emb_dim = 128 45 | config.model.qkv_dim = 64 46 | config.model.mlp_dim = 128 47 | config.model.num_heads = 8 48 | config.model.classifier_pool = "CLS" 49 | 50 | config.attention_fn = favor.make_fast_generalized_attention( 51 | qkv_dim=config.model.qkv_dim // config.model.num_heads, 52 | features_type='deterministic', 53 | kernel_fn=jax.nn.relu, 54 | lax_scan_unroll=16) 55 | return config 56 | 57 | 58 | def get_hyper(hyper): 59 | return hyper.product([]) 60 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu/cifar10/r3/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Configuration and hyperparameter sweeps.""" 16 | 17 | from fast_self_attention import fast_self_attention as favor 18 | import jax 19 | 20 | from lra_benchmarks.image.configs.cifar10 import base_cifar10_config 21 | from lra_benchmarks.image.configs.cifar10.base_cifar10_config import TRAIN_EXAMPLES, VALID_EXAMPLES 22 | 23 | 24 | NUM_EPOCHS = 200 25 | 26 | 27 | def get_config(): 28 | """Get the hyperparameter configuration.""" 29 | config = base_cifar10_config.get_config() 30 | config.random_seed = 2 31 | config.model_type = "transformer" 32 | config.learning_rate = .00025 33 | config.batch_size = 96 34 | config.eval_frequency = TRAIN_EXAMPLES // config.batch_size 35 | config.num_train_steps = (TRAIN_EXAMPLES // config.batch_size) * NUM_EPOCHS 36 | config.num_eval_steps = VALID_EXAMPLES // config.batch_size 37 | config.factors = 'constant * linear_warmup * cosine_decay' 38 | config.warmup = (TRAIN_EXAMPLES // config.batch_size) * 1 39 | 40 | config.model.dropout_rate = 0.3 41 | config.model.attention_dropout_rate = 0.2 42 | config.model.learn_pos_emb = True 43 | config.model.num_layers = 1 44 | config.model.emb_dim = 128 45 | config.model.qkv_dim = 64 46 | config.model.mlp_dim = 128 47 | config.model.num_heads = 8 48 | config.model.classifier_pool = "CLS" 49 | 50 | config.attention_fn = favor.make_fast_generalized_attention( 51 | qkv_dim=config.model.qkv_dim // config.model.num_heads, 52 | features_type='deterministic', 53 | kernel_fn=jax.nn.relu, 54 | lax_scan_unroll=16) 55 | return config 56 | 57 | 58 | def get_hyper(hyper): 59 | return hyper.product([]) 60 | -------------------------------------------------------------------------------- /experiments/pop_piano/models/ape_fast_transformer_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | from fast_transformers.builders import AttentionBuilder 6 | from fast_transformers.transformers import TransformerEncoderLayer 7 | from fast_transformers.masking import TriangularCausalMask, LengthMask 8 | from fast_transformers.attention import CausalLinearAttention, AttentionLayer 9 | from fast_transformers.feature_maps import Favor 10 | 11 | class FastTransformerDecoder(nn.Module): 12 | def __init__(self, n_layer, n_head, d_model, d_ff, 13 | dropout=0.1, 14 | activation='relu', 15 | favor_feature_dims=None 16 | ): 17 | super(FastTransformerDecoder, self).__init__() 18 | self.n_layer = n_layer 19 | self.n_head = n_head 20 | self.d_model = d_model 21 | self.d_ff = d_ff 22 | self.dropout = dropout 23 | self.activation = activation 24 | 25 | self.favor_feature_dims = 2 * d_model // n_head \ 26 | if favor_feature_dims is None else favor_feature_dims 27 | att_builder = AttentionBuilder.from_kwargs( 28 | query_dimensions=d_model // n_head, 29 | feature_map=Favor.factory(n_dims=self.favor_feature_dims) 30 | ) 31 | 32 | self.attention_layers = [ 33 | AttentionLayer( 34 | att_builder.get("causal-linear"), 35 | d_model, 36 | n_head, 37 | positional_encoder=None 38 | ) 39 | for l in range(n_layer) 40 | ] 41 | 42 | self.decoder_layers = nn.ModuleList() 43 | for l in range(n_layer): 44 | self.decoder_layers.append( 45 | TransformerEncoderLayer( 46 | attention=self.attention_layers[l], 47 | d_model=d_model, 48 | d_ff=d_ff, 49 | dropout=dropout, 50 | activation=activation 51 | ) 52 | ) 53 | 54 | def forward(self, x, lengths=None, attn_kwargs=None): 55 | attn_mask = TriangularCausalMask(x.size(1), device=x.device) 56 | 57 | if lengths is not None: 58 | length_mask = LengthMask(lengths, device=x.device) 59 | else: 60 | length_mask = None 61 | 62 | attn_kwargs = dict(attn_kwargs) if attn_kwargs else {} 63 | 64 | out = x 65 | for l in range(self.n_layer): 66 | # print (out.size()) 67 | out = self.decoder_layers[l]( 68 | out, 69 | attn_mask=attn_mask, 70 | length_mask=length_mask, 71 | attn_kwargs=attn_kwargs 72 | ) 73 | 74 | return out -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_sinespe/cifar10/r1/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Configuration and hyperparameter sweeps.""" 16 | 17 | import functools 18 | 19 | from fast_self_attention import fast_self_attention as favor 20 | import jax_spe as spe 21 | 22 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 23 | from lra_benchmarks.image.configs.cifar10 import base_cifar10_config 24 | from lra_benchmarks.image.configs.cifar10.base_cifar10_config import TRAIN_EXAMPLES, VALID_EXAMPLES 25 | 26 | 27 | NUM_EPOCHS = 200 28 | 29 | 30 | def get_config(): 31 | """Get the hyperparameter configuration.""" 32 | config = base_cifar10_config.get_config() 33 | config.random_seed = 0 34 | config.model_type = "transformer" 35 | config.learning_rate = .00025 36 | config.batch_size = 96 37 | config.eval_frequency = 4 * TRAIN_EXAMPLES // config.batch_size 38 | config.num_train_steps = (TRAIN_EXAMPLES // config.batch_size) * NUM_EPOCHS 39 | config.num_eval_steps = VALID_EXAMPLES // config.batch_size 40 | config.factors = 'constant * linear_warmup * cosine_decay' 41 | config.warmup = (TRAIN_EXAMPLES // config.batch_size) * 1 42 | 43 | config.model.dropout_rate = 0.3 44 | config.model.attention_dropout_rate = 0.2 45 | config.model.learn_pos_emb = True 46 | config.model.num_layers = 1 47 | config.model.emb_dim = 128 48 | config.model.qkv_dim = 64 49 | config.model.mlp_dim = 128 50 | config.model.num_heads = 8 51 | config.model.classifier_pool = "CLS" 52 | config.model.add_pos_emb=False 53 | num_realizations = 32 54 | config.model.qk_transform_fn_factory = functools.partial( 55 | make_spe_transform_fn, 56 | spe_cls=spe.SineSPE, 57 | spe_kwargs=dict( 58 | num_realizations=num_realizations, 59 | num_sines=10 60 | ) 61 | ) 62 | config.attention_fn = favor.make_fast_softmax_attention( 63 | qkv_dim=num_realizations, 64 | lax_scan_unroll=16) 65 | return config 66 | 67 | 68 | def get_hyper(hyper): 69 | return hyper.product([]) 70 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_sinespe/cifar10/r2/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Configuration and hyperparameter sweeps.""" 16 | 17 | import functools 18 | 19 | from fast_self_attention import fast_self_attention as favor 20 | import jax_spe as spe 21 | 22 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 23 | from lra_benchmarks.image.configs.cifar10 import base_cifar10_config 24 | from lra_benchmarks.image.configs.cifar10.base_cifar10_config import TRAIN_EXAMPLES, VALID_EXAMPLES 25 | 26 | 27 | NUM_EPOCHS = 200 28 | 29 | 30 | def get_config(): 31 | """Get the hyperparameter configuration.""" 32 | config = base_cifar10_config.get_config() 33 | config.random_seed = 1 34 | config.model_type = "transformer" 35 | config.learning_rate = .00025 36 | config.batch_size = 96 37 | config.eval_frequency = 4 * TRAIN_EXAMPLES // config.batch_size 38 | config.num_train_steps = (TRAIN_EXAMPLES // config.batch_size) * NUM_EPOCHS 39 | config.num_eval_steps = VALID_EXAMPLES // config.batch_size 40 | config.factors = 'constant * linear_warmup * cosine_decay' 41 | config.warmup = (TRAIN_EXAMPLES // config.batch_size) * 1 42 | 43 | config.model.dropout_rate = 0.3 44 | config.model.attention_dropout_rate = 0.2 45 | config.model.learn_pos_emb = True 46 | config.model.num_layers = 1 47 | config.model.emb_dim = 128 48 | config.model.qkv_dim = 64 49 | config.model.mlp_dim = 128 50 | config.model.num_heads = 8 51 | config.model.classifier_pool = "CLS" 52 | config.model.add_pos_emb=False 53 | num_realizations = 32 54 | config.model.qk_transform_fn_factory = functools.partial( 55 | make_spe_transform_fn, 56 | spe_cls=spe.SineSPE, 57 | spe_kwargs=dict( 58 | num_realizations=num_realizations, 59 | num_sines=10 60 | ) 61 | ) 62 | config.attention_fn = favor.make_fast_softmax_attention( 63 | qkv_dim=num_realizations, 64 | lax_scan_unroll=16) 65 | return config 66 | 67 | 68 | def get_hyper(hyper): 69 | return hyper.product([]) 70 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_sinespe/cifar10/r3/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Configuration and hyperparameter sweeps.""" 16 | 17 | import functools 18 | 19 | from fast_self_attention import fast_self_attention as favor 20 | import jax_spe as spe 21 | 22 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 23 | from lra_benchmarks.image.configs.cifar10 import base_cifar10_config 24 | from lra_benchmarks.image.configs.cifar10.base_cifar10_config import TRAIN_EXAMPLES, VALID_EXAMPLES 25 | 26 | 27 | NUM_EPOCHS = 200 28 | 29 | 30 | def get_config(): 31 | """Get the hyperparameter configuration.""" 32 | config = base_cifar10_config.get_config() 33 | config.random_seed = 2 34 | config.model_type = "transformer" 35 | config.learning_rate = .00025 36 | config.batch_size = 96 37 | config.eval_frequency = 4 * TRAIN_EXAMPLES // config.batch_size 38 | config.num_train_steps = (TRAIN_EXAMPLES // config.batch_size) * NUM_EPOCHS 39 | config.num_eval_steps = VALID_EXAMPLES // config.batch_size 40 | config.factors = 'constant * linear_warmup * cosine_decay' 41 | config.warmup = (TRAIN_EXAMPLES // config.batch_size) * 1 42 | 43 | config.model.dropout_rate = 0.3 44 | config.model.attention_dropout_rate = 0.2 45 | config.model.learn_pos_emb = True 46 | config.model.num_layers = 1 47 | config.model.emb_dim = 128 48 | config.model.qkv_dim = 64 49 | config.model.mlp_dim = 128 50 | config.model.num_heads = 8 51 | config.model.classifier_pool = "CLS" 52 | config.model.add_pos_emb=False 53 | num_realizations = 32 54 | config.model.qk_transform_fn_factory = functools.partial( 55 | make_spe_transform_fn, 56 | spe_cls=spe.SineSPE, 57 | spe_kwargs=dict( 58 | num_realizations=num_realizations, 59 | num_sines=10 60 | ) 61 | ) 62 | config.attention_fn = favor.make_fast_softmax_attention( 63 | qkv_dim=num_realizations, 64 | lax_scan_unroll=16) 65 | return config 66 | 67 | 68 | def get_hyper(hyper): 69 | return hyper.product([]) 70 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_convspe_k128_shr/cifar10/r1/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Configuration and hyperparameter sweeps.""" 16 | 17 | import functools 18 | 19 | from fast_self_attention import fast_self_attention as favor 20 | import jax_spe as spe 21 | 22 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 23 | from lra_benchmarks.image.configs.cifar10 import base_cifar10_config 24 | from lra_benchmarks.image.configs.cifar10.base_cifar10_config import TRAIN_EXAMPLES, VALID_EXAMPLES 25 | 26 | 27 | NUM_EPOCHS = 200 28 | 29 | 30 | def get_config(): 31 | """Get the hyperparameter configuration.""" 32 | config = base_cifar10_config.get_config() 33 | config.random_seed = 0 34 | config.model_type = "transformer" 35 | config.learning_rate = .00025 36 | config.batch_size = 96 37 | config.eval_frequency = 4 * TRAIN_EXAMPLES // config.batch_size 38 | config.num_train_steps = (TRAIN_EXAMPLES // config.batch_size) * NUM_EPOCHS 39 | config.num_eval_steps = VALID_EXAMPLES // config.batch_size 40 | config.factors = 'constant * linear_warmup * cosine_decay' 41 | config.warmup = (TRAIN_EXAMPLES // config.batch_size) * 1 42 | 43 | config.model.dropout_rate = 0.3 44 | config.model.attention_dropout_rate = 0.2 45 | config.model.learn_pos_emb = True 46 | config.model.num_layers = 1 47 | config.model.emb_dim = 128 48 | config.model.qkv_dim = 64 49 | config.model.mlp_dim = 128 50 | config.model.num_heads = 8 51 | config.model.classifier_pool = "CLS" 52 | config.model.add_pos_emb=False 53 | num_realizations = 32 54 | config.model.qk_transform_fn_factory = functools.partial( 55 | make_spe_transform_fn, 56 | spe_cls=spe.ConvSPE, 57 | spe_kwargs=dict( 58 | num_realizations=num_realizations, 59 | kernel_size=128 60 | ), 61 | shared=True 62 | ) 63 | config.attention_fn = favor.make_fast_softmax_attention( 64 | qkv_dim=num_realizations, 65 | lax_scan_unroll=16) 66 | return config 67 | 68 | 69 | def get_hyper(hyper): 70 | return hyper.product([]) 71 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_convspe_k128_shr/cifar10/r2/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Configuration and hyperparameter sweeps.""" 16 | 17 | import functools 18 | 19 | from fast_self_attention import fast_self_attention as favor 20 | import jax_spe as spe 21 | 22 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 23 | from lra_benchmarks.image.configs.cifar10 import base_cifar10_config 24 | from lra_benchmarks.image.configs.cifar10.base_cifar10_config import TRAIN_EXAMPLES, VALID_EXAMPLES 25 | 26 | 27 | NUM_EPOCHS = 200 28 | 29 | 30 | def get_config(): 31 | """Get the hyperparameter configuration.""" 32 | config = base_cifar10_config.get_config() 33 | config.random_seed = 1 34 | config.model_type = "transformer" 35 | config.learning_rate = .00025 36 | config.batch_size = 96 37 | config.eval_frequency = 4 * TRAIN_EXAMPLES // config.batch_size 38 | config.num_train_steps = (TRAIN_EXAMPLES // config.batch_size) * NUM_EPOCHS 39 | config.num_eval_steps = VALID_EXAMPLES // config.batch_size 40 | config.factors = 'constant * linear_warmup * cosine_decay' 41 | config.warmup = (TRAIN_EXAMPLES // config.batch_size) * 1 42 | 43 | config.model.dropout_rate = 0.3 44 | config.model.attention_dropout_rate = 0.2 45 | config.model.learn_pos_emb = True 46 | config.model.num_layers = 1 47 | config.model.emb_dim = 128 48 | config.model.qkv_dim = 64 49 | config.model.mlp_dim = 128 50 | config.model.num_heads = 8 51 | config.model.classifier_pool = "CLS" 52 | config.model.add_pos_emb=False 53 | num_realizations = 32 54 | config.model.qk_transform_fn_factory = functools.partial( 55 | make_spe_transform_fn, 56 | spe_cls=spe.ConvSPE, 57 | spe_kwargs=dict( 58 | num_realizations=num_realizations, 59 | kernel_size=128 60 | ), 61 | shared=True 62 | ) 63 | config.attention_fn = favor.make_fast_softmax_attention( 64 | qkv_dim=num_realizations, 65 | lax_scan_unroll=16) 66 | return config 67 | 68 | 69 | def get_hyper(hyper): 70 | return hyper.product([]) 71 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/performer_softmax_convspe_k128_shr/cifar10/r3/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Configuration and hyperparameter sweeps.""" 16 | 17 | import functools 18 | 19 | from fast_self_attention import fast_self_attention as favor 20 | import jax_spe as spe 21 | 22 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 23 | from lra_benchmarks.image.configs.cifar10 import base_cifar10_config 24 | from lra_benchmarks.image.configs.cifar10.base_cifar10_config import TRAIN_EXAMPLES, VALID_EXAMPLES 25 | 26 | 27 | NUM_EPOCHS = 200 28 | 29 | 30 | def get_config(): 31 | """Get the hyperparameter configuration.""" 32 | config = base_cifar10_config.get_config() 33 | config.random_seed = 2 34 | config.model_type = "transformer" 35 | config.learning_rate = .00025 36 | config.batch_size = 96 37 | config.eval_frequency = 4 * TRAIN_EXAMPLES // config.batch_size 38 | config.num_train_steps = (TRAIN_EXAMPLES // config.batch_size) * NUM_EPOCHS 39 | config.num_eval_steps = VALID_EXAMPLES // config.batch_size 40 | config.factors = 'constant * linear_warmup * cosine_decay' 41 | config.warmup = (TRAIN_EXAMPLES // config.batch_size) * 1 42 | 43 | config.model.dropout_rate = 0.3 44 | config.model.attention_dropout_rate = 0.2 45 | config.model.learn_pos_emb = True 46 | config.model.num_layers = 1 47 | config.model.emb_dim = 128 48 | config.model.qkv_dim = 64 49 | config.model.mlp_dim = 128 50 | config.model.num_heads = 8 51 | config.model.classifier_pool = "CLS" 52 | config.model.add_pos_emb=False 53 | num_realizations = 32 54 | config.model.qk_transform_fn_factory = functools.partial( 55 | make_spe_transform_fn, 56 | spe_cls=spe.ConvSPE, 57 | spe_kwargs=dict( 58 | num_realizations=num_realizations, 59 | kernel_size=128 60 | ), 61 | shared=True 62 | ) 63 | config.attention_fn = favor.make_fast_softmax_attention( 64 | qkv_dim=num_realizations, 65 | lax_scan_unroll=16) 66 | return config 67 | 68 | 69 | def get_hyper(hyper): 70 | return hyper.product([]) 71 | -------------------------------------------------------------------------------- /experiments/pop_piano/models/transformer_helpers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | def weight_init_normal(weight, normal_std): 9 | nn.init.normal_(weight, 0.0, normal_std) 10 | 11 | def bias_init(bias): 12 | nn.init.constant_(bias, 0.0) 13 | 14 | def weights_init(m): 15 | classname = m.__class__.__name__ 16 | # print ('[{}] initializing ...'.format(classname)) 17 | 18 | if classname.find('Linear') != -1: 19 | if hasattr(m, 'weight') and m.weight is not None: 20 | weight_init_normal(m.weight, 0.01) 21 | if hasattr(m, 'bias') and m.bias is not None: 22 | bias_init(m.bias) 23 | elif classname.find('Embedding') != -1: 24 | if hasattr(m, 'weight'): 25 | weight_init_normal(m.weight, 0.01) 26 | elif classname.find('LayerNorm') != -1: 27 | if hasattr(m, 'weight'): 28 | nn.init.normal_(m.weight, 1.0, 0.01) 29 | if hasattr(m, 'bias') and m.bias is not None: 30 | bias_init(m.bias) 31 | 32 | class PositionalEncoding(nn.Module): 33 | def __init__(self, d_embed, max_pos=20480): 34 | super(PositionalEncoding, self).__init__() 35 | self.d_embed = d_embed 36 | self.max_pos = max_pos 37 | 38 | pe = torch.zeros(max_pos, d_embed) 39 | position = torch.arange(0, max_pos, dtype=torch.float).unsqueeze(1) 40 | div_term = torch.exp(torch.arange(0, d_embed, 2).float() * (-math.log(10000.0) / d_embed)) 41 | pe[:, 0::2] = torch.sin(position * div_term) 42 | pe[:, 1::2] = torch.cos(position * div_term) 43 | pe = pe.unsqueeze(0).transpose(0, 1) 44 | self.register_buffer('pe', pe) 45 | 46 | def forward(self, seq_len, bsz=None): 47 | pos_encoding = self.pe[:seq_len, :] 48 | 49 | if bsz is not None: 50 | pos_encoding = pos_encoding.expand(seq_len, bsz, -1) 51 | 52 | return pos_encoding 53 | 54 | class TokenEmbedding(nn.Module): 55 | def __init__(self, n_token, d_embed, d_proj): 56 | super(TokenEmbedding, self).__init__() 57 | 58 | self.n_token = n_token 59 | self.d_embed = d_embed 60 | self.d_proj = d_proj 61 | self.emb_scale = d_proj ** 0.5 62 | 63 | self.emb_lookup = nn.Embedding(n_token, d_embed) 64 | if d_proj != d_embed: 65 | self.emb_proj = nn.Linear(d_embed, d_proj, bias=False) 66 | else: 67 | self.emb_proj = None 68 | 69 | def forward(self, inp_tokens): 70 | inp_emb = self.emb_lookup(inp_tokens) 71 | 72 | if self.emb_proj is not None: 73 | inp_emb = self.emb_proj(inp_emb) 74 | 75 | return inp_emb.mul_(self.emb_scale) -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_sinespe/cifar10/r1/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Configuration and hyperparameter sweeps.""" 16 | 17 | import functools 18 | 19 | from fast_self_attention import fast_self_attention as favor 20 | import jax 21 | import jax_spe as spe 22 | 23 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 24 | from lra_benchmarks.image.configs.cifar10 import base_cifar10_config 25 | from lra_benchmarks.image.configs.cifar10.base_cifar10_config import TRAIN_EXAMPLES, VALID_EXAMPLES 26 | 27 | 28 | NUM_EPOCHS = 200 29 | 30 | 31 | def get_config(): 32 | """Get the hyperparameter configuration.""" 33 | config = base_cifar10_config.get_config() 34 | config.random_seed = 0 35 | config.model_type = "transformer" 36 | config.learning_rate = .00025 37 | config.batch_size = 96 38 | config.eval_frequency = 4 * TRAIN_EXAMPLES // config.batch_size 39 | config.num_train_steps = (TRAIN_EXAMPLES // config.batch_size) * NUM_EPOCHS 40 | config.num_eval_steps = VALID_EXAMPLES // config.batch_size 41 | config.factors = 'constant * linear_warmup * cosine_decay' 42 | config.warmup = (TRAIN_EXAMPLES // config.batch_size) * 1 43 | 44 | config.model.dropout_rate = 0.3 45 | config.model.attention_dropout_rate = 0.2 46 | config.model.learn_pos_emb = True 47 | config.model.num_layers = 1 48 | config.model.emb_dim = 128 49 | config.model.qkv_dim = 64 50 | config.model.mlp_dim = 128 51 | config.model.num_heads = 8 52 | config.model.classifier_pool = "CLS" 53 | config.model.add_pos_emb=False 54 | num_realizations = 32 55 | config.model.qk_transform_fn_factory = functools.partial( 56 | make_spe_transform_fn, 57 | spe_cls=spe.SineSPE, 58 | spe_kwargs=dict( 59 | num_realizations=num_realizations, 60 | num_sines=10 61 | ) 62 | ) 63 | config.attention_fn = favor.make_fast_generalized_attention( 64 | qkv_dim=num_realizations, 65 | features_type='deterministic', 66 | kernel_fn=jax.nn.relu, 67 | lax_scan_unroll=16) 68 | return config 69 | 70 | 71 | def get_hyper(hyper): 72 | return hyper.product([]) 73 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_sinespe/cifar10/r2/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Configuration and hyperparameter sweeps.""" 16 | 17 | import functools 18 | 19 | from fast_self_attention import fast_self_attention as favor 20 | import jax 21 | import jax_spe as spe 22 | 23 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 24 | from lra_benchmarks.image.configs.cifar10 import base_cifar10_config 25 | from lra_benchmarks.image.configs.cifar10.base_cifar10_config import TRAIN_EXAMPLES, VALID_EXAMPLES 26 | 27 | 28 | NUM_EPOCHS = 200 29 | 30 | 31 | def get_config(): 32 | """Get the hyperparameter configuration.""" 33 | config = base_cifar10_config.get_config() 34 | config.random_seed = 1 35 | config.model_type = "transformer" 36 | config.learning_rate = .00025 37 | config.batch_size = 96 38 | config.eval_frequency = 4 * TRAIN_EXAMPLES // config.batch_size 39 | config.num_train_steps = (TRAIN_EXAMPLES // config.batch_size) * NUM_EPOCHS 40 | config.num_eval_steps = VALID_EXAMPLES // config.batch_size 41 | config.factors = 'constant * linear_warmup * cosine_decay' 42 | config.warmup = (TRAIN_EXAMPLES // config.batch_size) * 1 43 | 44 | config.model.dropout_rate = 0.3 45 | config.model.attention_dropout_rate = 0.2 46 | config.model.learn_pos_emb = True 47 | config.model.num_layers = 1 48 | config.model.emb_dim = 128 49 | config.model.qkv_dim = 64 50 | config.model.mlp_dim = 128 51 | config.model.num_heads = 8 52 | config.model.classifier_pool = "CLS" 53 | config.model.add_pos_emb=False 54 | num_realizations = 32 55 | config.model.qk_transform_fn_factory = functools.partial( 56 | make_spe_transform_fn, 57 | spe_cls=spe.SineSPE, 58 | spe_kwargs=dict( 59 | num_realizations=num_realizations, 60 | num_sines=10 61 | ) 62 | ) 63 | config.attention_fn = favor.make_fast_generalized_attention( 64 | qkv_dim=num_realizations, 65 | features_type='deterministic', 66 | kernel_fn=jax.nn.relu, 67 | lax_scan_unroll=16) 68 | return config 69 | 70 | 71 | def get_hyper(hyper): 72 | return hyper.product([]) 73 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_sinespe/cifar10/r3/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Configuration and hyperparameter sweeps.""" 16 | 17 | import functools 18 | 19 | from fast_self_attention import fast_self_attention as favor 20 | import jax 21 | import jax_spe as spe 22 | 23 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 24 | from lra_benchmarks.image.configs.cifar10 import base_cifar10_config 25 | from lra_benchmarks.image.configs.cifar10.base_cifar10_config import TRAIN_EXAMPLES, VALID_EXAMPLES 26 | 27 | 28 | NUM_EPOCHS = 200 29 | 30 | 31 | def get_config(): 32 | """Get the hyperparameter configuration.""" 33 | config = base_cifar10_config.get_config() 34 | config.random_seed = 2 35 | config.model_type = "transformer" 36 | config.learning_rate = .00025 37 | config.batch_size = 96 38 | config.eval_frequency = 4 * TRAIN_EXAMPLES // config.batch_size 39 | config.num_train_steps = (TRAIN_EXAMPLES // config.batch_size) * NUM_EPOCHS 40 | config.num_eval_steps = VALID_EXAMPLES // config.batch_size 41 | config.factors = 'constant * linear_warmup * cosine_decay' 42 | config.warmup = (TRAIN_EXAMPLES // config.batch_size) * 1 43 | 44 | config.model.dropout_rate = 0.3 45 | config.model.attention_dropout_rate = 0.2 46 | config.model.learn_pos_emb = True 47 | config.model.num_layers = 1 48 | config.model.emb_dim = 128 49 | config.model.qkv_dim = 64 50 | config.model.mlp_dim = 128 51 | config.model.num_heads = 8 52 | config.model.classifier_pool = "CLS" 53 | config.model.add_pos_emb=False 54 | num_realizations = 32 55 | config.model.qk_transform_fn_factory = functools.partial( 56 | make_spe_transform_fn, 57 | spe_cls=spe.SineSPE, 58 | spe_kwargs=dict( 59 | num_realizations=num_realizations, 60 | num_sines=10 61 | ) 62 | ) 63 | config.attention_fn = favor.make_fast_generalized_attention( 64 | qkv_dim=num_realizations, 65 | features_type='deterministic', 66 | kernel_fn=jax.nn.relu, 67 | lax_scan_unroll=16) 68 | return config 69 | 70 | 71 | def get_hyper(hyper): 72 | return hyper.product([]) 73 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | !/experiments/*/lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | **/.vscode/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | # directories for checkpoints & generations 134 | **/checkpoint 135 | **/ckpt* 136 | **/gen*/ 137 | 138 | # dataset 139 | *.tar.gz 140 | **/*_dataset/ 141 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_convspe_k128_shr/cifar10/r1/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Configuration and hyperparameter sweeps.""" 16 | 17 | import functools 18 | 19 | from fast_self_attention import fast_self_attention as favor 20 | import jax 21 | import jax_spe as spe 22 | 23 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 24 | from lra_benchmarks.image.configs.cifar10 import base_cifar10_config 25 | from lra_benchmarks.image.configs.cifar10.base_cifar10_config import TRAIN_EXAMPLES, VALID_EXAMPLES 26 | 27 | 28 | NUM_EPOCHS = 200 29 | 30 | 31 | def get_config(): 32 | """Get the hyperparameter configuration.""" 33 | config = base_cifar10_config.get_config() 34 | config.random_seed = 0 35 | config.model_type = "transformer" 36 | config.learning_rate = .00025 37 | config.batch_size = 96 38 | config.eval_frequency = 4 * TRAIN_EXAMPLES // config.batch_size 39 | config.num_train_steps = (TRAIN_EXAMPLES // config.batch_size) * NUM_EPOCHS 40 | config.num_eval_steps = VALID_EXAMPLES // config.batch_size 41 | config.factors = 'constant * linear_warmup * cosine_decay' 42 | config.warmup = (TRAIN_EXAMPLES // config.batch_size) * 1 43 | 44 | config.model.dropout_rate = 0.3 45 | config.model.attention_dropout_rate = 0.2 46 | config.model.learn_pos_emb = True 47 | config.model.num_layers = 1 48 | config.model.emb_dim = 128 49 | config.model.qkv_dim = 64 50 | config.model.mlp_dim = 128 51 | config.model.num_heads = 8 52 | config.model.classifier_pool = "CLS" 53 | config.model.add_pos_emb=False 54 | num_realizations = 32 55 | config.model.qk_transform_fn_factory = functools.partial( 56 | make_spe_transform_fn, 57 | spe_cls=spe.ConvSPE, 58 | spe_kwargs=dict( 59 | num_realizations=num_realizations, 60 | kernel_size=128 61 | ), 62 | shared=True 63 | ) 64 | config.attention_fn = favor.make_fast_generalized_attention( 65 | qkv_dim=num_realizations, 66 | features_type='deterministic', 67 | kernel_fn=jax.nn.relu, 68 | lax_scan_unroll=16) 69 | return config 70 | 71 | 72 | def get_hyper(hyper): 73 | return hyper.product([]) 74 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_convspe_k128_shr/cifar10/r2/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Configuration and hyperparameter sweeps.""" 16 | 17 | import functools 18 | 19 | from fast_self_attention import fast_self_attention as favor 20 | import jax 21 | import jax_spe as spe 22 | 23 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 24 | from lra_benchmarks.image.configs.cifar10 import base_cifar10_config 25 | from lra_benchmarks.image.configs.cifar10.base_cifar10_config import TRAIN_EXAMPLES, VALID_EXAMPLES 26 | 27 | 28 | NUM_EPOCHS = 200 29 | 30 | 31 | def get_config(): 32 | """Get the hyperparameter configuration.""" 33 | config = base_cifar10_config.get_config() 34 | config.random_seed = 1 35 | config.model_type = "transformer" 36 | config.learning_rate = .00025 37 | config.batch_size = 96 38 | config.eval_frequency = 4 * TRAIN_EXAMPLES // config.batch_size 39 | config.num_train_steps = (TRAIN_EXAMPLES // config.batch_size) * NUM_EPOCHS 40 | config.num_eval_steps = VALID_EXAMPLES // config.batch_size 41 | config.factors = 'constant * linear_warmup * cosine_decay' 42 | config.warmup = (TRAIN_EXAMPLES // config.batch_size) * 1 43 | 44 | config.model.dropout_rate = 0.3 45 | config.model.attention_dropout_rate = 0.2 46 | config.model.learn_pos_emb = True 47 | config.model.num_layers = 1 48 | config.model.emb_dim = 128 49 | config.model.qkv_dim = 64 50 | config.model.mlp_dim = 128 51 | config.model.num_heads = 8 52 | config.model.classifier_pool = "CLS" 53 | config.model.add_pos_emb=False 54 | num_realizations = 32 55 | config.model.qk_transform_fn_factory = functools.partial( 56 | make_spe_transform_fn, 57 | spe_cls=spe.ConvSPE, 58 | spe_kwargs=dict( 59 | num_realizations=num_realizations, 60 | kernel_size=128 61 | ), 62 | shared=True 63 | ) 64 | config.attention_fn = favor.make_fast_generalized_attention( 65 | qkv_dim=num_realizations, 66 | features_type='deterministic', 67 | kernel_fn=jax.nn.relu, 68 | lax_scan_unroll=16) 69 | return config 70 | 71 | 72 | def get_hyper(hyper): 73 | return hyper.product([]) 74 | -------------------------------------------------------------------------------- /experiments/lra/models/gpu_16g/linear_transformer_relu_convspe_k128_shr/cifar10/r3/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Configuration and hyperparameter sweeps.""" 16 | 17 | import functools 18 | 19 | from fast_self_attention import fast_self_attention as favor 20 | import jax 21 | import jax_spe as spe 22 | 23 | from lra_benchmarks.models.layers.spe import make_spe_transform_fn 24 | from lra_benchmarks.image.configs.cifar10 import base_cifar10_config 25 | from lra_benchmarks.image.configs.cifar10.base_cifar10_config import TRAIN_EXAMPLES, VALID_EXAMPLES 26 | 27 | 28 | NUM_EPOCHS = 200 29 | 30 | 31 | def get_config(): 32 | """Get the hyperparameter configuration.""" 33 | config = base_cifar10_config.get_config() 34 | config.random_seed = 2 35 | config.model_type = "transformer" 36 | config.learning_rate = .00025 37 | config.batch_size = 96 38 | config.eval_frequency = 4 * TRAIN_EXAMPLES // config.batch_size 39 | config.num_train_steps = (TRAIN_EXAMPLES // config.batch_size) * NUM_EPOCHS 40 | config.num_eval_steps = VALID_EXAMPLES // config.batch_size 41 | config.factors = 'constant * linear_warmup * cosine_decay' 42 | config.warmup = (TRAIN_EXAMPLES // config.batch_size) * 1 43 | 44 | config.model.dropout_rate = 0.3 45 | config.model.attention_dropout_rate = 0.2 46 | config.model.learn_pos_emb = True 47 | config.model.num_layers = 1 48 | config.model.emb_dim = 128 49 | config.model.qkv_dim = 64 50 | config.model.mlp_dim = 128 51 | config.model.num_heads = 8 52 | config.model.classifier_pool = "CLS" 53 | config.model.add_pos_emb=False 54 | num_realizations = 32 55 | config.model.qk_transform_fn_factory = functools.partial( 56 | make_spe_transform_fn, 57 | spe_cls=spe.ConvSPE, 58 | spe_kwargs=dict( 59 | num_realizations=num_realizations, 60 | kernel_size=128 61 | ), 62 | shared=True 63 | ) 64 | config.attention_fn = favor.make_fast_generalized_attention( 65 | qkv_dim=num_realizations, 66 | features_type='deterministic', 67 | kernel_fn=jax.nn.relu, 68 | lax_scan_unroll=16) 69 | return config 70 | 71 | 72 | def get_hyper(hyper): 73 | return hyper.product([]) 74 | -------------------------------------------------------------------------------- /experiments/pop_piano/models/music_performer_spe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | from spe_fast_transformer_decoder import SPEFastTransformerDecoder 6 | from transformer_helpers import ( 7 | TokenEmbedding, 8 | weights_init 9 | ) 10 | 11 | from spe import SineSPE, ConvSPE 12 | 13 | class MusicPerformerSPE(nn.Module): 14 | def __init__(self, n_token, n_layer, n_head, d_model, d_ff, d_embed, spe_type, 15 | activation='relu', dropout=0.1, favor_feature_dims=None, 16 | share_pe=False, 17 | share_spe_filter=False, 18 | use_gated_filter=True, 19 | spe_module_params=None 20 | ): 21 | super(MusicPerformerSPE, self).__init__() 22 | self.n_token = n_token 23 | self.n_layer = n_layer 24 | self.n_head = n_head 25 | self.d_model = d_model 26 | self.d_ff = d_ff 27 | self.dropout = dropout 28 | self.activation = activation 29 | self.favor_feature_dims = favor_feature_dims 30 | 31 | self.token_emb = TokenEmbedding(n_token, d_embed, d_model) 32 | self.d_embed = d_embed 33 | 34 | self.spe_type = spe_type 35 | if self.spe_type == 'SineSPE': 36 | self.pe = SineSPE 37 | elif self.spe_type == 'ConvSPE': 38 | self.pe = ConvSPE 39 | else: 40 | raise ValueError('unrecognized SPE implementation: {}'.format(self.spe_type)) 41 | 42 | self.dec_out_proj = nn.Linear(d_model, n_token) 43 | self.d_head = d_model // n_head 44 | self.transformer_decoder = SPEFastTransformerDecoder( 45 | n_layer, n_head, d_model, d_ff, dropout, activation, favor_feature_dims, 46 | spe_module=self.pe, 47 | share_pe=share_pe, 48 | share_spe_filter=share_spe_filter, 49 | use_gated_filter=use_gated_filter, 50 | spe_module_params=spe_module_params 51 | ) 52 | 53 | self.emb_dropout = nn.Dropout(self.dropout) 54 | self.use_gated_filter = use_gated_filter 55 | self.apply(weights_init) 56 | 57 | print ('[info] model init completed') 58 | 59 | def forward(self, x, keep_last_only=False, attn_kwargs=None): 60 | x_emb = self.token_emb(x) 61 | x_inp = self.emb_dropout(x_emb) 62 | 63 | dec_out = self.transformer_decoder(x_inp, attn_kwargs=attn_kwargs) 64 | dec_logits = self.dec_out_proj(dec_out) 65 | 66 | if keep_last_only: 67 | dec_logits = dec_logits[:, -1, :] 68 | 69 | return dec_logits 70 | 71 | def compute_loss(self, dec_logits, dec_tgt, reduction='mean'): 72 | recons_loss = F.cross_entropy( 73 | dec_logits.view(-1, dec_logits.size(-1)), dec_tgt.contiguous().view(-1), 74 | ignore_index=self.n_token - 1, reduction=reduction 75 | ).float() 76 | 77 | return { 78 | 'recons_loss': recons_loss, 79 | 'total_loss': recons_loss 80 | } -------------------------------------------------------------------------------- /experiments/pop_piano/inference.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | sys.path.append('./models') 3 | import torch 4 | import numpy as np 5 | 6 | from utils import pickle_load, load_model 7 | from convert2midi import event_to_midi 8 | from generate_utils import generate_fast 9 | 10 | import yaml 11 | train_conf_path = sys.argv[1] 12 | inf_conf_path = sys.argv[2] 13 | train_conf = yaml.load(open(train_conf_path, 'r'), Loader=yaml.FullLoader) 14 | inf_conf = yaml.load(open(inf_conf_path, 'r'), Loader=yaml.FullLoader) 15 | 16 | REMI_MODEL_VOCAB_SIZE = 333 17 | gpuid = inf_conf['gpuid'] 18 | torch.cuda.set_device(gpuid) 19 | 20 | def word2event(word_seq, idx2event): 21 | return [ idx2event[w] for w in word_seq ] 22 | 23 | if __name__ == "__main__": 24 | event2idx, idx2event = pickle_load('./pickles/remi_vocab.pkl') 25 | 26 | ckpt_path = inf_conf['ckpt_path'] 27 | n_pieces = inf_conf['gen_n_pieces'] 28 | gen_output_dir = inf_conf['gen_output_dir'] 29 | gen_max_events = inf_conf['gen_max_events'] 30 | gen_max_bars = inf_conf['gen_max_bars'] 31 | sampling_temp = inf_conf['sampling']['temp'] 32 | sampling_top_p = inf_conf['sampling']['top_p'] 33 | 34 | model = load_model(train_conf['model'], gpuid, REMI_MODEL_VOCAB_SIZE) 35 | 36 | pretrained_dict = torch.load(ckpt_path) 37 | pretrained_dict = { 38 | k:v for k, v in pretrained_dict.items() if 'feature_map.omega' not in k 39 | } 40 | model_state_dict = model.state_dict() 41 | model_state_dict.update(pretrained_dict) 42 | model.load_state_dict(model_state_dict) 43 | model.eval() 44 | print ('[info] trained model weights loaded') 45 | 46 | if not os.path.exists(gen_output_dir): 47 | os.makedirs(gen_output_dir) 48 | 49 | all_ents = np.zeros((1, gen_max_events)) 50 | with torch.no_grad(): 51 | for p in range(n_pieces): 52 | print ('model:', type(model), '{}'.format(model.spe_type if hasattr(model, 'spe_type') else '')) 53 | print ('piece:', p+1) 54 | out_file = os.path.join(gen_output_dir, 55 | 'samp{:02d}'.format(p + 1) 56 | ) 57 | 58 | song, entropies = generate_fast(model, event2idx, idx2event, 59 | max_bars=gen_max_bars, max_events=gen_max_events, skip_check=False, 60 | temp=sampling_temp, top_p=sampling_top_p 61 | ) 62 | 63 | song = word2event(song, idx2event) 64 | print (*song, sep='\n', file=open(out_file + '.txt', 'w')) 65 | event_to_midi(song, out_file + '.mid') 66 | 67 | # [optional] Save per-token entropies during generation 68 | # print (entropies.shape) 69 | # all_ents = np.concatenate( 70 | # (all_ents, np.expand_dims(entropies, axis=0)), 71 | # axis=0 72 | # ) 73 | # print (all_ents.shape) 74 | # np.save(os.path.join(gen_output_dir, 'entropies'), all_ents[1:, :]) -------------------------------------------------------------------------------- /experiments/groove/src/spe_music/style_eval/note_features.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | import muspy 3 | import numpy as np 4 | 5 | _EPSILON = 1e-9 6 | 7 | 8 | def extract_features(data: Iterable[muspy.Music], features): 9 | """Extract a set of features from the given note sequences. 10 | 11 | Args: 12 | note_sequences: an iterable of MusPy `Music` objects. 13 | features: a dictionary with feature objects as values. 14 | 15 | Returns: 16 | A dictionary mapping keys from `features` to lists of feature values. 17 | """ 18 | results = {key: [] for key in features} 19 | for music in data: 20 | for key, feature in features.items(): 21 | results[key].extend(list(feature.extract(music))) 22 | 23 | assert len(set(len(x) for x in results.values())) <= 1 24 | 25 | return results 26 | 27 | 28 | class Pitch: 29 | """The MIDI pitch of the note.""" 30 | 31 | def extract(self, music: muspy.Music): 32 | for track in music.tracks: 33 | for note in track.notes: 34 | yield note.pitch 35 | 36 | def get_bins(self, min_value=0, max_value=127): 37 | return np.arange(min_value, max_value + 1) - 0.5 38 | 39 | 40 | class Duration: 41 | """The duration of the note. 42 | 43 | It is assumed that the tempo is normalized (typically to 60 BPM) so that the duration is 44 | expressed in beats. 45 | """ 46 | 47 | def extract(self, music: muspy.Music): 48 | for track in music.tracks: 49 | for note in track.notes: 50 | yield note.duration / music.resolution 51 | 52 | def get_bins(self, bin_size=1/6, max_value=2): 53 | return np.arange(0., max_value + bin_size - _EPSILON, bin_size) 54 | 55 | 56 | class Velocity: 57 | """The MIDI velocity of the note.""" 58 | 59 | def extract(self, music: muspy.Music): 60 | for track in music.tracks: 61 | for note in track.notes: 62 | yield note.velocity 63 | 64 | def get_bins(self, num_bins=8): 65 | return np.arange(0, 127, 128 / num_bins) - 0.5 66 | 67 | 68 | class OnsetPositionInBar: 69 | """The time of the note onset expressed in beats from the most recent downbeat.""" 70 | 71 | def extract(self, music: muspy.Music): 72 | if music.time_signatures: 73 | if (len(set((ts.numerator, ts.denominator) for ts in music.time_signatures)) > 1 74 | or music.time_signatures[0].time > 0): 75 | raise NotImplementedError('Music with multiple time signatures is not supported') 76 | bar_duration = music.time_signatures[0].numerator * music.resolution 77 | # TODO: Make sure this is correct and handle compound time signatures. 78 | else: 79 | # Assume 4/4 80 | bar_duration = 4 * music.resolution 81 | 82 | for track in music.tracks: 83 | for note in track.notes: 84 | yield note.start % bar_duration 85 | 86 | def get_bins(self, bin_size=1/6, max_beats=4): 87 | return np.arange(0., max_beats + bin_size - _EPSILON, bin_size) 88 | -------------------------------------------------------------------------------- /experiments/lra/fast_attention/README.md: -------------------------------------------------------------------------------- 1 | # Performer's Fast Self Attention Module. 2 | 3 | Copied from https://github.com/google-research/google-research/tree/2260bcc3f9946ae8f07e39bc0ab4a98f2acfacd4/performer. 4 | 5 | 6 | See ["Rethinking Attention with Performers"](https://arxiv.org/abs/2009.14794) for the paper associated with this library, as well as the corresponding [Google AI Blog post](https://ai.googleblog.com/2020/10/rethinking-attention-with-performers.html). 7 | 8 | There are two main attention variants, constructed using Fast Attention Via positive Orthogonal Random features (FAVOR+): 9 | 10 | * `make_fast_softmax_attention` - An unbiased and tight approximation of regular softmax attention. Can be used in Transformer models, as well as standalone for applications involving raw softmax attention or purely just softmax. 11 | * `make_fast_generalized_attention` - Allows for generalized attention functions to produce different attention kernels as described in the paper. 12 | 13 | The two functions create a `attention_fn` that has the same API as `flax.nn.attention.dot_product_attention`, allowing quick replacement for a Transformer built on top of `flax.nn.attention` modules. 14 | 15 | Their default hyperparameters are currently optimal for a variety of tasks, such as protein modelling, image generation, and natural language processing. 16 | 17 | The protein language modelling code can be found in [/google-research/protein_lm/](https://github.com/google-research/google-research/tree/master/protein_lm). In order to replace regular attention with our fast attention, set via gin: `FlaxModel.attention_fn = @make_fast_softmax_attention()` or `FlaxModel.attention_fn = @make_fast_generalized_attention()`. 18 | 19 | ## Notes: 20 | 21 | * Set `lax_scan_unroll=16` for both attention functions when using a GPU to provide 4x speedups due to loop unrolling optimizations in the unidirectional case. However, set `lax_scan_unroll=1` (defaulted) when using a TPU. 22 | * The unidirectional variant uses custom gradients via Jax, in order to provide significant memory reductions. 23 | * FAVOR has also been integrated into the [Reformer library](https://github.com/google/trax/blob/master/trax/layers/research/sparsity.py) as `CausalFavor`, in order to provide additional memory gains via reversible layers. 24 | 25 | If you found this codebase useful, please consider citing the paper: 26 | 27 | ``` 28 | @article{performer, 29 | author = {Krzysztof Choromanski and 30 | Valerii Likhosherstov and 31 | David Dohan and 32 | Xingyou Song and 33 | Andreea Gane and 34 | Tam{\'{a}}s Sarl{\'{o}}s and 35 | Peter Hawkins and 36 | Jared Davis and 37 | Afroz Mohiuddin and 38 | Lukasz Kaiser and 39 | David Belanger and 40 | Lucy Colwell and 41 | Adrian Weller}, 42 | title = {Rethinking Attention with Performers}, 43 | journal = {CoRR}, 44 | volume = {abs/2009.14794}, 45 | year = {2020}, 46 | url = {https://arxiv.org/abs/2009.14794}, 47 | archivePrefix = {arXiv}, 48 | eprint = {2009.14794} 49 | } 50 | ``` 51 | 52 | --------------------------------------------------------------------------------