├── .env.local ├── .gitignore ├── README.md ├── data ├── README.md ├── make_folder.py ├── multiprocessing │ ├── fineweb.py │ └── main.py ├── ray_distributed │ ├── main.py │ └── test_cluster.py ├── requirements.txt └── scripts │ ├── batch_test.sh │ └── setup_tpu.sh ├── dataset.py ├── debug_tpu.sh ├── launcher.sh ├── main.py ├── model.py ├── public ├── banner-light.png ├── banner.png ├── experts.png ├── loss-load.png └── loss-val.png ├── run.sh ├── setupTpu.sh └── utils.py /.env.local: -------------------------------------------------------------------------------- 1 | WANDB_KEY=your_wandb_key -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divyamakkar0/JAXformer/HEAD/.gitignore -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divyamakkar0/JAXformer/HEAD/README.md -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divyamakkar0/JAXformer/HEAD/data/README.md -------------------------------------------------------------------------------- /data/make_folder.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divyamakkar0/JAXformer/HEAD/data/make_folder.py -------------------------------------------------------------------------------- /data/multiprocessing/fineweb.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divyamakkar0/JAXformer/HEAD/data/multiprocessing/fineweb.py -------------------------------------------------------------------------------- /data/multiprocessing/main.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/ray_distributed/main.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divyamakkar0/JAXformer/HEAD/data/ray_distributed/main.py -------------------------------------------------------------------------------- /data/ray_distributed/test_cluster.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divyamakkar0/JAXformer/HEAD/data/ray_distributed/test_cluster.py -------------------------------------------------------------------------------- /data/requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divyamakkar0/JAXformer/HEAD/data/requirements.txt -------------------------------------------------------------------------------- /data/scripts/batch_test.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divyamakkar0/JAXformer/HEAD/data/scripts/batch_test.sh -------------------------------------------------------------------------------- /data/scripts/setup_tpu.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divyamakkar0/JAXformer/HEAD/data/scripts/setup_tpu.sh -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divyamakkar0/JAXformer/HEAD/dataset.py -------------------------------------------------------------------------------- /debug_tpu.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divyamakkar0/JAXformer/HEAD/debug_tpu.sh -------------------------------------------------------------------------------- /launcher.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divyamakkar0/JAXformer/HEAD/launcher.sh -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divyamakkar0/JAXformer/HEAD/main.py -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divyamakkar0/JAXformer/HEAD/model.py -------------------------------------------------------------------------------- /public/banner-light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divyamakkar0/JAXformer/HEAD/public/banner-light.png -------------------------------------------------------------------------------- /public/banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divyamakkar0/JAXformer/HEAD/public/banner.png -------------------------------------------------------------------------------- /public/experts.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divyamakkar0/JAXformer/HEAD/public/experts.png -------------------------------------------------------------------------------- /public/loss-load.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divyamakkar0/JAXformer/HEAD/public/loss-load.png -------------------------------------------------------------------------------- /public/loss-val.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divyamakkar0/JAXformer/HEAD/public/loss-val.png -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divyamakkar0/JAXformer/HEAD/run.sh -------------------------------------------------------------------------------- /setupTpu.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divyamakkar0/JAXformer/HEAD/setupTpu.sh -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divyamakkar0/JAXformer/HEAD/utils.py --------------------------------------------------------------------------------