├── .gitignore ├── LICENSE.md ├── Makefile ├── README.md ├── clip_jax ├── __init__.py ├── data.py ├── maxtext │ ├── README.md │ ├── __init__.py │ ├── accelerator_to_spec_map.py │ ├── checkpointing.py │ ├── common_types.py │ ├── inference_utils.py │ ├── input_pipeline │ │ ├── _grain_data_processing.py │ │ ├── _grain_operations.py │ │ ├── _grain_tokenizer.py │ │ ├── _hf_data_processing.py │ │ ├── _input_pipeline_utils.py │ │ ├── _tfds_data_processing.py │ │ ├── _tfds_data_processing_c4_mlperf.py │ │ └── input_pipeline_interface.py │ ├── layers │ │ ├── attentions.py │ │ ├── embeddings.py │ │ ├── gemma.py │ │ ├── gpt3.py │ │ ├── initializers.py │ │ ├── linears.py │ │ ├── llama2.py │ │ ├── mistral.py │ │ ├── models.py │ │ ├── normalizations.py │ │ ├── pipeline.py │ │ ├── quantizations.py │ │ └── simple_layer.py │ ├── max_logging.py │ ├── max_utils.py │ ├── multihost_dataloading.py │ ├── pyconfig.py │ ├── sequence_packing.py │ └── tokenizer.py ├── maxtext_model.ipynb ├── modeling.py ├── partitions.py ├── tokenizer.py ├── utils.py └── wandb_utils.py ├── configs ├── large-patch16-cappa.json ├── large-patch16-clip.json ├── mini-patch16-cappa.json └── mini-vision-maxtext.json ├── pyproject.toml ├── training ├── adafactor.py ├── kron.py ├── muon.py ├── precondition_local │ └── distributed_shampoo.py ├── quad.py ├── quad_pipeline_simple.py ├── run.sh └── train.py └── utils ├── convert_cappa_to_clip.ipynb ├── demo.ipynb ├── demo_cappa.ipynb ├── run_benchmark.py ├── run_benchmark_cappa.py └── run_benchmark_sugarcrepe.ipynb /.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/.gitignore -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/LICENSE.md -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/Makefile -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/README.md -------------------------------------------------------------------------------- /clip_jax/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/__init__.py -------------------------------------------------------------------------------- /clip_jax/data.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/data.py -------------------------------------------------------------------------------- /clip_jax/maxtext/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext/README.md -------------------------------------------------------------------------------- /clip_jax/maxtext/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext/__init__.py -------------------------------------------------------------------------------- /clip_jax/maxtext/accelerator_to_spec_map.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext/accelerator_to_spec_map.py -------------------------------------------------------------------------------- /clip_jax/maxtext/checkpointing.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext/checkpointing.py -------------------------------------------------------------------------------- /clip_jax/maxtext/common_types.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext/common_types.py -------------------------------------------------------------------------------- /clip_jax/maxtext/inference_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext/inference_utils.py -------------------------------------------------------------------------------- /clip_jax/maxtext/input_pipeline/_grain_data_processing.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext/input_pipeline/_grain_data_processing.py -------------------------------------------------------------------------------- /clip_jax/maxtext/input_pipeline/_grain_operations.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext/input_pipeline/_grain_operations.py -------------------------------------------------------------------------------- /clip_jax/maxtext/input_pipeline/_grain_tokenizer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext/input_pipeline/_grain_tokenizer.py -------------------------------------------------------------------------------- /clip_jax/maxtext/input_pipeline/_hf_data_processing.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext/input_pipeline/_hf_data_processing.py -------------------------------------------------------------------------------- /clip_jax/maxtext/input_pipeline/_input_pipeline_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext/input_pipeline/_input_pipeline_utils.py -------------------------------------------------------------------------------- /clip_jax/maxtext/input_pipeline/_tfds_data_processing.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext/input_pipeline/_tfds_data_processing.py -------------------------------------------------------------------------------- /clip_jax/maxtext/input_pipeline/_tfds_data_processing_c4_mlperf.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext/input_pipeline/_tfds_data_processing_c4_mlperf.py -------------------------------------------------------------------------------- /clip_jax/maxtext/input_pipeline/input_pipeline_interface.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext/input_pipeline/input_pipeline_interface.py -------------------------------------------------------------------------------- /clip_jax/maxtext/layers/attentions.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext/layers/attentions.py -------------------------------------------------------------------------------- /clip_jax/maxtext/layers/embeddings.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext/layers/embeddings.py -------------------------------------------------------------------------------- /clip_jax/maxtext/layers/gemma.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext/layers/gemma.py -------------------------------------------------------------------------------- /clip_jax/maxtext/layers/gpt3.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext/layers/gpt3.py -------------------------------------------------------------------------------- /clip_jax/maxtext/layers/initializers.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext/layers/initializers.py -------------------------------------------------------------------------------- /clip_jax/maxtext/layers/linears.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext/layers/linears.py -------------------------------------------------------------------------------- /clip_jax/maxtext/layers/llama2.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext/layers/llama2.py -------------------------------------------------------------------------------- /clip_jax/maxtext/layers/mistral.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext/layers/mistral.py -------------------------------------------------------------------------------- /clip_jax/maxtext/layers/models.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext/layers/models.py -------------------------------------------------------------------------------- /clip_jax/maxtext/layers/normalizations.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext/layers/normalizations.py -------------------------------------------------------------------------------- /clip_jax/maxtext/layers/pipeline.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext/layers/pipeline.py -------------------------------------------------------------------------------- /clip_jax/maxtext/layers/quantizations.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext/layers/quantizations.py -------------------------------------------------------------------------------- /clip_jax/maxtext/layers/simple_layer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext/layers/simple_layer.py -------------------------------------------------------------------------------- /clip_jax/maxtext/max_logging.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext/max_logging.py -------------------------------------------------------------------------------- /clip_jax/maxtext/max_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext/max_utils.py -------------------------------------------------------------------------------- /clip_jax/maxtext/multihost_dataloading.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext/multihost_dataloading.py -------------------------------------------------------------------------------- /clip_jax/maxtext/pyconfig.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext/pyconfig.py -------------------------------------------------------------------------------- /clip_jax/maxtext/sequence_packing.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext/sequence_packing.py -------------------------------------------------------------------------------- /clip_jax/maxtext/tokenizer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext/tokenizer.py -------------------------------------------------------------------------------- /clip_jax/maxtext_model.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/maxtext_model.ipynb -------------------------------------------------------------------------------- /clip_jax/modeling.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/modeling.py -------------------------------------------------------------------------------- /clip_jax/partitions.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/partitions.py -------------------------------------------------------------------------------- /clip_jax/tokenizer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/tokenizer.py -------------------------------------------------------------------------------- /clip_jax/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/utils.py -------------------------------------------------------------------------------- /clip_jax/wandb_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/clip_jax/wandb_utils.py -------------------------------------------------------------------------------- /configs/large-patch16-cappa.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/configs/large-patch16-cappa.json -------------------------------------------------------------------------------- /configs/large-patch16-clip.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/configs/large-patch16-clip.json -------------------------------------------------------------------------------- /configs/mini-patch16-cappa.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/configs/mini-patch16-cappa.json -------------------------------------------------------------------------------- /configs/mini-vision-maxtext.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/configs/mini-vision-maxtext.json -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/pyproject.toml -------------------------------------------------------------------------------- /training/adafactor.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/training/adafactor.py -------------------------------------------------------------------------------- /training/kron.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/training/kron.py -------------------------------------------------------------------------------- /training/muon.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/training/muon.py -------------------------------------------------------------------------------- /training/precondition_local/distributed_shampoo.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/training/precondition_local/distributed_shampoo.py -------------------------------------------------------------------------------- /training/quad.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/training/quad.py -------------------------------------------------------------------------------- /training/quad_pipeline_simple.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/training/quad_pipeline_simple.py -------------------------------------------------------------------------------- /training/run.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/training/run.sh -------------------------------------------------------------------------------- /training/train.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/training/train.py -------------------------------------------------------------------------------- /utils/convert_cappa_to_clip.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/utils/convert_cappa_to_clip.ipynb -------------------------------------------------------------------------------- /utils/demo.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/utils/demo.ipynb -------------------------------------------------------------------------------- /utils/demo_cappa.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/utils/demo_cappa.ipynb -------------------------------------------------------------------------------- /utils/run_benchmark.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/utils/run_benchmark.py -------------------------------------------------------------------------------- /utils/run_benchmark_cappa.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/utils/run_benchmark_cappa.py -------------------------------------------------------------------------------- /utils/run_benchmark_sugarcrepe.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borisdayma/clip-jax/HEAD/utils/run_benchmark_sugarcrepe.ipynb --------------------------------------------------------------------------------