├── LICENSE ├── README.md ├── assets └── gpt2_124M_loss.png ├── bench.py ├── config ├── gpt2.yaml └── test.yaml ├── data └── openwebtext │ └── prepare.py ├── dataset.py ├── model.py ├── requirements_tpu.txt ├── test.py └── train.py /LICENSE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jenkspt/gpt-jax/HEAD/LICENSE -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jenkspt/gpt-jax/HEAD/README.md -------------------------------------------------------------------------------- /assets/gpt2_124M_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jenkspt/gpt-jax/HEAD/assets/gpt2_124M_loss.png -------------------------------------------------------------------------------- /bench.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jenkspt/gpt-jax/HEAD/bench.py -------------------------------------------------------------------------------- /config/gpt2.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jenkspt/gpt-jax/HEAD/config/gpt2.yaml -------------------------------------------------------------------------------- /config/test.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jenkspt/gpt-jax/HEAD/config/test.yaml -------------------------------------------------------------------------------- /data/openwebtext/prepare.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jenkspt/gpt-jax/HEAD/data/openwebtext/prepare.py -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jenkspt/gpt-jax/HEAD/dataset.py -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jenkspt/gpt-jax/HEAD/model.py -------------------------------------------------------------------------------- /requirements_tpu.txt: -------------------------------------------------------------------------------- 1 | tyro 2 | wandb 3 | jax[tpu] 4 | flax -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jenkspt/gpt-jax/HEAD/test.py -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jenkspt/gpt-jax/HEAD/train.py --------------------------------------------------------------------------------