├── .readthedocs.yml ├── docs ├── .gitignore ├── flax.png ├── guides │ ├── quantization │ │ └── index.rst │ ├── parallel_training │ │ └── index.rst │ ├── data_preprocessing │ │ └── index.rst │ ├── model_inspection │ │ └── index.rst │ ├── training_techniques │ │ └── index.rst │ ├── flax_fundamentals │ │ └── index.rst │ ├── converting_and_upgrading │ │ └── index.rst │ └── index.rst ├── api_reference │ ├── flax.linen │ │ ├── inspection.rst │ │ ├── variable.rst │ │ ├── decorators.rst │ │ ├── init_apply.rst │ │ ├── profiling.rst │ │ ├── index.rst │ │ ├── transformations.rst │ │ ├── module.rst │ │ ├── spmd.rst │ │ ├── initializers.rst │ │ └── activation_functions.rst │ ├── flax.errors.rst │ ├── flax.struct.rst │ ├── index.rst │ ├── flax.traceback_util.rst │ ├── flax.core.frozen_dict.rst │ ├── flax.jax_utils.rst │ ├── flax.serialization.rst │ ├── flax.training.rst │ └── flax.cursor.rst ├── robots.txt ├── developer_notes │ └── index.rst ├── examples │ ├── index.rst │ └── repositories_that_use_flax.rst ├── _static │ └── css │ │ └── flax_theme.css ├── Makefile ├── flip │ ├── 0000-template.md │ └── README.md ├── .readthedocs.yaml └── _templates │ └── autosummary │ └── flax_module.rst ├── docs_nnx ├── .gitignore ├── flax.png ├── guides │ ├── images │ │ ├── performance-graph.png │ │ └── stateful-transforms.png │ ├── index.rst │ └── blog.md ├── examples │ ├── index.rst │ └── core_examples.rst ├── api_reference │ ├── flax.nnx │ │ ├── summary.rst │ │ ├── visualization.rst │ │ ├── nn │ │ │ ├── stochastic.rst │ │ │ ├── dtypes.rst │ │ │ ├── lora.rst │ │ │ ├── attention.rst │ │ │ ├── index.rst │ │ │ ├── linear.rst │ │ │ ├── normalization.rst │ │ │ ├── recurrent.rst │ │ │ ├── initializers.rst │ │ │ └── activations.rst │ │ ├── module.rst │ │ ├── training │ │ │ ├── optimizer.rst │ │ │ ├── index.rst │ │ │ └── metrics.rst │ │ ├── spmd.rst │ │ ├── helpers.rst │ │ ├── rnglib.rst │ │ ├── filterlib.rst │ │ ├── bridge.rst │ │ ├── index.rst │ │ ├── object.rst │ │ ├── transforms.rst │ │ ├── state.rst │ │ ├── variables.rst │ │ └── graph.rst │ ├── index.rst │ ├── flax.config.rst │ ├── flax.struct.rst │ ├── flax.training.rst │ ├── flax.traverse_util.rst │ └── flax.core.frozen_dict.rst ├── robots.txt ├── migrating │ └── index.rst ├── guides_advanced.rst ├── guides_basic.rst ├── _static │ └── css │ │ └── flax_theme.css ├── .readthedocs.yaml ├── Makefile ├── flip │ ├── 0000-template.md │ └── README.md ├── _templates │ └── autosummary │ │ └── flax_module.rst └── hijax │ └── index.rst ├── examples ├── vae │ ├── results │ │ └── .gitignore │ ├── sample.png │ ├── reconstruction.png │ ├── requirements.txt │ ├── configs │ │ └── default.py │ ├── README.md │ ├── input_pipeline.py │ ├── models.py │ └── main.py ├── nnx_toy_examples │ ├── requirements.txt │ ├── 06_scan_over_layers.py │ ├── 09_parameter_surgery.py │ └── 08_save_load_checkpoints.py ├── nlp_seq │ ├── requirements.txt │ ├── README.md │ ├── configs │ │ └── default.py │ └── main.py ├── seq2seq │ ├── requirements.txt │ ├── README.md │ ├── configs │ │ └── default.py │ └── main.py ├── sst2 │ ├── requirements.txt │ ├── configs │ │ └── default.py │ ├── README.md │ ├── build_vocabulary.py │ ├── train_test.py │ └── main.py ├── gemma │ └── requirements.txt ├── ppo │ ├── requirements.txt │ ├── ppo_main.py │ ├── configs │ │ └── default.py │ └── test_episodes.py ├── mnist │ ├── requirements.txt │ ├── configs │ │ └── default.py │ ├── README.md │ └── main.py ├── imagenet │ ├── requirements.txt │ ├── configs │ │ ├── v100_x8.py │ │ ├── tpu.py │ │ ├── v100_x8_mixed_precision.py │ │ ├── fake_data_benchmark.py │ │ └── default.py │ ├── models_test.py │ └── main.py ├── ogbg_molpcba │ ├── requirements.txt │ └── configs │ │ ├── test.py │ │ ├── default.py │ │ ├── default_graph_net.py │ │ └── hparam_sweep.py ├── lm1b │ ├── requirements.txt │ ├── temperature_sampler_test.py │ └── train_test.py ├── wmt │ └── requirements.txt ├── __init__.py ├── README.md ├── linen_design_test │ ├── dense.py │ ├── linear_regression.py │ ├── mlp_inline.py │ └── mlp_lazy.py └── cloud │ └── startup_script.sh ├── flax ├── nnx │ ├── scripts │ │ ├── requirements.txt │ │ └── run-all-examples.bash │ ├── nn │ │ └── __init__.py │ ├── transforms │ │ └── __init__.py │ ├── training │ │ └── __init__.py │ ├── bridge │ │ └── __init__.py │ └── ids.py ├── py.typed ├── oss │ └── .git-blame-ignore-revs ├── metrics │ └── __init__.py ├── experimental │ ├── __init__.py │ └── nnx.py ├── training │ └── __init__.py ├── version.py ├── testing │ └── __init__.py ├── __init__.py ├── core │ ├── tracers.py │ ├── __init__.py │ ├── nn │ │ ├── stochastic.py │ │ └── __init__.py │ └── variables.py ├── ids.py ├── linen │ └── README.md └── traceback_util.py ├── .git-blame-ignore-revs ├── .github ├── analytics │ ├── requirements.txt │ ├── README.md │ ├── issue_activity_since_date.gql │ └── pr_data_query.gql ├── ISSUE_TEMPLATE │ └── bug_report.md ├── pull_request_template.md └── workflows │ └── jax_nightly.yml ├── images ├── flax_logo.png ├── flax_logo_250px.png └── flax_logo_500px.png ├── benchmarks ├── tracing │ ├── requirements.txt │ ├── README.md │ ├── __init__.py │ ├── mnist.py │ ├── vae.py │ └── sst2.py └── README.md ├── contributing.md ├── AUTHORS ├── flaxlib_src ├── Cargo.toml ├── pyproject.toml ├── README.md ├── src │ └── flaxlib │ │ ├── __init__.py │ │ └── flaxlib_cpp.pyi └── .gitignore ├── .gitignore ├── tests ├── nnx │ ├── __init__.py │ ├── ids_test.py │ ├── filters_test.py │ └── containers_test.py ├── flaxlib_test.py ├── download_dataset_metadata.sh ├── pickle_test.py ├── linen │ ├── linen_activation_test.py │ ├── linen_dtypes_test.py │ └── initializers_test.py ├── import_test.ipynb ├── colab_tpu_jax_version.ipynb └── configurations_test.py └── .pre-commit-config.yaml /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # deprecated -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | _formatted_howtos 2 | -------------------------------------------------------------------------------- /docs_nnx/.gitignore: -------------------------------------------------------------------------------- 1 | _formatted_howtos 2 | -------------------------------------------------------------------------------- /examples/vae/results/.gitignore: -------------------------------------------------------------------------------- 1 | *.png 2 | -------------------------------------------------------------------------------- /flax/nnx/scripts/requirements.txt: -------------------------------------------------------------------------------- 1 | datasets>=2.12.0 2 | -------------------------------------------------------------------------------- /docs/flax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/flax/HEAD/docs/flax.png -------------------------------------------------------------------------------- /flax/py.typed: -------------------------------------------------------------------------------- 1 | # Marker file for PEP 561. The package uses inline types. 2 | -------------------------------------------------------------------------------- /.git-blame-ignore-revs: -------------------------------------------------------------------------------- 1 | # apply pyink 2 | 40a6e074e5224d733f964be00e21e0a1cb98bd2e -------------------------------------------------------------------------------- /.github/analytics/requirements.txt: -------------------------------------------------------------------------------- 1 | pandas 2 | absl-py 3 | requests 4 | matplotlib -------------------------------------------------------------------------------- /docs_nnx/flax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/flax/HEAD/docs_nnx/flax.png -------------------------------------------------------------------------------- /examples/nnx_toy_examples/requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib>=3.7.1 2 | datasets>=2.12.0 -------------------------------------------------------------------------------- /images/flax_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/flax/HEAD/images/flax_logo.png -------------------------------------------------------------------------------- /examples/vae/sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/flax/HEAD/examples/vae/sample.png -------------------------------------------------------------------------------- /images/flax_logo_250px.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/flax/HEAD/images/flax_logo_250px.png -------------------------------------------------------------------------------- /images/flax_logo_500px.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/flax/HEAD/images/flax_logo_500px.png -------------------------------------------------------------------------------- /examples/nlp_seq/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | flax==0.3.6 3 | numpy==1.22.0 4 | tensorflow==2.11.1 5 | -------------------------------------------------------------------------------- /examples/vae/reconstruction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/flax/HEAD/examples/vae/reconstruction.png -------------------------------------------------------------------------------- /examples/seq2seq/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | clu==0.0.6 3 | flax==0.3.6 4 | numpy==1.22.0 5 | optax==0.1.0 6 | -------------------------------------------------------------------------------- /benchmarks/tracing/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | flax 3 | google-benchmark 4 | jax 5 | ml_collections 6 | numpy 7 | optax -------------------------------------------------------------------------------- /docs/guides/quantization/index.rst: -------------------------------------------------------------------------------- 1 | Quantization 2 | ============ 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | fp8_basics 8 | -------------------------------------------------------------------------------- /contributing.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | Please see https://flax.readthedocs.io/en/latest/contributing.html for more information. 4 | -------------------------------------------------------------------------------- /docs_nnx/guides/images/performance-graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/flax/HEAD/docs_nnx/guides/images/performance-graph.png -------------------------------------------------------------------------------- /docs_nnx/guides/index.rst: -------------------------------------------------------------------------------- 1 | Guides 2 | ====== 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | index_basic 8 | index_advanced 9 | -------------------------------------------------------------------------------- /docs_nnx/guides/images/stateful-transforms.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/flax/HEAD/docs_nnx/guides/images/stateful-transforms.png -------------------------------------------------------------------------------- /docs_nnx/examples/index.rst: -------------------------------------------------------------------------------- 1 | Examples 2 | ======== 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | gemma 8 | core_examples 9 | 10 | 11 | -------------------------------------------------------------------------------- /docs/api_reference/flax.linen/inspection.rst: -------------------------------------------------------------------------------- 1 | 2 | Inspection 3 | ---------------------- 4 | 5 | .. currentmodule:: flax.linen 6 | 7 | .. autofunction:: tabulate 8 | -------------------------------------------------------------------------------- /docs/guides/parallel_training/index.rst: -------------------------------------------------------------------------------- 1 | Parallel training 2 | ================= 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | ensembling 8 | flax_on_pjit 9 | -------------------------------------------------------------------------------- /examples/vae/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | flax==0.6.9 3 | numpy==1.23.5 4 | optax==0.1.5 5 | Pillow==10.2.0 6 | tensorflow==2.12.0 7 | tensorflow-datasets==4.9.2 -------------------------------------------------------------------------------- /docs/guides/data_preprocessing/index.rst: -------------------------------------------------------------------------------- 1 | Data preprocessing 2 | ================= 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | full_eval 8 | loading_datasets 9 | -------------------------------------------------------------------------------- /docs/api_reference/flax.linen/variable.rst: -------------------------------------------------------------------------------- 1 | 2 | Variable dictionary 3 | ---------------------- 4 | 5 | .. automodule:: flax.core.variables 6 | .. autoclass:: flax.linen.Variable 7 | -------------------------------------------------------------------------------- /docs/guides/model_inspection/index.rst: -------------------------------------------------------------------------------- 1 | Model inspection 2 | ================ 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | model_surgery 8 | extracting_intermediates 9 | -------------------------------------------------------------------------------- /docs_nnx/api_reference/flax.nnx/summary.rst: -------------------------------------------------------------------------------- 1 | summary 2 | ------------------------ 3 | 4 | .. automodule:: flax.nnx 5 | .. currentmodule:: flax.nnx 6 | 7 | .. autofunction:: tabulate -------------------------------------------------------------------------------- /docs/api_reference/flax.linen/decorators.rst: -------------------------------------------------------------------------------- 1 | Decorators 2 | ---------------------- 3 | 4 | .. currentmodule:: flax.linen 5 | 6 | .. autofunction:: compact 7 | .. autofunction:: nowrap 8 | -------------------------------------------------------------------------------- /docs_nnx/api_reference/flax.nnx/visualization.rst: -------------------------------------------------------------------------------- 1 | visualization 2 | ------------------------ 3 | 4 | .. automodule:: flax.nnx 5 | .. currentmodule:: flax.nnx 6 | 7 | .. autofunction:: display -------------------------------------------------------------------------------- /docs_nnx/api_reference/flax.nnx/nn/stochastic.rst: -------------------------------------------------------------------------------- 1 | Stochastic 2 | ------------------------ 3 | 4 | .. automodule:: flax.nnx 5 | .. currentmodule:: flax.nnx 6 | 7 | .. autoclass:: Dropout 8 | :members: -------------------------------------------------------------------------------- /docs/robots.txt: -------------------------------------------------------------------------------- 1 | User-agent: * 2 | 3 | Disallow: /api_reference/flax.linen/_autosummary/ # for SEO, since Google still indexes this deprecated link 4 | 5 | Sitemap: https://flax.readthedocs.io/sitemap.xml 6 | -------------------------------------------------------------------------------- /docs_nnx/api_reference/index.rst: -------------------------------------------------------------------------------- 1 | API Reference 2 | ============= 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | flax.nnx/index 8 | flax.core.frozen_dict 9 | flax.struct 10 | flax.training -------------------------------------------------------------------------------- /docs_nnx/robots.txt: -------------------------------------------------------------------------------- 1 | User-agent: * 2 | 3 | Disallow: /api_reference/flax.linen/_autosummary/ # for SEO, since Google still indexes this deprecated link 4 | 5 | Sitemap: https://flax.readthedocs.io/sitemap.xml 6 | -------------------------------------------------------------------------------- /docs/api_reference/flax.errors.rst: -------------------------------------------------------------------------------- 1 | 2 | flax.errors package 3 | =================== 4 | 5 | Flax has the following classes of errors. 6 | 7 | .. automodule:: flax.errors 8 | :members: 9 | :exclude-members: FlaxError -------------------------------------------------------------------------------- /docs_nnx/migrating/index.rst: -------------------------------------------------------------------------------- 1 | Migrating 2 | ------------------------ 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | convert_pytorch_to_flax 8 | nnx_010_to_nnx_011 9 | linen_to_nnx 10 | haiku_to_flax 11 | -------------------------------------------------------------------------------- /docs/api_reference/flax.linen/init_apply.rst: -------------------------------------------------------------------------------- 1 | 2 | Init/Apply 3 | ============== 4 | 5 | .. currentmodule:: flax.linen 6 | 7 | .. autofunction:: apply 8 | .. autofunction:: init 9 | .. autofunction:: init_with_output 10 | -------------------------------------------------------------------------------- /docs/developer_notes/index.rst: -------------------------------------------------------------------------------- 1 | Developer notes 2 | =============== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | module_lifecycle 8 | lift 9 | FLIPs 10 | -------------------------------------------------------------------------------- /docs/examples/index.rst: -------------------------------------------------------------------------------- 1 | Examples 2 | ======== 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | core_examples 8 | google_research_examples 9 | repositories_that_use_flax 10 | community_examples 11 | 12 | 13 | -------------------------------------------------------------------------------- /examples/sst2/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | clu==0.0.6 3 | flax==0.3.6 4 | ml-collections==0.1.0 5 | numpy==1.22.0 6 | optax==0.1.0 7 | tensorflow==2.11.1 8 | tensorflow-datasets==4.4.0 9 | tensorflow-text==2.7.0 10 | -------------------------------------------------------------------------------- /docs_nnx/api_reference/flax.nnx/module.rst: -------------------------------------------------------------------------------- 1 | module 2 | ------------------------ 3 | 4 | .. automodule:: flax.nnx 5 | :members: iter_children, iter_modules 6 | .. currentmodule:: flax.nnx 7 | .. autoclass:: Module 8 | :members: 9 | -------------------------------------------------------------------------------- /docs/guides/training_techniques/index.rst: -------------------------------------------------------------------------------- 1 | Training techniques 2 | =================== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | batch_norm 8 | dropout 9 | lr_schedule 10 | transfer_learning 11 | use_checkpointing -------------------------------------------------------------------------------- /docs_nnx/api_reference/flax.config.rst: -------------------------------------------------------------------------------- 1 | 2 | flax.config package 3 | ==================== 4 | 5 | .. automodule:: flax.configurations 6 | :members: 7 | :undoc-members: 8 | :exclude-members: FlagHolder, bool_flag, static_bool_env 9 | -------------------------------------------------------------------------------- /docs_nnx/api_reference/flax.nnx/nn/dtypes.rst: -------------------------------------------------------------------------------- 1 | Dtypes 2 | ------------------------ 3 | 4 | .. automodule:: flax.nnx.nn.dtypes 5 | .. currentmodule:: flax.nnx.nn.dtypes 6 | 7 | .. autofunction:: canonicalize_dtype 8 | .. autofunction:: promote_dtype -------------------------------------------------------------------------------- /docs_nnx/guides_advanced.rst: -------------------------------------------------------------------------------- 1 | Advanced Guides 2 | ====== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | :caption: Advanced 7 | 8 | guides/flax_gspmd 9 | guides/performance 10 | guides/bridge_guide 11 | guides/surgery 12 | -------------------------------------------------------------------------------- /docs/api_reference/flax.linen/profiling.rst: -------------------------------------------------------------------------------- 1 | Profiling 2 | ---------------------- 3 | 4 | .. currentmodule:: flax.linen 5 | 6 | .. autofunction:: enable_named_call 7 | .. autofunction:: disable_named_call 8 | .. autofunction:: override_named_call 9 | -------------------------------------------------------------------------------- /docs_nnx/api_reference/flax.nnx/training/optimizer.rst: -------------------------------------------------------------------------------- 1 | Optimizer 2 | ------------------------ 3 | 4 | .. automodule:: flax.nnx.optimizer 5 | .. currentmodule:: flax.nnx.optimizer 6 | 7 | .. autoclass:: Optimizer 8 | :members: __init__, update 9 | -------------------------------------------------------------------------------- /docs/api_reference/flax.struct.rst: -------------------------------------------------------------------------------- 1 | 2 | flax.struct package 3 | ===================== 4 | 5 | .. currentmodule:: flax.struct 6 | 7 | .. automodule:: flax.struct 8 | 9 | 10 | .. autofunction:: dataclass 11 | 12 | 13 | .. autoclass:: PyTreeNode -------------------------------------------------------------------------------- /docs_nnx/api_reference/flax.struct.rst: -------------------------------------------------------------------------------- 1 | 2 | flax.struct package 3 | ===================== 4 | 5 | .. currentmodule:: flax.struct 6 | 7 | .. automodule:: flax.struct 8 | 9 | 10 | .. autofunction:: dataclass 11 | 12 | 13 | .. autoclass:: PyTreeNode -------------------------------------------------------------------------------- /examples/gemma/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py~=2.2 2 | clu==0.0.12 3 | flax~=0.10 4 | jax~=0.6 5 | mlcroissant~=1.0 6 | numpy~=2.1 7 | optax~=0.2 8 | sentencepiece~=0.2 9 | jaxtyping~=0.3 10 | tensorflow~=2.19 11 | tensorflow-datasets~=4.9 12 | tensorflow-text~=2.19 -------------------------------------------------------------------------------- /examples/ppo/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | atari-py==0.2.5 3 | opencv-python==4.5.4.60 4 | flax==0.3.6 5 | gym==0.18.3 6 | gymnasium[atari, accept-rom-license]==0.29.0 7 | ml-collections==0.1.0 8 | numpy==1.22.0 9 | optax==0.1.5 10 | tensorflow==2.11.1 11 | -------------------------------------------------------------------------------- /docs_nnx/api_reference/flax.nnx/spmd.rst: -------------------------------------------------------------------------------- 1 | spmd 2 | ------------------------ 3 | 4 | .. automodule:: flax.nnx 5 | .. currentmodule:: flax.nnx 6 | 7 | .. autofunction:: get_partition_spec 8 | .. autofunction:: get_named_sharding 9 | .. autofunction:: with_partitioning 10 | -------------------------------------------------------------------------------- /docs_nnx/api_reference/flax.training.rst: -------------------------------------------------------------------------------- 1 | 2 | flax.training package 3 | ===================== 4 | 5 | Train state 6 | ------------------------ 7 | 8 | .. currentmodule:: flax.training.train_state 9 | 10 | .. autoclass:: TrainState 11 | :members: apply_gradients, create 12 | 13 | -------------------------------------------------------------------------------- /docs_nnx/guides_basic.rst: -------------------------------------------------------------------------------- 1 | Basic Guides 2 | ====== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | :caption: Basic 7 | 8 | guides/pytree 9 | guides/transforms 10 | guides/filters_guide 11 | guides/randomness 12 | guides/checkpointing 13 | guides/jax_and_nnx_transforms -------------------------------------------------------------------------------- /docs_nnx/api_reference/flax.nnx/training/index.rst: -------------------------------------------------------------------------------- 1 | training 2 | ---------------------------- 3 | 4 | Experimental API. See the `NNX page `__ for more details. 5 | 6 | .. toctree:: 7 | :maxdepth: 3 8 | 9 | metrics 10 | optimizer 11 | 12 | -------------------------------------------------------------------------------- /docs/guides/flax_fundamentals/index.rst: -------------------------------------------------------------------------------- 1 | Flax fundamentals 2 | ================= 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | JAX 101 8 | flax_basics 9 | state_params 10 | setup_or_nncompact 11 | arguments 12 | rng_guide 13 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | # This is the list the Flax authors for copyright purposes. 2 | # 3 | # This does not necessarily list everyone who has contributed code, since in 4 | # some cases, their employer may be the copyright holder. To see the full list 5 | # of contributors, see the revision history in source control. 6 | Google LLC 7 | -------------------------------------------------------------------------------- /flaxlib_src/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "flaxlib" 3 | version = "0.0.1-a1" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | [lib] 8 | name = "flaxlib" 9 | crate-type = ["cdylib"] 10 | 11 | [dependencies] 12 | pyo3 = "0.21.2" 13 | -------------------------------------------------------------------------------- /docs_nnx/api_reference/flax.nnx/nn/lora.rst: -------------------------------------------------------------------------------- 1 | LoRA 2 | ------------------------ 3 | 4 | NNX LoRA classes. 5 | 6 | .. automodule:: flax.nnx 7 | .. currentmodule:: flax.nnx 8 | 9 | .. flax_module:: 10 | :module: flax.nnx 11 | :class: LoRA 12 | 13 | .. flax_module:: 14 | :module: flax.nnx 15 | :class: LoRALinear 16 | -------------------------------------------------------------------------------- /flax/nnx/scripts/run-all-examples.bash: -------------------------------------------------------------------------------- 1 | set -e 2 | 3 | source .venv/bin/activate 4 | 5 | for f in $(find examples/nnx_toy_examples -name "*.py" -maxdepth 1); do 6 | echo -e "\n---------------------------------" 7 | echo "$f" 8 | echo "---------------------------------" 9 | MPLBACKEND=Agg python "$f" 10 | done 11 | -------------------------------------------------------------------------------- /docs_nnx/api_reference/flax.nnx/helpers.rst: -------------------------------------------------------------------------------- 1 | helpers 2 | ------------------------ 3 | 4 | .. automodule:: flax.nnx 5 | .. currentmodule:: flax.nnx 6 | 7 | 8 | .. autoclass:: Sequential 9 | :members: 10 | .. autoclass:: List 11 | :members: 12 | .. autoclass:: Dict 13 | :members: 14 | .. autoclass:: TrainState 15 | :members: -------------------------------------------------------------------------------- /docs_nnx/api_reference/flax.nnx/rnglib.rst: -------------------------------------------------------------------------------- 1 | rnglib 2 | ------------------------ 3 | 4 | .. automodule:: flax.nnx 5 | .. currentmodule:: flax.nnx 6 | 7 | .. autoclass:: Rngs 8 | :members: __init__ 9 | .. autoclass:: RngStream 10 | :members: 11 | .. autofunction:: split_rngs 12 | .. autofunction:: fork_rngs 13 | .. autofunction:: reseed 14 | -------------------------------------------------------------------------------- /docs/guides/converting_and_upgrading/index.rst: -------------------------------------------------------------------------------- 1 | Converting and upgrading 2 | ======================== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | haiku_migration_guide 8 | convert_pytorch_to_flax 9 | orbax_upgrade_guide 10 | optax_update_guide 11 | linen_upgrade_guide 12 | rnncell_upgrade_guide 13 | regular_dict_upgrade_guide -------------------------------------------------------------------------------- /docs/guides/index.rst: -------------------------------------------------------------------------------- 1 | Guides 2 | ====== 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | flax_fundamentals/index 8 | data_preprocessing/index 9 | training_techniques/index 10 | parallel_training/index 11 | model_inspection/index 12 | converting_and_upgrading/index 13 | quantization/index 14 | The Sharp Bits 15 | -------------------------------------------------------------------------------- /docs/api_reference/index.rst: -------------------------------------------------------------------------------- 1 | API Reference 2 | ============= 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | flax.config 8 | flax.core.frozen_dict 9 | flax.cursor 10 | flax.errors 11 | flax.jax_utils 12 | flax.linen/index 13 | flax.serialization 14 | flax.struct 15 | flax.traceback_util 16 | flax.training 17 | flax.traverse_util -------------------------------------------------------------------------------- /docs_nnx/api_reference/flax.traverse_util.rst: -------------------------------------------------------------------------------- 1 | flax.traverse_util package 2 | ============================ 3 | 4 | .. currentmodule:: flax.traverse_util 5 | 6 | .. automodule:: flax.traverse_util 7 | 8 | Dict utils 9 | ------------ 10 | 11 | .. autofunction:: flatten_dict 12 | 13 | .. autofunction:: unflatten_dict 14 | 15 | .. autofunction:: path_aware_map 16 | -------------------------------------------------------------------------------- /docs/api_reference/flax.traceback_util.rst: -------------------------------------------------------------------------------- 1 | flax.traceback_util package 2 | ============================ 3 | 4 | .. currentmodule:: flax.traceback_util 5 | 6 | .. automodule:: flax.traceback_util 7 | 8 | 9 | Traceback filtering utils 10 | -------------------------- 11 | 12 | .. autofunction:: hide_flax_in_tracebacks 13 | 14 | .. autofunction:: show_flax_in_tracebacks 15 | -------------------------------------------------------------------------------- /examples/mnist/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | clu==0.0.6 3 | flax==0.4.1 4 | jax==0.3.4 5 | --find-links https://storage.googleapis.com/jax-releases/jax_releases.html 6 | jaxlib==0.3.2+cuda11.cudnn82 # Make sure CUDA version matches the base image. 7 | ml-collections==0.1.0 8 | numpy==1.22.0 9 | optax==0.1.0 10 | tensorflow==2.11.1 11 | tensorflow-datasets==4.4.0 12 | -------------------------------------------------------------------------------- /docs_nnx/guides/blog.md: -------------------------------------------------------------------------------- 1 | ### Do we need another JAX NN library? 2 | 3 | Hello, today I want to talk to you about a new JAX library that I have been working on, but before I do that, I wanted to discuss the topic: Do we need another JAX NN library? 4 | 5 | ### JAX Libraries 6 | 7 | JAX NN libraries come in a wide variety ranging from functional like Flax and Haiku, to Pytree-based like Equinox. -------------------------------------------------------------------------------- /benchmarks/README.md: -------------------------------------------------------------------------------- 1 | # Benchmarks 2 | 3 | These are mini benchmarks to measure the performance of NNX operations. 4 | 5 | Sample profile command: 6 | 7 | ```shell 8 | python -m cProfile -o ~/tmp/overhead.prof benchmarks/nnx_graph_overhead.py --mode=nnx --depth=100 --total_steps=1000 9 | ``` 10 | 11 | Sample profile inspection: 12 | 13 | ```shell 14 | snakeviz ~/tmp/overhead.prof 15 | ``` -------------------------------------------------------------------------------- /docs_nnx/api_reference/flax.nnx/nn/attention.rst: -------------------------------------------------------------------------------- 1 | Attention 2 | ------------------------ 3 | 4 | .. automodule:: flax.nnx 5 | .. currentmodule:: flax.nnx 6 | 7 | .. flax_module:: 8 | :module: flax.nnx 9 | :class: MultiHeadAttention 10 | 11 | .. autofunction:: combine_masks 12 | .. autofunction:: dot_product_attention 13 | .. autofunction:: make_attention_mask 14 | .. autofunction:: make_causal_mask -------------------------------------------------------------------------------- /examples/imagenet/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | clu==0.0.6 3 | flax==0.6.5 4 | -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 5 | -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 6 | jax[cuda11_cudnn805]>=0.3.16 # change to jax[tpu] if running on tpus 7 | ml-collections==0.1.0 8 | numpy==1.22.0 9 | optax==0.1.3 10 | tensorflow==2.11.1 11 | tensorflow-datasets==4.4.0 12 | -------------------------------------------------------------------------------- /examples/ogbg_molpcba/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | clu==0.0.6 3 | flax==0.4.1 4 | jax==0.3.4 5 | --find-links https://storage.googleapis.com/jax-releases/jax_releases.html 6 | jaxlib==0.3.2+cuda11.cudnn82 # Make sure CUDA version matches the base image. 7 | jraph==0.0.2.dev0 8 | ml-collections==0.1.0 9 | numpy==1.22.0 10 | optax==0.1.0 11 | sklearn==0.0 12 | tensorflow==2.11.1 13 | tensorflow-datasets==4.4.0 14 | -------------------------------------------------------------------------------- /docs_nnx/api_reference/flax.nnx/filterlib.rst: -------------------------------------------------------------------------------- 1 | filterlib 2 | ------------------------ 3 | 4 | .. automodule:: flax.nnx 5 | .. currentmodule:: flax.nnx 6 | 7 | 8 | .. autofunction:: flax.nnx.filterlib.to_predicate 9 | .. autoclass:: WithTag 10 | .. autoclass:: PathContains 11 | .. autoclass:: OfType 12 | .. autoclass:: Any 13 | .. autoclass:: All 14 | .. autoclass:: Not 15 | .. autoclass:: Everything 16 | .. autoclass:: Nothing -------------------------------------------------------------------------------- /examples/lm1b/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | clu==0.0.9 3 | flax==0.6.11 4 | jax==0.4.13 5 | --find-links https://storage.googleapis.com/jax-releases/jax_releases.html 6 | jaxlib==0.4.13+cuda11.cudnn82 # Make sure CUDA version matches the base image. 7 | ml-collections==0.1.1 8 | numpy==1.24.3 9 | optax==0.1.5 10 | sentencepiece==0.1.99 11 | tensorflow==2.13.0 12 | tensorflow-datasets==4.9.2 13 | tensorflow-text==2.13.0 14 | -------------------------------------------------------------------------------- /docs_nnx/api_reference/flax.nnx/bridge.rst: -------------------------------------------------------------------------------- 1 | bridge 2 | ------------------------ 3 | 4 | .. automodule:: flax.nnx.bridge 5 | .. currentmodule:: flax.nnx.bridge 6 | 7 | .. flax_module:: 8 | :module: flax.nnx.bridge 9 | :class: ToNNX 10 | 11 | .. flax_module:: 12 | :module: flax.nnx.bridge 13 | :class: ToLinen 14 | 15 | .. autofunction:: to_linen 16 | 17 | .. flax_module:: 18 | :module: flax.nnx.bridge 19 | :class: NNXMeta 20 | -------------------------------------------------------------------------------- /docs/_static/css/flax_theme.css: -------------------------------------------------------------------------------- 1 | @import url("theme.css"); 2 | 3 | .wy-nav-content { 4 | max-width: 1290px; 5 | } 6 | 7 | .rst-content table.docutils { 8 | width: 100%; 9 | } 10 | 11 | .rst-content table.docutils td { 12 | vertical-align: top; 13 | padding: 0; 14 | } 15 | 16 | .rst-content table.docutils td p { 17 | padding: 8px; 18 | } 19 | 20 | .rst-content div[class^=highlight] { 21 | border: 0; 22 | margin: 0; 23 | } 24 | -------------------------------------------------------------------------------- /docs/api_reference/flax.core.frozen_dict.rst: -------------------------------------------------------------------------------- 1 | 2 | flax.core.frozen_dict package 3 | ============================= 4 | 5 | .. currentmodule:: flax.core.frozen_dict 6 | 7 | .. autoclass:: FrozenDict 8 | :members: pretty_repr, copy, pop, unfreeze, tree_flatten 9 | 10 | .. autofunction:: freeze 11 | 12 | .. autofunction:: unfreeze 13 | 14 | .. autofunction:: copy 15 | 16 | .. autofunction:: pop 17 | 18 | .. autofunction:: pretty_repr 19 | -------------------------------------------------------------------------------- /docs_nnx/_static/css/flax_theme.css: -------------------------------------------------------------------------------- 1 | @import url("theme.css"); 2 | 3 | .wy-nav-content { 4 | max-width: 1290px; 5 | } 6 | 7 | .rst-content table.docutils { 8 | width: 100%; 9 | } 10 | 11 | .rst-content table.docutils td { 12 | vertical-align: top; 13 | padding: 0; 14 | } 15 | 16 | .rst-content table.docutils td p { 17 | padding: 8px; 18 | } 19 | 20 | .rst-content div[class^=highlight] { 21 | border: 0; 22 | margin: 0; 23 | } 24 | -------------------------------------------------------------------------------- /docs_nnx/api_reference/flax.core.frozen_dict.rst: -------------------------------------------------------------------------------- 1 | 2 | flax.core.frozen_dict package 3 | ============================= 4 | 5 | .. currentmodule:: flax.core.frozen_dict 6 | 7 | .. autoclass:: FrozenDict 8 | :members: pretty_repr, copy, pop, unfreeze, tree_flatten 9 | 10 | .. autofunction:: freeze 11 | 12 | .. autofunction:: unfreeze 13 | 14 | .. autofunction:: copy 15 | 16 | .. autofunction:: pop 17 | 18 | .. autofunction:: pretty_repr 19 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | \#*\# 3 | *.pyc 4 | .tfds 5 | .DS_Store 6 | dist/ 7 | build/ 8 | *.egg-info 9 | *.rej 10 | .pytype 11 | .vscode/* 12 | /.devcontainer 13 | docs*/**/_autosummary 14 | docs*/_build 15 | docs*/**/tmp 16 | flaxlib_src/build 17 | flaxlib_src/builddir 18 | flaxlib_src/dist 19 | flaxlib_src/subprojects 20 | .venv 21 | venv/ 22 | venv.bak/ 23 | 24 | # used by direnv 25 | .envrc 26 | 27 | # uv 28 | uv.lock 29 | 30 | # custom 31 | /tmp-files 32 | -------------------------------------------------------------------------------- /docs_nnx/api_reference/flax.nnx/index.rst: -------------------------------------------------------------------------------- 1 | flax.nnx 2 | ------------------------ 3 | 4 | Experimental API. See the `NNX page `__ for more details. 5 | 6 | .. toctree:: 7 | :maxdepth: 3 8 | 9 | graph 10 | object 11 | module 12 | nn/index 13 | rnglib 14 | spmd 15 | state 16 | training/index 17 | transforms 18 | variables 19 | helpers 20 | visualization 21 | filterlib 22 | bridge 23 | -------------------------------------------------------------------------------- /docs_nnx/api_reference/flax.nnx/nn/index.rst: -------------------------------------------------------------------------------- 1 | nn 2 | ---------------------------- 3 | 4 | Neural network layers and activation functions used in NNX :class:`Module`'s. 5 | See the `NNX page `__ for more details. 6 | 7 | .. toctree:: 8 | :maxdepth: 3 9 | 10 | activations 11 | attention 12 | dtypes 13 | initializers 14 | linear 15 | lora 16 | normalization 17 | recurrent 18 | stochastic 19 | 20 | -------------------------------------------------------------------------------- /docs/api_reference/flax.linen/index.rst: -------------------------------------------------------------------------------- 1 | 2 | flax.linen 3 | ========== 4 | 5 | Linen is the Flax Module system. Read more about our design goals in the `Linen README `_. 6 | 7 | .. toctree:: 8 | :maxdepth: 2 9 | 10 | module 11 | init_apply 12 | layers 13 | activation_functions 14 | initializers 15 | transformations 16 | inspection 17 | variable 18 | spmd 19 | decorators 20 | profiling -------------------------------------------------------------------------------- /examples/wmt/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | clu==0.0.6 3 | flax==0.6.0 4 | -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 5 | -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 6 | jax[cuda11_cudnn805]>=0.3.16 # change to jax[tpu] if running on tpus 7 | ml-collections==0.1.0 8 | numpy==1.22.0 9 | optax==0.1.0 10 | sentencepiece==0.1.96 11 | six==1.15.0 12 | tensorflow==2.11.1 13 | tensorflow-datasets==4.4.0 14 | tensorflow-text==2.8.1 15 | -------------------------------------------------------------------------------- /docs_nnx/api_reference/flax.nnx/object.rst: -------------------------------------------------------------------------------- 1 | object 2 | ------------------------ 3 | 4 | .. automodule:: flax.nnx 5 | .. currentmodule:: flax.nnx 6 | 7 | .. autoclass:: Pytree 8 | :members: 9 | .. autoclass:: Object 10 | :members: 11 | .. autofunction:: data 12 | .. autodata:: Data 13 | :annotation: 14 | .. autofunction:: static 15 | .. autodata:: Static 16 | :annotation: 17 | .. autofunction:: is_data 18 | .. autofunction:: register_data_type 19 | .. autofunction:: check_pytree -------------------------------------------------------------------------------- /docs/api_reference/flax.jax_utils.rst: -------------------------------------------------------------------------------- 1 | 2 | flax.jax_utils package 3 | ======================== 4 | 5 | .. currentmodule:: flax.jax_utils 6 | 7 | .. automodule:: flax.jax_utils 8 | 9 | 10 | .. autofunction:: partial_eval_by_shape 11 | 12 | 13 | Multi device utilities 14 | ------------------------ 15 | 16 | .. autofunction:: replicate 17 | .. autofunction:: unreplicate 18 | 19 | .. autofunction:: prefetch_to_device 20 | 21 | .. autofunction:: pmean 22 | 23 | .. autofunction:: pad_shard_unpad 24 | -------------------------------------------------------------------------------- /flax/oss/ .git-blame-ignore-revs: -------------------------------------------------------------------------------- 1 | # .git-blame-ignore-revs 2 | # 3 | # These commits will be ignored by the github blame view. 4 | # The git blame CLI can ignore them as well by doing: 5 | # git blame --ignore-revs-file .git-blame-ignore-revs 6 | # or via global config: 7 | # git config --global blame.ignoreRevsFile .git-blame-ignore-revs 8 | # see the blame.markIgnoredLines and blame.markUnblamableLines options. 9 | # 10 | # remove all trailing whitespaces 11 | 442df07ca1a90f04c685cdae9f8e488bbffc2f83 12 | -------------------------------------------------------------------------------- /docs/api_reference/flax.linen/transformations.rst: -------------------------------------------------------------------------------- 1 | Transformations 2 | ---------------------- 3 | 4 | .. automodule:: flax.linen.transforms 5 | .. currentmodule:: flax.linen 6 | 7 | .. autofunction:: vmap 8 | .. autofunction:: scan 9 | .. autofunction:: jit 10 | .. autofunction:: remat 11 | .. autofunction:: remat_scan 12 | .. autofunction:: map_variables 13 | .. autofunction:: jvp 14 | .. autofunction:: vjp 15 | .. autofunction:: custom_vjp 16 | .. autofunction:: while_loop 17 | .. autofunction:: cond 18 | .. autofunction:: switch 19 | -------------------------------------------------------------------------------- /docs_nnx/api_reference/flax.nnx/transforms.rst: -------------------------------------------------------------------------------- 1 | transforms 2 | ------------------------ 3 | 4 | .. automodule:: flax.nnx 5 | .. currentmodule:: flax.nnx 6 | .. autofunction:: grad 7 | .. autofunction:: jit 8 | .. autofunction:: shard_map 9 | .. autofunction:: remat 10 | .. autofunction:: scan 11 | .. autoclass:: Carry 12 | .. autofunction:: value_and_grad 13 | .. autofunction:: vmap 14 | .. autofunction:: eval_shape 15 | .. autofunction:: custom_vjp 16 | .. autofunction:: cond 17 | .. autofunction:: switch 18 | .. autofunction:: while_loop 19 | .. autofunction:: fori_loop 20 | -------------------------------------------------------------------------------- /docs_nnx/api_reference/flax.nnx/training/metrics.rst: -------------------------------------------------------------------------------- 1 | Metrics 2 | ------------------------ 3 | 4 | .. automodule:: flax.nnx.metrics 5 | .. currentmodule:: flax.nnx.metrics 6 | 7 | 8 | .. autoclass:: Metric 9 | :members: __init__, reset, update, compute 10 | 11 | .. autoclass:: Average 12 | :members: __init__, reset, update, compute 13 | 14 | .. autoclass:: Accuracy 15 | :members: update 16 | 17 | .. autoclass:: Welford 18 | :members: __init__, reset, update, compute 19 | 20 | .. autoclass:: MultiMetric 21 | :members: __init__, reset, update, compute 22 | 23 | -------------------------------------------------------------------------------- /docs_nnx/api_reference/flax.nnx/state.rst: -------------------------------------------------------------------------------- 1 | state 2 | ------------------------ 3 | 4 | .. automodule:: flax.nnx 5 | .. currentmodule:: flax.nnx 6 | 7 | 8 | .. autoclass:: State 9 | :members: 10 | 11 | .. autoclass:: FlatState 12 | :members: 13 | 14 | .. autofunction:: filter_state 15 | .. autofunction:: from_flat_state 16 | .. autofunction:: map_state 17 | .. autofunction:: merge_state 18 | .. autofunction:: replace_by_pure_dict 19 | .. autofunction:: restore_int_paths 20 | .. autofunction:: to_flat_state 21 | .. autofunction:: to_pure_dict 22 | .. autofunction:: split_state 23 | -------------------------------------------------------------------------------- /benchmarks/tracing/README.md: -------------------------------------------------------------------------------- 1 | # Tracing and lowering benchmarks for Flax examples 2 | 3 | See Flax 4 | [documentation](https://flax.readthedocs.io/en/latest/examples/index.html) on 5 | their examples. 6 | 7 | ## Getting started 8 | bash 9 | ``` 10 | pip install -r benchmarks/tracing/requirements.txt 11 | 12 | # Benchmark trace and lower timing for all workloads. 13 | python tracing_benchmark.py 14 | 15 | # Profile a single example. 16 | python tracing_benchmark.py --example=wmt 17 | 18 | # Profile just tracing for a single example. 19 | python tracing_benchmark.py --example=wmt --mode=trace 20 | ``` -------------------------------------------------------------------------------- /docs_nnx/api_reference/flax.nnx/variables.rst: -------------------------------------------------------------------------------- 1 | variables 2 | ------------------------ 3 | 4 | .. automodule:: flax.nnx 5 | .. currentmodule:: flax.nnx 6 | 7 | .. autoclass:: BatchStat 8 | :members: 9 | .. autoclass:: Cache 10 | :members: 11 | .. autoclass:: Intermediate 12 | :members: 13 | .. autoclass:: Param 14 | :members: 15 | .. autoclass:: Variable 16 | :members: 17 | .. autoclass:: VariableMetadata 18 | :members: 19 | 20 | .. autofunction:: with_metadata 21 | 22 | .. autofunction:: variable_name_from_type 23 | .. autofunction:: variable_type_from_name 24 | .. autofunction:: register_variable_name 25 | -------------------------------------------------------------------------------- /docs/api_reference/flax.linen/module.rst: -------------------------------------------------------------------------------- 1 | Module 2 | ------------------------ 3 | 4 | .. automodule:: flax.linen 5 | .. currentmodule:: flax.linen 6 | 7 | .. autoclass:: Module 8 | :members: setup, variable, param, bind, unbind, apply, init, init_with_output, copy, make_rng, sow, variables, Variable, __setattr__, tabulate, module_paths, is_initializing, perturb, put_variable, has_variable, has_rng, lazy_init, get_variable, path, is_mutable_collection 9 | 10 | .. autofunction:: apply 11 | .. autofunction:: init 12 | .. autofunction:: init_with_output 13 | .. autofunction:: intercept_methods 14 | .. autofunction:: share_scope 15 | -------------------------------------------------------------------------------- /docs/api_reference/flax.serialization.rst: -------------------------------------------------------------------------------- 1 | 2 | flax.serialization package 3 | ============================ 4 | 5 | .. currentmodule:: flax.serialization 6 | 7 | .. automodule:: flax.serialization 8 | 9 | 10 | State dicts 11 | ------------------------ 12 | 13 | .. autofunction:: from_state_dict 14 | .. autofunction:: to_state_dict 15 | 16 | .. autofunction:: register_serialization_state 17 | 18 | 19 | Serialization with MessagePack 20 | -------------------------------- 21 | 22 | .. autofunction:: msgpack_serialize 23 | .. autofunction:: msgpack_restore 24 | 25 | .. autofunction:: to_bytes 26 | .. autofunction:: from_bytes 27 | -------------------------------------------------------------------------------- /docs_nnx/api_reference/flax.nnx/nn/linear.rst: -------------------------------------------------------------------------------- 1 | Linear 2 | ------------------------ 3 | 4 | NNX linear layer classes. 5 | 6 | .. automodule:: flax.nnx 7 | .. currentmodule:: flax.nnx 8 | 9 | .. flax_module:: 10 | :module: flax.nnx 11 | :class: Conv 12 | 13 | .. flax_module:: 14 | :module: flax.nnx 15 | :class: ConvTranspose 16 | 17 | .. flax_module:: 18 | :module: flax.nnx 19 | :class: Embed 20 | 21 | .. flax_module:: 22 | :module: flax.nnx 23 | :class: Linear 24 | 25 | .. flax_module:: 26 | :module: flax.nnx 27 | :class: LinearGeneral 28 | 29 | .. flax_module:: 30 | :module: flax.nnx 31 | :class: Einsum -------------------------------------------------------------------------------- /.github/analytics/README.md: -------------------------------------------------------------------------------- 1 | # Repo Analytics 2 | 3 | To run the repo analytics follow the steps below: 4 | 5 | 1. You must have a Github token, if you don't have one you can create one by following [this guide](https://docs.github.com/en/enterprise-server@3.4/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token). 6 | 2. Install the requirements: 7 | 8 | ```bash 9 | pip install -r .github/analytics/requirements.txt 10 | ``` 11 | 3. Run the analytics: 12 | 13 | ```bash 14 | GITHUB_TOKEN= \ 15 | python .github/analytics/get_repo_metrics.py \ 16 | --repo-owner google \ 17 | --repo-name flax 18 | ``` -------------------------------------------------------------------------------- /tests/nnx/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /flax/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /flax/nnx/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /benchmarks/tracing/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The JAX Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /flax/nnx/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /flax/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /flax/nnx/training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /flax/training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Flax training utilities.""" 16 | -------------------------------------------------------------------------------- /docs_nnx/api_reference/flax.nnx/nn/normalization.rst: -------------------------------------------------------------------------------- 1 | Normalization 2 | ------------------------ 3 | 4 | .. automodule:: flax.nnx 5 | .. currentmodule:: flax.nnx 6 | 7 | .. flax_module:: 8 | :module: flax.nnx 9 | :class: BatchNorm 10 | 11 | .. flax_module:: 12 | :module: flax.nnx 13 | :class: LayerNorm 14 | 15 | .. flax_module:: 16 | :module: flax.nnx 17 | :class: RMSNorm 18 | 19 | .. flax_module:: 20 | :module: flax.nnx 21 | :class: GroupNorm 22 | 23 | .. flax_module:: 24 | :module: flax.nnx 25 | :class: InstanceNorm 26 | 27 | .. flax_module:: 28 | :module: flax.nnx 29 | :class: SpectralNorm 30 | 31 | .. flax_module:: 32 | :module: flax.nnx 33 | :class: WeightNorm 34 | -------------------------------------------------------------------------------- /docs/api_reference/flax.linen/spmd.rst: -------------------------------------------------------------------------------- 1 | 2 | SPMD 3 | ---------------------- 4 | 5 | .. automodule:: flax.linen.spmd 6 | .. currentmodule:: flax.linen 7 | 8 | .. autofunction:: Partitioned 9 | .. autofunction:: with_partitioning 10 | .. autofunction:: get_partition_spec 11 | .. autofunction:: get_sharding 12 | .. autofunction:: LogicallyPartitioned 13 | .. autofunction:: logical_axis_rules 14 | .. autofunction:: set_logical_axis_rules 15 | .. autofunction:: get_logical_axis_rules 16 | .. autofunction:: logical_to_mesh_axes 17 | .. autofunction:: logical_to_mesh 18 | .. autofunction:: logical_to_mesh_sharding 19 | .. autofunction:: with_logical_constraint 20 | .. autofunction:: with_logical_partitioning 21 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs_nnx/.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | build: 9 | os: ubuntu-22.04 10 | tools: 11 | python: "3.12" 12 | jobs: 13 | pre_build: 14 | - pip install ".[all, testing, docs]" 15 | - pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ 16 | 17 | # Build documentation in the docs/ directory with Sphinx 18 | sphinx: 19 | configuration: docs_nnx/conf.py 20 | 21 | # Optionally build your docs in additional formats such as PDF and ePub 22 | formats: 23 | - htmlzip 24 | - epub 25 | # - pdf 26 | -------------------------------------------------------------------------------- /flax/version.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Current Flax version at head on Github.""" 16 | __version__ = '0.12.1' 17 | -------------------------------------------------------------------------------- /docs_nnx/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /flax/testing/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | """Flax testing utilities.""" 17 | 18 | from .benchmark import Benchmark 19 | -------------------------------------------------------------------------------- /docs_nnx/api_reference/flax.nnx/nn/recurrent.rst: -------------------------------------------------------------------------------- 1 | Recurrent 2 | ------------------------ 3 | 4 | .. automodule:: flax.nnx.nn.recurrent 5 | .. currentmodule:: flax.nnx.nn.recurrent 6 | 7 | .. flax_module:: 8 | :module: flax.nnx.nn.recurrent 9 | :class: LSTMCell 10 | 11 | .. flax_module:: 12 | :module: flax.nnx.nn.recurrent 13 | :class: OptimizedLSTMCell 14 | 15 | .. flax_module:: 16 | :module: flax.nnx.nn.recurrent 17 | :class: SimpleCell 18 | 19 | .. flax_module:: 20 | :module: flax.nnx.nn.recurrent 21 | :class: GRUCell 22 | 23 | .. flax_module:: 24 | :module: flax.nnx.nn.recurrent 25 | :class: RNN 26 | 27 | .. flax_module:: 28 | :module: flax.nnx.nn.recurrent 29 | :class: Bidirectional 30 | 31 | 32 | .. autofunction:: flip_sequences -------------------------------------------------------------------------------- /docs/flip/0000-template.md: -------------------------------------------------------------------------------- 1 | - Start Date: (fill me in with today's date, YYYY-MM-DD) 2 | - FLIP PR: [#0000](https://github.com/google/flax/pull/0000) 3 | - FLIP Issue: [#0000](https://github.com/google/flax/issues/0000) 4 | 5 | (Below sections are just a possible structure - please adapt to your FLIP.) 6 | 7 | # Summary 8 | [summary]: #summary 9 | 10 | One paragraph explanation of the FLIP. 11 | 12 | # Motivation 13 | [motivation]: #motivation 14 | 15 | Why are we doing this? What use cases does it support? What is the expected outcome? 16 | 17 | # Implementation 18 | [implementation]: #implementation 19 | 20 | The technical part. 21 | 22 | # Discussion 23 | [discussion]: #discussion 24 | 25 | Summarize the discussion from the original issue and from the pull request. 26 | -------------------------------------------------------------------------------- /docs/.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | build: 9 | os: ubuntu-22.04 10 | tools: 11 | python: "3.12" 12 | 13 | # Build documentation in the docs/ directory with Sphinx 14 | sphinx: 15 | configuration: docs/conf.py 16 | 17 | # Optionally build your docs in additional formats such as PDF and ePub 18 | formats: 19 | - htmlzip 20 | - epub 21 | # - pdf 22 | 23 | # Optionally set the version of Python and requirements required to build your docs 24 | python: 25 | install: 26 | - method: pip 27 | path: . 28 | extra_requirements: 29 | - all 30 | - testing 31 | - docs 32 | -------------------------------------------------------------------------------- /docs_nnx/flip/0000-template.md: -------------------------------------------------------------------------------- 1 | - Start Date: (fill me in with today's date, YYYY-MM-DD) 2 | - FLIP PR: [#0000](https://github.com/google/flax/pull/0000) 3 | - FLIP Issue: [#0000](https://github.com/google/flax/issues/0000) 4 | 5 | (Below sections are just a possible structure - please adapt to your FLIP.) 6 | 7 | # Summary 8 | [summary]: #summary 9 | 10 | One paragraph explanation of the FLIP. 11 | 12 | # Motivation 13 | [motivation]: #motivation 14 | 15 | Why are we doing this? What use cases does it support? What is the expected outcome? 16 | 17 | # Implementation 18 | [implementation]: #implementation 19 | 20 | The technical part. 21 | 22 | # Discussion 23 | [discussion]: #discussion 24 | 25 | Summarize the discussion from the original issue and from the pull request. 26 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Flax Examples 2 | 3 | Each example is designed to be **self-contained and easily forkable**, while 4 | reproducing relevant results in different areas of machine learning. 5 | 6 | As discussed in [#231](https://github.com/google/flax/issues/231), we decided 7 | to go for a standard pattern for all examples including the simplest ones 8 | (like MNIST). This makes every example a bit more verbose, but once you know 9 | one example, you know the structure of all of them. Having unit tests and 10 | integration tests is also very useful when you fork these examples. 11 | 12 | For more examples including contributions from the community and other projects currently using Flax see the **[Examples](https://flax.readthedocs.io/en/latest/examples.html)** section in the documentation. 13 | -------------------------------------------------------------------------------- /docs/_templates/autosummary/flax_module.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline }} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | :exclude-members: 7 | 8 | .. automethod:: __call__ 9 | 10 | {% block methods %} 11 | 12 | {% for item in methods %} 13 | {%- if item not in inherited_members and item not in annotations and not item in ['__init__', 'setup'] %} 14 | .. automethod:: {{ item }} 15 | {%- endif %} 16 | {%- endfor %} 17 | 18 | {% if methods %} 19 | .. rubric:: Methods 20 | 21 | .. autosummary:: 22 | 23 | {% for item in methods %} 24 | {%- if item not in inherited_members and item not in annotations and not item in ['__init__', 'setup'] %} 25 | ~{{ name }}.{{ item }} 26 | {%- endif %} 27 | {%- endfor %} 28 | {% endif %} 29 | {% endblock %} -------------------------------------------------------------------------------- /docs_nnx/_templates/autosummary/flax_module.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline }} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | :exclude-members: 7 | 8 | .. automethod:: __call__ 9 | 10 | {% block methods %} 11 | 12 | {% for item in methods %} 13 | {%- if item not in inherited_members and item not in annotations and not item in ['__init__', 'setup'] %} 14 | .. automethod:: {{ item }} 15 | {%- endif %} 16 | {%- endfor %} 17 | 18 | {% if methods %} 19 | .. rubric:: Methods 20 | 21 | .. autosummary:: 22 | 23 | {% for item in methods %} 24 | {%- if item not in inherited_members and item not in annotations and not item in ['__init__', 'setup'] %} 25 | ~{{ name }}.{{ item }} 26 | {%- endif %} 27 | {%- endfor %} 28 | {% endif %} 29 | {% endblock %} -------------------------------------------------------------------------------- /flax/experimental/nnx.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from absl import logging 16 | 17 | from flax.nnx import * 18 | 19 | 20 | logging.warning( 21 | "Using 'flax.experimental.nnx' is deprecated. Please use 'flax.nnx' instead." 22 | ) -------------------------------------------------------------------------------- /flaxlib_src/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["scikit-build-core >=0.4.3", "nanobind >=1.3.2"] 3 | build-backend = "scikit_build_core.build" 4 | 5 | [project] 6 | name = "flaxlib" 7 | version = "0.0.1" 8 | requires-python = ">=3.10" 9 | classifiers = [ 10 | "Programming Language :: C++", 11 | "Programming Language :: Python :: Implementation :: CPython", 12 | "Programming Language :: Python :: Implementation :: PyPy", 13 | ] 14 | 15 | [project.optional-dependencies] 16 | tests = [ 17 | "pytest", 18 | ] 19 | 20 | [tool.scikit-build] 21 | # Protect the configuration against future changes in scikit-build-core 22 | minimum-version = "0.4" 23 | 24 | # Setuptools-style build caching in a local directory 25 | build-dir = "build/{wheel_tag}" 26 | 27 | # Build stable ABI wheels for CPython 3.12+ 28 | wheel.py-api = "cp312" -------------------------------------------------------------------------------- /flaxlib_src/README.md: -------------------------------------------------------------------------------- 1 | # flaxlib 2 | 3 | ## Build flaxlib from source 4 | 5 | Install necessary dependencies to build the C++ based package. 6 | 7 | ```shell 8 | pip install meson-python ninja build 9 | ``` 10 | 11 | Clone the Flax repository, navigate to the flaxlib source directory. 12 | 13 | ```shell 14 | git clone git@github.com:google/flax.git 15 | cd flax/flaxlib_src 16 | ``` 17 | 18 | Configure the build. 19 | 20 | ```shell 21 | mkdir -p subprojects 22 | meson wrap install robin-map 23 | meson wrap install nanobind 24 | meson setup builddir 25 | ``` 26 | 27 | Compile the code. You'll need to run this repeatedly if you modify the source 28 | code. Note that the actual wheel name will differ depending on your system. 29 | 30 | ```shell 31 | meson compile -C builddir 32 | python -m build . -w 33 | pip install dist/flaxlib-0.0.1-cp311-cp311-macosx_14_0_arm64.whl --force-reinstall 34 | ``` 35 | -------------------------------------------------------------------------------- /docs/api_reference/flax.linen/initializers.rst: -------------------------------------------------------------------------------- 1 | Initializers 2 | ------------------------ 3 | 4 | .. automodule:: flax.linen.initializers 5 | .. currentmodule:: flax.linen.initializers 6 | 7 | .. autofunction:: constant 8 | .. autofunction:: delta_orthogonal 9 | .. autofunction:: glorot_normal 10 | .. autofunction:: glorot_uniform 11 | .. autofunction:: he_normal 12 | .. autofunction:: he_uniform 13 | .. autofunction:: kaiming_normal 14 | .. autofunction:: kaiming_uniform 15 | .. autofunction:: lecun_normal 16 | .. autofunction:: lecun_uniform 17 | .. autofunction:: normal 18 | .. autofunction:: truncated_normal 19 | .. autofunction:: ones 20 | .. autofunction:: ones_init 21 | .. autofunction:: orthogonal 22 | .. autofunction:: uniform 23 | .. autofunction:: variance_scaling 24 | .. autofunction:: xavier_normal 25 | .. autofunction:: xavier_uniform 26 | .. autofunction:: zeros 27 | .. autofunction:: zeros_init 28 | -------------------------------------------------------------------------------- /docs_nnx/api_reference/flax.nnx/nn/initializers.rst: -------------------------------------------------------------------------------- 1 | Initializers 2 | ------------------------ 3 | 4 | .. automodule:: flax.nnx.initializers 5 | .. currentmodule:: flax.nnx.initializers 6 | 7 | .. autofunction:: constant 8 | .. autofunction:: delta_orthogonal 9 | .. autofunction:: glorot_normal 10 | .. autofunction:: glorot_uniform 11 | .. autofunction:: he_normal 12 | .. autofunction:: he_uniform 13 | .. autofunction:: kaiming_normal 14 | .. autofunction:: kaiming_uniform 15 | .. autofunction:: lecun_normal 16 | .. autofunction:: lecun_uniform 17 | .. autofunction:: normal 18 | .. autofunction:: truncated_normal 19 | .. autofunction:: ones 20 | .. autofunction:: ones_init 21 | .. autofunction:: orthogonal 22 | .. autofunction:: uniform 23 | .. autofunction:: variance_scaling 24 | .. autofunction:: xavier_normal 25 | .. autofunction:: xavier_uniform 26 | .. autofunction:: zeros 27 | .. autofunction:: zeros_init 28 | -------------------------------------------------------------------------------- /flaxlib_src/src/flaxlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .flaxlib_cpp import RefMap as RefMap 16 | from .flaxlib_cpp import IndexMap as IndexMap 17 | from .flaxlib_cpp import NodeDef as NodeDef 18 | from .flaxlib_cpp import VariableDef as VariableDef 19 | from .flaxlib_cpp import NodeRef as NodeRef 20 | -------------------------------------------------------------------------------- /docs_nnx/api_reference/flax.nnx/nn/activations.rst: -------------------------------------------------------------------------------- 1 | Activation functions 2 | ------------------------ 3 | 4 | .. automodule:: flax.nnx 5 | .. currentmodule:: flax.nnx 6 | 7 | .. autofunction:: celu 8 | .. autofunction:: elu 9 | .. autofunction:: gelu 10 | .. autofunction:: glu 11 | .. autofunction:: hard_sigmoid 12 | .. autofunction:: hard_silu 13 | .. autofunction:: hard_swish 14 | .. autofunction:: hard_tanh 15 | .. autofunction:: leaky_relu 16 | .. autofunction:: log_sigmoid 17 | .. autofunction:: log_softmax 18 | .. autofunction:: logsumexp 19 | .. autofunction:: one_hot 20 | .. autofunction:: relu 21 | .. autofunction:: relu6 as relu6, 22 | .. autofunction:: selu 23 | .. autofunction:: sigmoid 24 | .. autofunction:: identity 25 | .. autofunction:: silu 26 | .. autofunction:: soft_sign 27 | .. autofunction:: softmax 28 | .. autofunction:: softplus 29 | .. autofunction:: standardize 30 | .. autofunction:: swish 31 | .. autofunction:: tanh -------------------------------------------------------------------------------- /tests/flaxlib_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | # TODO: Re-enable this test after setting up CI build for flaxlib CC. 17 | 18 | # from absl.testing import absltest 19 | # import flaxlib 20 | 21 | 22 | # class TestFlaxlib(absltest.TestCase): 23 | 24 | # def test_flaxlib(self): 25 | # self.assertEqual(flaxlib.sum_as_string(1, 2), '3') 26 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried. 11 | 12 | ### System information 13 | - OS Platform and Distribution (e.g., Linux Ubuntu 16.04): 14 | - Flax, jax, jaxlib versions (obtain with `pip show flax jax jaxlib`: 15 | - Python version: 16 | - GPU/TPU model and memory: 17 | - CUDA version (if applicable): 18 | 19 | 20 | ### Problem you have encountered: 21 | 22 | 23 | ### What you expected to happen: 24 | 25 | 26 | ### Logs, error messages, etc: 27 | 28 | 29 | 30 | ### Steps to reproduce: 31 | Whenever possible, please provide a *minimal example*. Please consider submitting it as a Colab link. 32 | -------------------------------------------------------------------------------- /tests/download_dataset_metadata.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # If you get an error like: 3 | # Cloning into 'datasets'... 4 | # fatal: cannot change to 'https://github.com/tensorflow/datasets/': No such file or directory 5 | # error: failed to initialize sparse-checkout 6 | # This mean your git version is outdated. Just update it. 7 | 8 | 9 | set -e 10 | 11 | # Download TFDS metadata to flax/.tfds/metadata directory. 12 | # This allows the tests to specify the `data_dir` when using tfds.testing.mock_data(). 13 | cd "$( dirname "$0" )" 14 | 15 | if [ -d "../.tfds/metadata" ]; then 16 | echo 'TFDS metadata already exists.'; 17 | else 18 | echo 'TFDS metadata does not exist. Downloading...'; 19 | git clone --branch v4.8.2 --depth 3 --filter=blob:none --sparse https://github.com/tensorflow/datasets/ 20 | cd datasets 21 | git sparse-checkout set tensorflow_datasets/testing/metadata 22 | mkdir ../../.tfds 23 | mv tensorflow_datasets/testing/metadata/ ../../.tfds/metadata/ 24 | cd .. 25 | rm -rf datasets 26 | fi 27 | -------------------------------------------------------------------------------- /docs/api_reference/flax.linen/activation_functions.rst: -------------------------------------------------------------------------------- 1 | 2 | Activation functions 3 | ------------------------ 4 | 5 | .. automodule:: flax.linen.activation 6 | .. currentmodule:: flax.linen.activation 7 | 8 | .. autoclass:: PReLU 9 | :members: 10 | :special-members: __call__ 11 | 12 | .. autofunction:: celu 13 | .. autofunction:: elu 14 | .. autofunction:: gelu 15 | .. autofunction:: glu 16 | .. autofunction:: hard_sigmoid 17 | .. autofunction:: hard_silu 18 | .. autofunction:: hard_swish 19 | .. autofunction:: hard_tanh 20 | .. autofunction:: leaky_relu 21 | .. autofunction:: log_sigmoid 22 | .. autofunction:: log_softmax 23 | .. autofunction:: logsumexp 24 | .. autofunction:: one_hot 25 | .. autofunction:: relu 26 | .. autofunction:: relu6 as relu6, 27 | .. autofunction:: selu 28 | .. autofunction:: sigmoid 29 | .. autofunction:: silu 30 | .. autofunction:: soft_sign 31 | .. autofunction:: softmax 32 | .. autofunction:: softplus 33 | .. autofunction:: standardize 34 | .. autofunction:: swish 35 | .. autofunction:: tanh 36 | -------------------------------------------------------------------------------- /examples/vae/configs/default.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Default Hyperparameter configuration.""" 16 | 17 | import ml_collections 18 | 19 | 20 | def get_config(): 21 | """Get the default hyperparameter configuration.""" 22 | config = ml_collections.ConfigDict() 23 | 24 | config.learning_rate = 0.001 25 | config.latents = 20 26 | config.batch_size = 128 27 | config.num_epochs = 30 28 | return config 29 | -------------------------------------------------------------------------------- /examples/seq2seq/README.md: -------------------------------------------------------------------------------- 1 | ## seq2seq addition 2 | 3 | This example trains a simple LSTM on a sequence-to-sequence addition task using 4 | an encoder-decoder architecture. The data is generated on the fly. 5 | 6 | Colab lets you edit the source files and interact with the model: 7 | 8 | https://colab.research.google.com/github/google/flax/blob/main/examples/seq2seq/seq2seq.ipynb 9 | 10 | ### Example output 11 | 12 | From Colab run that also generated [tfhub.dev] 13 | 14 | ``` 15 | INFO:absl:[1800] accuracy=1.0, loss=0.0020284138154238462 16 | INFO:absl:DECODE: 14+381 = 395 (CORRECT) 17 | INFO:absl:DECODE: 68+91 = 159 (CORRECT) 18 | INFO:absl:DECODE: 0+807 = 707 (INCORRECT) correct=807 19 | INFO:absl:DECODE: 95+532 = 627 (CORRECT) 20 | INFO:absl:DECODE: 6+600 = 606 (CORRECT) 21 | ``` 22 | 23 | [tfhub.dev]: https://tensorboard.dev/experiment/TwvKVBqzTaKWgEbyebillw/#scalars&_smoothingWeight=0 24 | 25 | ### How to run 26 | 27 | `python train.py` 28 | 29 | The total runtime for 1200 steps on CPU (3.5GHz Intel Core i7, 16GB memory) is 30 | about 4 minutes. 31 | -------------------------------------------------------------------------------- /examples/mnist/configs/default.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Default Hyperparameter configuration.""" 16 | 17 | import ml_collections 18 | 19 | 20 | def get_config(): 21 | """Get the default hyperparameter configuration.""" 22 | config = ml_collections.ConfigDict() 23 | 24 | config.learning_rate = 0.1 25 | config.momentum = 0.9 26 | config.batch_size = 128 27 | config.num_epochs = 10 28 | return config 29 | 30 | 31 | def metrics(): 32 | return [] 33 | -------------------------------------------------------------------------------- /flaxlib_src/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | .pytest_cache/ 6 | *.py[cod] 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | .venv/ 14 | env/ 15 | bin/ 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | include/ 26 | man/ 27 | venv/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | pip-selfcheck.json 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .cache 42 | nosetests.xml 43 | coverage.xml 44 | 45 | # Translations 46 | *.mo 47 | 48 | # Mr Developer 49 | .mr.developer.cfg 50 | .project 51 | .pydevproject 52 | 53 | # Rope 54 | .ropeproject 55 | 56 | # Django stuff: 57 | *.log 58 | *.pot 59 | 60 | .DS_Store 61 | 62 | # Sphinx documentation 63 | docs/_build/ 64 | 65 | # PyCharm 66 | .idea/ 67 | 68 | # VSCode 69 | .vscode/ 70 | 71 | # Pyenv 72 | .python-version 73 | 74 | # cibuildwheel 75 | /wheelhouse -------------------------------------------------------------------------------- /docs_nnx/api_reference/flax.nnx/graph.rst: -------------------------------------------------------------------------------- 1 | graph 2 | ------------------------ 3 | 4 | .. automodule:: flax.nnx 5 | .. currentmodule:: flax.nnx 6 | 7 | 8 | .. autofunction:: split 9 | .. autofunction:: merge 10 | .. autofunction:: update 11 | .. autofunction:: pop 12 | .. autofunction:: state 13 | .. autofunction:: variables 14 | .. autofunction:: graph 15 | .. autofunction:: graphdef 16 | .. autofunction:: iter_graph 17 | .. autofunction:: recursive_map 18 | .. autofunction:: clone 19 | .. autofunction:: call 20 | .. autofunction:: set_metadata 21 | .. autofunction:: cached_partial 22 | 23 | .. autoclass:: GraphDef 24 | :members: 25 | 26 | .. autoclass:: UpdateContext 27 | :members: 28 | 29 | .. autofunction:: update_context 30 | .. autofunction:: current_update_context 31 | 32 | .. autofunction:: find_duplicates 33 | .. autofunction:: pure 34 | .. autofunction:: as_immutable_vars 35 | .. autofunction:: as_mutable_vars 36 | .. autofunction:: as_hijax_vars 37 | .. autofunction:: as_pytree_vars 38 | .. autofunction:: as_ref_vars 39 | .. autofunction:: as_array_vars 40 | .. autofunction:: flatten 41 | .. autofunction:: unflatten 42 | -------------------------------------------------------------------------------- /examples/seq2seq/configs/default.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Default Hyperparameter configuration.""" 16 | 17 | import ml_collections 18 | 19 | 20 | def get_config(): 21 | """Get the default hyperparameter configuration.""" 22 | config = ml_collections.ConfigDict() 23 | 24 | config.workdir = '/tmp/seq2seq' 25 | config.learning_rate = 0.003 26 | config.batch_size = 128 27 | config.hidden_size = 512 28 | config.num_train_steps = 10000 29 | config.decode_frequency = 200 30 | config.max_len_query_digit = 3 31 | 32 | return config 33 | -------------------------------------------------------------------------------- /tests/nnx/ids_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import copy 16 | 17 | from absl.testing import absltest 18 | from flax.nnx import ids 19 | 20 | 21 | class TestIds(absltest.TestCase): 22 | def test_hashable(self): 23 | id1 = ids.uuid() 24 | id2 = ids.uuid() 25 | assert id1 == id1 26 | assert id1 != id2 27 | assert hash(id1) != hash(id2) 28 | id1c = copy.copy(id1) 29 | id1dc = copy.deepcopy(id1) 30 | assert hash(id1) != hash(id1c) 31 | assert hash(id1) != hash(id1dc) 32 | 33 | 34 | if __name__ == '__main__': 35 | absltest.main() 36 | -------------------------------------------------------------------------------- /docs_nnx/examples/core_examples.rst: -------------------------------------------------------------------------------- 1 | Core examples 2 | ============= 3 | 4 | Core examples are hosted on the GitHub Flax repository in the `examples `__ 5 | directory. 6 | 7 | Each example is designed to be **self-contained and easily forkable**, while 8 | reproducing relevant results in different areas of machine learning. 9 | 10 | Some of the examples below have a link "Interactive🕹" that lets you run them 11 | directly in Colab. 12 | 13 | Transformers 14 | ******************** 15 | 16 | - :octicon:`mark-github;0.9em` `Gemma `__ : 17 | A family of open-weights Large Language Model (LLM) by Google DeepMind, based on Gemini research and technology. 18 | 19 | - :octicon:`mark-github;0.9em` `LM1B `__ : 20 | Transformer encoder trained on the One Billion Word Benchmark. 21 | 22 | Toy examples 23 | ******************** 24 | 25 | `NNX toy examples `__ 26 | directory contains a few smaller, standalone toy examples for simple training scenarios. 27 | -------------------------------------------------------------------------------- /examples/imagenet/configs/v100_x8.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Hyperparameter configuration to run the example on 8 x Nvidia V100 GPUs.""" 16 | 17 | from configs import default as default_lib 18 | 19 | 20 | def get_config(): 21 | """Get the hyperparameter configuration to train on 8 x Nvidia V100 GPUs.""" 22 | # Override default configuration to avoid duplication of field definition. 23 | config = default_lib.get_config() 24 | 25 | config.batch_size = 512 26 | config.shuffle_buffer_size = 16 * 512 27 | config.cache = True 28 | 29 | return config 30 | 31 | 32 | metrics = default_lib.metrics 33 | -------------------------------------------------------------------------------- /examples/nlp_seq/README.md: -------------------------------------------------------------------------------- 1 | ## Part-of-Speech Tagging 2 | Trains a simple sequence-based part-of-speech tagger. The following sentence 3 | shows an example. 4 | 5 | ``` 6 | From|ADP the|DT AP|PROPN comes|VBZ this|DT story|NN :|: 7 | ``` 8 | 9 | ### Requirements 10 | * Universal Dependency data sets: https://universaldependencies.org/#download. 11 | 12 | Download via command line: 13 | 14 | ``` 15 | curl -# -o ud-treebanks-v2.0.tgz https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-1976/ud-treebanks-v2.0.tgz 16 | tar xzf ud-treebanks-v2.0.tgz 17 | ``` 18 | 19 | ### Supported setups 20 | The model should run with other configurations and hardware, but explicitly tested on the following. 21 | 22 | | Hardware | Batch size | Learning rate | Training time | Accuracy | TensorBoard.dev | 23 | |:---:|:---:|:---:|:---:|:---:|:---:| 24 | | Nvidia Titan V (12GB) | 64 | 0.05 | 5:58h | 68.6% | [2022-05-01](https://tensorboard.dev/experiment/F5ULHlyzQlieVJn5PG8mRQ/) | 25 | 26 | ### Running 27 | ``` 28 | python train.py --batch_size=64 --model_dir=./ancient_greek \ 29 | --dev=ud-treebanks-v2.0/UD_Ancient_Greek/grc-ud-dev.conllu \ 30 | --train=ud-treebanks-v2.0/UD_Ancient_Greek/grc-ud-train.conllu 31 | ``` 32 | -------------------------------------------------------------------------------- /examples/imagenet/configs/tpu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Hyperparameter configuration to run the example on TPUs.""" 15 | 16 | from configs import default as default_lib 17 | 18 | 19 | def get_config(): 20 | """Get the hyperparameter configuration to train on TPUs.""" 21 | config = default_lib.get_config() 22 | 23 | # Consider setting the batch size to max(tpu_chips * 256, 8 * 1024) if you 24 | # train on a larger pod slice. 25 | config.batch_size = 1024 26 | config.shuffle_buffer_size = 16 * 1024 27 | config.cache = True 28 | config.half_precision = True 29 | 30 | return config 31 | 32 | 33 | metrics = default_lib.metrics 34 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | # What does this PR do? 2 | 3 | 18 | 19 | Fixes # (issue) 20 | 21 | ## Checklist 22 | - [ ] This PR fixes a minor issue (e.g.: typo or small bug) or improves the docs (you can dismiss the other checks if that's the case). 23 | - [ ] This change is discussed in a Github issue/[discussion](https://github.com/google/flax/discussions) (please add a link). 24 | - [ ] The documentation and docstrings adhere to the [documentation guidelines](https://github.com/google/flax/blob/main/docs/README.md#how-to-write-code-documentation). 25 | - [ ] This change includes necessary high-coverage tests. (No quality testing = no merge!) 26 | -------------------------------------------------------------------------------- /examples/imagenet/configs/v100_x8_mixed_precision.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Hyperparameter configuration to run the example on 8 x Nvidia V100 GPUs.""" 16 | 17 | from configs import default as default_lib 18 | 19 | 20 | def get_config(): 21 | """Get the hyperparameter configuration to train on 8 x Nvidia V100 GPUs.""" 22 | # Override default configuration to avoid duplication of field definition. 23 | config = default_lib.get_config() 24 | 25 | config.batch_size = 2048 26 | config.shuffle_buffer_size = 16 * 2048 27 | config.cache = True 28 | config.half_precision = True 29 | 30 | return config 31 | 32 | 33 | metrics = default_lib.metrics 34 | -------------------------------------------------------------------------------- /examples/imagenet/configs/fake_data_benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Hyperparameter configuration for Fake data benchmark.""" 16 | 17 | import jax 18 | 19 | from configs import default as default_lib 20 | 21 | 22 | def get_config(): 23 | """Get the hyperparameter configuration for Fake data benchmark.""" 24 | # Override default configuration to avoid duplication of field definition. 25 | config = default_lib.get_config() 26 | config.batch_size = 256 * jax.device_count() 27 | config.half_precision = True 28 | config.num_epochs = 5 29 | 30 | # Run for a single step: 31 | config.num_train_steps = 1 32 | config.steps_per_eval = 1 33 | 34 | return config 35 | -------------------------------------------------------------------------------- /examples/vae/README.md: -------------------------------------------------------------------------------- 1 | # Basic VAE Example 2 | 3 | This is an implementation of the paper [Auto-Encoding with Variational Bayes](http://arxiv.org/abs/1312.6114) by D.P.Kingma and M.Welling. 4 | This code follows [pytorch/examples/vae](https://github.com/pytorch/examples/blob/master/vae/README.md). 5 | 6 | ```bash 7 | pip install -r requirements.txt 8 | python main.py --workdir=/tmp/mnist --config=configs/default.py 9 | ``` 10 | 11 | ## Overriding Hyperparameter configurations 12 | 13 | This VAE example allows specifying a hyperparameter configuration by the means of 14 | setting `--config` flag. Configuration flag is defined using 15 | [config_flags](https://github.com/google/ml_collections/tree/master#config-flags). 16 | `config_flags` allows overriding configuration fields. This can be done as 17 | follows: 18 | 19 | ```shell 20 | python main.py \ 21 | --workdir=/tmp/mnist --config=configs/default.py \ 22 | --config.learning_rate=0.01 --config.num_epochs=10 23 | ``` 24 | 25 | 26 | ## Examples 27 | 28 | If you run the code by above command, you can get some generated images: 29 | 30 | ![generated_mnist](./sample.png) 31 | 32 | and reconstructions of test set digits: 33 | 34 | ![reconstruction_mnist](./reconstruction.png) 35 | 36 | The test set loss after 10 epochs should be around `104`. 37 | -------------------------------------------------------------------------------- /flax/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Flax API.""" 16 | 17 | # pylint: disable=g-import-not-at-top 18 | # pyformat: disable 19 | 20 | from flax import configurations 21 | config: configurations.Config = configurations.config 22 | del configurations 23 | 24 | from flax import core 25 | from flax import jax_utils 26 | from flax import linen 27 | from flax import serialization 28 | from flax import traverse_util 29 | 30 | from flax import version 31 | __version__: str = version.__version__ 32 | del version 33 | 34 | # DO NOT REMOVE - Marker for internal deprecated API. 35 | 36 | # DO NOT REMOVE - Marker for internal logging. 37 | 38 | # pyformat: enable 39 | # pylint: enable=g-import-not-at-top 40 | -------------------------------------------------------------------------------- /flax/core/tracers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Functionality for inspecting jax tracers.""" 16 | 17 | import jax 18 | import jax.core 19 | 20 | 21 | def current_trace(): 22 | """Returns the current JAX state tracer.""" 23 | if jax.__version_info__ <= (0, 4, 33): 24 | top = jax.core.find_top_trace(()) 25 | if top: 26 | return top.level 27 | else: 28 | return float('-inf') 29 | 30 | return jax.core.get_opaque_trace_state(convention="flax") 31 | 32 | def check_trace_level(base_level): 33 | # TODO(cgarciae): skipping for now as it breaks 34 | # too many internal tests. 35 | # level = current_trace() 36 | # if level != base_level: 37 | # raise errors.JaxTransformError() 38 | pass 39 | -------------------------------------------------------------------------------- /examples/linen_design_test/dense.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax import lax 16 | from flax.linen import initializers 17 | from collections.abc import Callable 18 | from flax.linen import Module, compact 19 | 20 | 21 | class Dense(Module): 22 | features: int 23 | kernel_init: Callable = initializers.lecun_normal() 24 | bias_init: Callable = initializers.zeros_init() 25 | use_bias: bool = True 26 | 27 | @compact 28 | def __call__(self, inputs): 29 | kernel = self.param( 30 | 'kernel', self.kernel_init, (inputs.shape[-1], self.features) 31 | ) 32 | y = lax.dot_general( 33 | inputs, 34 | kernel, 35 | (((inputs.ndim - 1,), (0,)), ((), ())), 36 | ) 37 | if self.use_bias: 38 | bias = self.param('bias', self.bias_init, (self.features,)) 39 | y = y + bias 40 | return y 41 | -------------------------------------------------------------------------------- /examples/sst2/configs/default.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Default hyperparameter configuration for SST-2.""" 16 | 17 | import ml_collections 18 | 19 | 20 | def get_config(): 21 | """Get the default hyperparameter configuration.""" 22 | config = ml_collections.ConfigDict() 23 | 24 | config.embedding_size = 300 25 | config.hidden_size = 256 26 | config.vocab_size = None 27 | config.output_size = 1 28 | 29 | config.vocab_path = 'vocab.txt' 30 | config.max_input_length = 60 31 | 32 | config.dropout_rate = 0.5 33 | config.word_dropout_rate = 0.1 34 | config.unk_idx = 1 35 | 36 | config.learning_rate = 0.1 37 | config.momentum = 0.9 38 | config.weight_decay = 3e-6 39 | 40 | config.batch_size = 64 41 | config.bucket_size = 8 42 | config.num_epochs = 10 43 | 44 | config.seed = 0 45 | 46 | return config 47 | -------------------------------------------------------------------------------- /tests/pickle_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for flax.errors.""" 16 | 17 | from absl.testing import absltest 18 | from flax.errors import FlaxError, ScopeVariableNotFoundError 19 | import pickle 20 | 21 | class ErrorrsTest(absltest.TestCase): 22 | def test_exception_can_be_pickled(self): 23 | # tests the new __reduce__ method fixes bug reported in issue #4000 24 | ex = ScopeVariableNotFoundError('varname', 'collection', 'scope') 25 | pickled_ex = pickle.dumps(ex) 26 | unpicked_ex = pickle.loads(pickled_ex) 27 | self.assertIsInstance(unpicked_ex, FlaxError) 28 | self.assertIn('varname', str(unpicked_ex)) 29 | self.assertIn('#flax.errors.ScopeVariableNotFoundError', str(unpicked_ex)) 30 | self.assertNotIn('#flax.errors.FlaxError', str(unpicked_ex)) 31 | 32 | 33 | if __name__ == '__main__': 34 | absltest.main() 35 | -------------------------------------------------------------------------------- /examples/nlp_seq/configs/default.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Default hyperparameters for NLP sequence tagging.""" 16 | 17 | import ml_collections 18 | 19 | 20 | def get_config(): 21 | """Get the default hyperparameter configuration.""" 22 | config = ml_collections.ConfigDict() 23 | 24 | # Model directory for checkpoints and logs 25 | config.model_dir = '' 26 | 27 | # Experiment name 28 | config.experiment = 'xpos' 29 | 30 | # Training hyperparameters 31 | config.batch_size = 64 32 | config.num_train_steps = 75000 33 | config.eval_frequency = 100 34 | 35 | # Optimizer hyperparameters 36 | config.learning_rate = 0.05 37 | config.weight_decay = 1e-1 38 | 39 | # Model hyperparameters 40 | config.max_length = 256 41 | 42 | # Random seed 43 | config.random_seed = 0 44 | 45 | # Data paths 46 | config.train = '' 47 | config.dev = '' 48 | 49 | return config 50 | -------------------------------------------------------------------------------- /docs/api_reference/flax.training.rst: -------------------------------------------------------------------------------- 1 | 2 | flax.training package 3 | ===================== 4 | 5 | Checkpoints 6 | ------------------------ 7 | 8 | .. currentmodule:: flax.training.checkpoints 9 | 10 | .. automodule:: flax.training.checkpoints 11 | 12 | .. autofunction:: save_checkpoint 13 | 14 | .. autofunction:: save_checkpoint_multiprocess 15 | 16 | .. autofunction:: latest_checkpoint 17 | 18 | .. autofunction:: restore_checkpoint 19 | 20 | .. autofunction:: convert_pre_linen 21 | 22 | Learning rate schedules 23 | ------------------------ 24 | 25 | .. currentmodule:: flax.training.lr_schedule 26 | 27 | .. automodule:: flax.training.lr_schedule 28 | 29 | .. autofunction:: create_constant_learning_rate_schedule 30 | 31 | .. autofunction:: create_stepped_learning_rate_schedule 32 | 33 | .. autofunction:: create_cosine_learning_rate_schedule 34 | 35 | Train state 36 | ------------------------ 37 | 38 | .. currentmodule:: flax.training.train_state 39 | 40 | .. autoclass:: TrainState 41 | :members: apply_gradients, create 42 | 43 | Early Stopping 44 | ------------------------ 45 | 46 | .. currentmodule:: flax.training.early_stopping 47 | 48 | .. autoclass:: EarlyStopping 49 | :members: reset, update 50 | 51 | Common Utilities 52 | ------------------------ 53 | 54 | .. currentmodule:: flax.training.common_utils 55 | 56 | .. autofunction:: shard 57 | 58 | .. autofunction:: shard_prng_key 59 | 60 | .. autofunction:: stack_forest 61 | 62 | .. autofunction:: get_metrics 63 | 64 | .. autofunction:: onehot 65 | -------------------------------------------------------------------------------- /flax/nnx/bridge/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from .wrappers import functional as functional 17 | from .wrappers import Functional as Functional 18 | from .wrappers import ToNNX as ToNNX 19 | from .wrappers import lazy_init as lazy_init 20 | from .wrappers import ToLinen as ToLinen 21 | from .wrappers import to_linen as to_linen 22 | from .variables import NNXMeta as NNXMeta 23 | from .variables import with_partitioning as with_partitioning 24 | from .module import Module as Module 25 | from .module import Scope as Scope 26 | from .module import AttrPriority as AttrPriority 27 | from .module import compact as compact 28 | from .module import current_context as current_context 29 | from .module import current_module as current_module 30 | from .interop import nnx_in_bridge_mdl as nnx_in_bridge_mdl 31 | from .interop import linen_in_bridge_mdl as linen_in_bridge_mdl 32 | from flax.nnx.nn import initializers as initializers 33 | -------------------------------------------------------------------------------- /docs/flip/README.md: -------------------------------------------------------------------------------- 1 | # FLIP: Flax Improvement Process 2 | 3 | Most changes can be discussed with simple issues/discussions and pull requests. 4 | 5 | Some changes though are a bit larger in scope or require more discussion, and 6 | these should be implemented as FLIPs. This allows for writing longer documents 7 | that can be discussed in a pull request themselves. 8 | 9 | The structure of FLIPs is kept as lightweight as possible to start and might 10 | be extended later on. 11 | 12 | ## When you should use a FLIP 13 | 14 | - When your change requires a design doc. We prefer collecting the designs as 15 | FLIPs for better discoverability and further reference. 16 | 17 | - When your change requires extensive discussion. It's fine to have relatively 18 | short discussions on issues or pull requests, but when the discussion gets 19 | longer this becomes unpractical for later digestion. FLIPs allow to update the 20 | main document with a summary of the discussion and these updates can be 21 | discussed themselves in the pull request adding the FLIP. 22 | 23 | ## How to start a FLIP 24 | 25 | First, create an issue with the [FLIP label]. All pull requests that relate to 26 | the FLIP (i.e. adding the FLIP itself as well as any implementing pull requests) 27 | should be linked to this issue. 28 | 29 | Then create a pull request that consists of a copy of the `0000-template.md` 30 | renamed to `%04d-{short-title}.md` - with the number being the issue number. 31 | 32 | [FLIP label]: https://github.com/google/flax/issues?q=label%3AFLIP 33 | -------------------------------------------------------------------------------- /docs_nnx/flip/README.md: -------------------------------------------------------------------------------- 1 | # FLIP: Flax Improvement Process 2 | 3 | Most changes can be discussed with simple issues/discussions and pull requests. 4 | 5 | Some changes though are a bit larger in scope or require more discussion, and 6 | these should be implemented as FLIPs. This allows for writing longer documents 7 | that can be discussed in a pull request themselves. 8 | 9 | The structure of FLIPs is kept as lightweight as possible to start and might 10 | be extended later on. 11 | 12 | ## When you should use a FLIP 13 | 14 | - When your change requires a design doc. We prefer collecting the designs as 15 | FLIPs for better discoverability and further reference. 16 | 17 | - When your change requires extensive discussion. It's fine to have relatively 18 | short discussions on issues or pull requests, but when the discussion gets 19 | longer this becomes unpractical for later digestion. FLIPs allow to update the 20 | main document with a summary of the discussion and these updates can be 21 | discussed themselves in the pull request adding the FLIP. 22 | 23 | ## How to start a FLIP 24 | 25 | First, create an issue with the [FLIP label]. All pull requests that relate to 26 | the FLIP (i.e. adding the FLIP itself as well as any implementing pull requests) 27 | should be linked to this issue. 28 | 29 | Then create a pull request that consists of a copy of the `0000-template.md` 30 | renamed to `%04d-{short-title}.md` - with the number being the issue number. 31 | 32 | [FLIP label]: https://github.com/google/flax/issues?q=label%3AFLIP 33 | -------------------------------------------------------------------------------- /tests/nnx/filters_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from absl.testing import absltest 16 | 17 | from flax import nnx 18 | 19 | 20 | class TestFilters(absltest.TestCase): 21 | def test_path_contains(self): 22 | class Model(nnx.Module): 23 | def __init__(self, rngs): 24 | self.backbone1 = nnx.Linear(2, 3, rngs=rngs) 25 | self.backbone2 = nnx.Linear(3, 3, rngs=rngs) 26 | self.head = nnx.Linear(3, 10, rngs=rngs) 27 | 28 | model = Model(nnx.Rngs(0)) 29 | 30 | head_state = nnx.state(model, nnx.PathContains('head')) 31 | backbones_state = nnx.state(model, nnx.PathContains('backbone', exact=False)) 32 | 33 | self.assertIn('head', head_state) 34 | self.assertNotIn('backbone', head_state) 35 | self.assertIn('backbone1', backbones_state) 36 | self.assertIn('backbone2', backbones_state) 37 | self.assertNotIn('head', backbones_state) 38 | 39 | if __name__ == '__main__': 40 | absltest.main() 41 | -------------------------------------------------------------------------------- /examples/linen_design_test/linear_regression.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import jax 16 | from jax import numpy as jnp, jit 17 | from dense import Dense 18 | 19 | 20 | X = jnp.ones((1, 10)) 21 | Y = jnp.ones((5,)) 22 | 23 | model = Dense(features=5) 24 | 25 | 26 | @jit 27 | def predict(params): 28 | return model.apply({"params": params}, X) 29 | 30 | 31 | @jit 32 | def loss_fn(params): 33 | return jnp.mean(jnp.abs(Y - predict(params))) 34 | 35 | 36 | @jit 37 | def init_params(rng): 38 | mlp_variables = model.init({"params": rng}, X) 39 | return mlp_variables["params"] 40 | 41 | 42 | # Get initial parameters 43 | params = init_params(jax.random.key(42)) 44 | print("initial params", params) 45 | 46 | # Run SGD. 47 | for i in range(50): 48 | loss, grad = jax.value_and_grad(loss_fn)(params) 49 | print(i, "loss = ", loss, "Yhat = ", predict(params)) 50 | lr = 0.03 51 | params = jax.tree_util.tree_map(lambda x, d: x - lr * d, params, grad) 52 | -------------------------------------------------------------------------------- /docs_nnx/hijax/index.rst: -------------------------------------------------------------------------------- 1 | Hijax (experimental) 2 | ==================== 3 | 4 | 5 | 6 | ---- 7 | 8 | Basic usage 9 | ^^^^^^^^^^^^ 10 | 11 | .. testsetup:: 12 | 13 | import jax 14 | import jax.numpy as jnp 15 | 16 | current_mode = nnx.using_hijax() 17 | 18 | .. testcode:: 19 | 20 | from flax import nnx 21 | import optax 22 | 23 | nnx.use_hijax(True) 24 | 25 | class Model(nnx.Module): 26 | def __init__(self, din, dmid, dout, rngs: nnx.Rngs): 27 | self.linear = nnx.Linear(din, dmid, rngs=rngs) 28 | self.bn = nnx.BatchNorm(dmid, rngs=rngs) 29 | self.dropout = nnx.Dropout(0.2) 30 | self.linear_out = nnx.Linear(dmid, dout, rngs=rngs) 31 | 32 | def __call__(self, x, rngs): 33 | x = nnx.relu(self.dropout(self.bn(self.linear(x)), rngs=rngs)) 34 | return self.linear_out(x) 35 | 36 | model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization 37 | optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) 38 | 39 | @jax.jit 40 | def train_step(model, optimizer, rngs, x, y): 41 | graphdef, params, nondiff = nnx.split(model, nnx.Param, ...) 42 | def loss_fn(params): 43 | model = nnx.merge(graphdef, params, nondiff) 44 | return ((model(x, rngs) - y) ** 2).mean() 45 | loss, grads = jax.value_and_grad(loss_fn)(nnx.as_immutable_vars(params)) 46 | optimizer.update(model, grads) # in-place updates 47 | return loss 48 | 49 | nnx.use_hijax(current_mode) # clean up for CI tests 50 | 51 | 52 | ---- 53 | 54 | .. toctree:: 55 | :hidden: 56 | :maxdepth: 2 57 | 58 | hijax 59 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # Install the pre-commit hooks below with 2 | # 'pre-commit install' 3 | 4 | # Auto-update the version of the hooks with 5 | # 'pre-commit autoupdate' 6 | 7 | # Run the hooks on all files with 8 | # 'pre-commit run --all' 9 | 10 | repos: 11 | - repo: https://github.com/mwouts/jupytext 12 | rev: v1.13.8 13 | hooks: 14 | - id: jupytext 15 | args: [--sync] 16 | # diable pyink for now 17 | # - repo: https://github.com/google/pyink 18 | # rev: 23.5.0 19 | # hooks: 20 | # - id: pyink 21 | - repo: https://github.com/pre-commit/pre-commit-hooks 22 | rev: v5.0.0 23 | hooks: 24 | - id: check-toml 25 | - id: trailing-whitespace 26 | exclude: ^docs.*\.md$ 27 | - repo: https://github.com/kynan/nbstripout 28 | rev: 0.6.1 29 | hooks: 30 | - id: nbstripout 31 | exclude: ^examples/.* 32 | args: [ 33 | --keep-output, 34 | --keep-count, 35 | --extra-keys, 36 | "cell.metadata.executionInfo cell.metadata.id metadata.kernelspec metadata.vscode metadata.colab cell.metadata.executionInfo.user cell.metadata.executionInfo.user_tz cell.metadata.colab", 37 | ] 38 | - repo: https://github.com/astral-sh/ruff-pre-commit 39 | # Ruff version. 40 | rev: v0.1.3 41 | hooks: 42 | # Run the Ruff linter. 43 | - id: ruff 44 | args: [--fix, --exit-non-zero-on-fix] 45 | # Disable Ruff formatter for now 46 | # # Run the Ruff formatter. 47 | # - id: ruff-format 48 | - repo: https://github.com/asottile/pyupgrade 49 | rev: v3.16.0 50 | hooks: 51 | - id: pyupgrade 52 | args: [--py310-plus] 53 | -------------------------------------------------------------------------------- /examples/ogbg_molpcba/configs/test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Defines a CPU-friendly test configuration.""" 16 | 17 | import ml_collections 18 | 19 | 20 | def get_config(): 21 | """Get the default hyperparameter configuration.""" 22 | config = ml_collections.ConfigDict() 23 | 24 | # Optimizer. 25 | config.optimizer = 'adam' 26 | config.learning_rate = 1e-3 27 | 28 | # Training hyperparameters. 29 | config.batch_size = 32 30 | config.num_train_steps = 10 31 | config.log_every_steps = 5 32 | config.eval_every_steps = 5 33 | config.checkpoint_every_steps = 5 34 | config.add_virtual_node = True 35 | config.add_undirected_edges = True 36 | config.add_self_loops = True 37 | 38 | # GNN hyperparameters. 39 | config.model = 'GraphConvNet' 40 | config.message_passing_steps = 5 41 | config.latent_size = 256 42 | config.dropout_rate = 0.1 43 | config.num_mlp_layers = 1 44 | config.num_classes = 128 45 | config.skip_connections = False 46 | config.layer_norm = False 47 | 48 | return config 49 | -------------------------------------------------------------------------------- /examples/vae/input_pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Input pipeline for VAE dataset.""" 16 | 17 | import jax 18 | import jax.numpy as jnp 19 | import tensorflow as tf 20 | import tensorflow_datasets as tfds 21 | 22 | 23 | def build_train_set(batch_size, ds_builder): 24 | """Builds train dataset.""" 25 | 26 | train_ds = ds_builder.as_dataset(split=tfds.Split.TRAIN) 27 | train_ds = train_ds.map(prepare_image) 28 | train_ds = train_ds.cache() 29 | train_ds = train_ds.repeat() 30 | train_ds = train_ds.shuffle(50000) 31 | train_ds = train_ds.batch(batch_size) 32 | train_ds = iter(tfds.as_numpy(train_ds)) 33 | return train_ds 34 | 35 | 36 | def build_test_set(ds_builder): 37 | """Builds train dataset.""" 38 | test_ds = ds_builder.as_dataset(split=tfds.Split.TEST) 39 | test_ds = test_ds.map(prepare_image).batch(10000) 40 | test_ds = jnp.array(list(test_ds)[0]) 41 | test_ds = jax.device_put(test_ds) 42 | return test_ds 43 | 44 | 45 | def prepare_image(x): 46 | x = tf.cast(x['image'], tf.float32) 47 | x = tf.reshape(x, (-1,)) 48 | return x 49 | -------------------------------------------------------------------------------- /flax/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .axes_scan import broadcast as broadcast 16 | from .frozen_dict import ( 17 | FrozenDict as FrozenDict, 18 | copy as copy, 19 | freeze as freeze, 20 | pop as pop, 21 | pretty_repr as pretty_repr, 22 | unfreeze as unfreeze, 23 | ) 24 | from .lift import ( 25 | custom_vjp as custom_vjp, 26 | jit as jit, 27 | jvp as jvp, 28 | remat_scan as remat_scan, 29 | remat as remat, 30 | scan as scan, 31 | vjp as vjp, 32 | vmap as vmap, 33 | while_loop as while_loop, 34 | ) 35 | from .meta import ( 36 | AxisMetadata as AxisMetadata, 37 | map_axis_meta as map_axis_meta, 38 | unbox as unbox, 39 | ) 40 | from .scope import ( 41 | DenyList as DenyList, 42 | Scope as Scope, 43 | apply as apply, 44 | bind as bind, 45 | init as init, 46 | lazy_init as lazy_init, 47 | ) 48 | from .tracers import ( 49 | check_trace_level as check_trace_level, 50 | current_trace as current_trace, 51 | ) 52 | 53 | from flax.typing import ( 54 | Array as Array, 55 | ) 56 | -------------------------------------------------------------------------------- /examples/lm1b/temperature_sampler_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from absl.testing import absltest 16 | import jax 17 | import jax.numpy as jnp 18 | import numpy as np 19 | 20 | from temperature_sampler import temperature_sample 21 | 22 | 23 | jax.config.update('jax_disable_most_optimizations', True) 24 | 25 | 26 | class TestTemperatureSampler(absltest.TestCase): 27 | 28 | def test_temperature_sampler(self): 29 | tokens = jnp.array([[5, 0, 0, 0]], dtype=jnp.int32) 30 | cache = None 31 | key = jax.random.PRNGKey(0) 32 | 33 | def tokens_to_logits(tokens, cache): 34 | jax.debug.print('tokens: {}', tokens) 35 | logits = jax.nn.one_hot(tokens[..., -1:] + 1, 10) 36 | logits = jnp.where(logits < 0.5, float('-inf'), logits) 37 | logits = logits.squeeze(axis=1) 38 | return logits, cache 39 | 40 | new_tokens = temperature_sample( 41 | tokens, cache, tokens_to_logits, key, topk=5 42 | ) 43 | 44 | np.testing.assert_array_equal(new_tokens, [[5, 6, 7, 8]]) 45 | 46 | 47 | if __name__ == '__main__': 48 | absltest.main() 49 | -------------------------------------------------------------------------------- /flax/core/nn/stochastic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Stochastic modules.""" 16 | 17 | import jax.numpy as jnp 18 | from jax import lax, random 19 | 20 | 21 | def dropout(scope, inputs, rate, deterministic=False, rng=None): 22 | """Applies a random dropout mask to the input. 23 | Args: 24 | inputs: the inputs that should be randomly masked. 25 | rate: the probablity of masking out a value. 26 | deterministic: if false the inputs are scaled by `1 / (1 - rate)` and 27 | masked, whereas if true, no mask is applied and the inputs are returned as 28 | is. 29 | rng: an optional `jax.random.PRNGKey`. By default `nn.make_rng()` will 30 | be used. 31 | Returns: 32 | The masked inputs. 33 | """ 34 | if rate == 0.0: 35 | return inputs 36 | keep_prob = 1.0 - rate 37 | 38 | if deterministic: 39 | return inputs 40 | else: 41 | if rng is None: 42 | rng = scope.make_rng('dropout') 43 | mask = random.bernoulli(rng, p=keep_prob, shape=inputs.shape) 44 | return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs)) 45 | -------------------------------------------------------------------------------- /flax/core/variables.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A variable dict is a normal Python dictionary, which is a container for one 16 | or more "variable collections", each of which are nested dictionaries whose 17 | leaves are ``jax.numpy`` arrays. 18 | 19 | The different variable collections share the same nested tree structure. 20 | 21 | For example, consider the following variable dictionary:: 22 | 23 | { 24 | "params": { 25 | "Conv1": { "weight": ..., "bias": ... }, 26 | "BatchNorm1": { "scale": ..., "mean": ... }, 27 | "Conv2": {...} 28 | }, 29 | "batch_stats": { 30 | "BatchNorm1": { "moving_mean": ..., "moving_average": ...} 31 | } 32 | } 33 | 34 | In this case, the ``"BatchNorm1"`` key lives in both the ``"params"`` and 35 | ```"batch_stats""`` collections. This reflects the fact that the submodule 36 | named ``""BatchNorm1""`` has both trainable parameters (the ``"params"`` collection), 37 | as well as other non-trainable variables (the ``"batch_stats"`` collection) 38 | 39 | TODO: Make "variable dict" design note, and link to it from here. 40 | """ 41 | 42 | from .scope import Variable 43 | -------------------------------------------------------------------------------- /examples/seq2seq/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Main script for seq2seq example.""" 16 | 17 | from absl import app 18 | from absl import flags 19 | from absl import logging 20 | import train 21 | from ml_collections import config_flags 22 | 23 | FLAGS = flags.FLAGS 24 | 25 | config_flags.DEFINE_config_file( 26 | 'config', 27 | None, 28 | 'File path to the training hyperparameter configuration.', 29 | lock_config=True, 30 | ) 31 | 32 | 33 | def main(argv): 34 | del argv 35 | 36 | config = FLAGS.config 37 | 38 | # Set train.FLAGS values from config 39 | train.FLAGS.workdir = config.workdir 40 | train.FLAGS.learning_rate = config.learning_rate 41 | train.FLAGS.batch_size = config.batch_size 42 | train.FLAGS.hidden_size = config.hidden_size 43 | train.FLAGS.num_train_steps = config.num_train_steps 44 | train.FLAGS.decode_frequency = config.decode_frequency 45 | train.FLAGS.max_len_query_digit = config.max_len_query_digit 46 | 47 | logging.info('Starting training with config: %s', config) 48 | _ = train.train_and_evaluate(train.FLAGS.workdir) 49 | 50 | 51 | if __name__ == '__main__': 52 | app.run(main) 53 | -------------------------------------------------------------------------------- /examples/ogbg_molpcba/configs/default.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Defines the default hyperparameters and training configuration. 16 | 17 | Uses a Graph Convolutional Network model (https://arxiv.org/abs/1609.02907). 18 | """ 19 | 20 | import ml_collections 21 | 22 | 23 | def get_config(): 24 | """Get the default hyperparameter configuration.""" 25 | config = ml_collections.ConfigDict() 26 | 27 | # Optimizer. 28 | config.optimizer = 'adam' 29 | config.learning_rate = 1e-3 30 | 31 | # Training hyperparameters. 32 | config.batch_size = 256 33 | config.num_train_steps = 100_000 34 | config.log_every_steps = 100 35 | config.eval_every_steps = 1_000 36 | config.checkpoint_every_steps = 10_000 37 | config.add_virtual_node = False 38 | config.add_undirected_edges = True 39 | config.add_self_loops = True 40 | 41 | # GNN hyperparameters. 42 | config.model = 'GraphConvNet' 43 | config.message_passing_steps = 5 44 | config.latent_size = 256 45 | config.dropout_rate = 0.1 46 | config.num_mlp_layers = 2 47 | config.num_classes = 128 48 | config.skip_connections = True 49 | config.layer_norm = True 50 | return config 51 | -------------------------------------------------------------------------------- /tests/linen/linen_activation_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for flax.linen.activation.""" 16 | 17 | from absl.testing import absltest 18 | from flax import linen as nn 19 | import jax 20 | from jax import random 21 | import jax.numpy as jnp 22 | import numpy as np 23 | 24 | 25 | # Parse absl flags test_srcdir and test_tmpdir. 26 | jax.config.parse_flags_with_absl() 27 | 28 | 29 | class ActivationTest(absltest.TestCase): 30 | 31 | def test_prelu(self): 32 | rng = random.key(0) 33 | key, skey_1, skey_2 = jax.random.split(rng, 3) 34 | x = jax.random.uniform(skey_1, (4, 6, 5)) - 0.5 35 | act = nn.PReLU() 36 | y, params = act.init_with_output(skey_2, x) 37 | expected_y = jnp.where(x < 0, x * act.negative_slope_init, x) 38 | init_negative_slope = params['params']['negative_slope'] 39 | expected_negative_slope = jnp.array( 40 | act.negative_slope_init, dtype=jnp.float32 41 | ) 42 | 43 | self.assertEqual(y.shape, x.shape) 44 | np.testing.assert_array_almost_equal(expected_y, y) 45 | np.testing.assert_array_equal(init_negative_slope, expected_negative_slope) 46 | 47 | 48 | if __name__ == '__main__': 49 | absltest.main() 50 | -------------------------------------------------------------------------------- /examples/ogbg_molpcba/configs/default_graph_net.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Defines the default hyperparameters and training configuration. 16 | 17 | Uses a GraphNetwork model (https://arxiv.org/abs/1806.01261). 18 | """ 19 | 20 | import ml_collections 21 | 22 | 23 | def get_config(): 24 | """Get the hyperparameter configuration for the GraphNetwork model.""" 25 | config = ml_collections.ConfigDict() 26 | 27 | # Optimizer. 28 | config.optimizer = 'adam' 29 | config.learning_rate = 1e-3 30 | 31 | # Training hyperparameters. 32 | config.batch_size = 256 33 | config.num_train_steps = 100_000 34 | config.log_every_steps = 100 35 | config.eval_every_steps = 10_000 36 | config.checkpoint_every_steps = 10_000 37 | config.add_virtual_node = True 38 | config.add_undirected_edges = True 39 | config.add_self_loops = True 40 | 41 | # GNN hyperparameters. 42 | config.model = 'GraphNet' 43 | config.message_passing_steps = 5 44 | config.latent_size = 256 45 | config.dropout_rate = 0.1 46 | config.num_mlp_layers = 1 47 | config.num_classes = 128 48 | config.use_edge_model = True 49 | config.skip_connections = True 50 | config.layer_norm = True 51 | return config 52 | -------------------------------------------------------------------------------- /tests/linen/linen_dtypes_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for flax.linen.dtypes.""" 16 | 17 | 18 | from absl.testing import absltest 19 | from jax import numpy as jnp 20 | 21 | from flax.linen import dtypes 22 | 23 | try: 24 | # JAX v0.8.0 and newer 25 | from jax import enable_x64 26 | except ImportError: 27 | from jax.experimental import enable_x64 28 | 29 | default_float_dtype = jnp.result_type(1.0) 30 | 31 | 32 | class DtypesTest(absltest.TestCase): 33 | def test_no_inexact_dtype(self): 34 | i32 = jnp.int32(1.0) 35 | self.assertEqual(dtypes.canonicalize_dtype(i32, inexact=False), jnp.int32) 36 | 37 | def test_inexact_dtype(self): 38 | with enable_x64(): 39 | i64 = jnp.int64(1) 40 | self.assertEqual(dtypes.canonicalize_dtype(i64), jnp.float32) 41 | i32 = jnp.int32(1) 42 | self.assertEqual(dtypes.canonicalize_dtype(i32), jnp.float32) 43 | i16 = jnp.int16(1.0) 44 | self.assertEqual(dtypes.canonicalize_dtype(i16), jnp.float32) 45 | 46 | def test_explicit_downcast(self): 47 | f32 = jnp.float32(1.0) 48 | (x,) = dtypes.promote_dtype(f32, dtype=jnp.float16) 49 | self.assertEqual(x.dtype, jnp.float16) 50 | 51 | 52 | if __name__ == '__main__': 53 | absltest.main() 54 | -------------------------------------------------------------------------------- /examples/ppo/ppo_main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # See issue #620. 16 | # pytype: disable=wrong-keyword-args 17 | 18 | from absl import app 19 | from absl import flags 20 | from ml_collections import config_flags 21 | import tensorflow as tf 22 | 23 | import env_utils 24 | import models 25 | import ppo_lib 26 | import gymnasium as gym 27 | import ale_py 28 | 29 | gym.register_envs(ale_py) 30 | 31 | 32 | FLAGS = flags.FLAGS 33 | 34 | flags.DEFINE_string( 35 | 'workdir', 36 | default='/tmp/ppo_training', 37 | help='Directory to save checkpoints and logging info.', 38 | ) 39 | 40 | config_flags.DEFINE_config_file( 41 | 'config', 42 | 'configs/default.py', 43 | 'File path to the default configuration file.', 44 | lock_config=True, 45 | ) 46 | 47 | 48 | def main(argv): 49 | # Make sure tf does not allocate gpu memory. 50 | tf.config.experimental.set_visible_devices([], 'GPU') 51 | config = FLAGS.config 52 | game = config.game + 'NoFrameskip-v4' 53 | num_actions = env_utils.get_num_actions(game) 54 | print(f'Playing {game} with {num_actions} actions') 55 | model = models.ActorCritic(num_outputs=num_actions) 56 | ppo_lib.train(model, config, FLAGS.workdir) 57 | 58 | 59 | if __name__ == '__main__': 60 | app.run(main) 61 | -------------------------------------------------------------------------------- /benchmarks/tracing/mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The JAX Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """MNIST helper functions.""" 15 | 16 | from typing import Any 17 | 18 | from flax.examples.mnist import train 19 | import jax 20 | import jax.numpy as jnp 21 | import ml_collections 22 | 23 | 24 | def get_fake_batch(batch_size: int) -> tuple[Any, Any]: 25 | """Returns fake data for the given batch size. 26 | 27 | Args: 28 | batch_size: The global batch size to generate. 29 | 30 | Returns: 31 | A tuple of (images, labels) with fake data. 32 | """ 33 | rng = jax.random.PRNGKey(0) 34 | images = jax.random.normal(rng, (batch_size, 28, 28, 1), jnp.float32) 35 | labels = jax.random.randint(rng, (batch_size,), 0, 10, jnp.int32) 36 | return images, labels 37 | 38 | 39 | def get_apply_fn_and_args( 40 | config: ml_collections.ConfigDict, 41 | ) -> tuple[Any, tuple[Any, ...], dict[str, Any]]: 42 | """Returns the apply function and args for the given config. 43 | 44 | Args: 45 | config: The training configuration. 46 | 47 | Returns: 48 | A tuple of the apply function, args, kwargs, and any metadata. 49 | """ 50 | rng = jax.random.PRNGKey(0) 51 | state = train.create_train_state(rng, config) 52 | images, labels = get_fake_batch(config.batch_size) 53 | return train.apply_model, (state, images, labels), {} 54 | -------------------------------------------------------------------------------- /.github/workflows/jax_nightly.yml: -------------------------------------------------------------------------------- 1 | name: CI - with JAX nightly 2 | 3 | concurrency: 4 | group: ${{ github.workflow }}-${{ github.ref }} 5 | cancel-in-progress: true 6 | 7 | on: 8 | schedule: 9 | - cron: "0 12 * * *" # Daily at 12:00 UTC 10 | workflow_dispatch: # allows triggering the workflow run manually 11 | pull_request: # Automatically trigger on pull requests affecting this file 12 | branches: 13 | - main 14 | paths: 15 | - '**workflows/jax_nightly.yml' 16 | 17 | jobs: 18 | jax-nightly: 19 | runs-on: ubuntu-latest 20 | permissions: 21 | contents: read 22 | issues: write # for failed-build-issue 23 | strategy: 24 | fail-fast: false 25 | matrix: 26 | python-version: ["3.11"] 27 | steps: 28 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 29 | - name: Set up Python ${{ matrix.python-version }} 30 | id: setup_python 31 | uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 32 | with: 33 | python-version: ${{ matrix.python-version }} 34 | - name: Setup uv 35 | uses: astral-sh/setup-uv@887a942a15af3a7626099df99e897a18d9e5ab3a # v5.1.0 36 | with: 37 | version: "0.8.13" 38 | - name: Install dependencies 39 | run: | 40 | uv sync --extra testing --extra docs 41 | - name: Install JAX 42 | run: | 43 | uv pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ 44 | - name: Run test suite 45 | if: success() 46 | run: | 47 | uv run tests/run_all_tests.sh --only-pytest 48 | - name: Notify failed build 49 | uses: jayqi/failed-build-issue-action@1a893bbf43ef1c2a8705e2b115cd4f0fe3c5649b # v1.2.0 50 | if: failure() && github.event.pull_request == null 51 | with: 52 | github-token: ${{ secrets.GITHUB_TOKEN }} -------------------------------------------------------------------------------- /examples/mnist/README.md: -------------------------------------------------------------------------------- 1 | ## MNIST classification 2 | 3 | Trains a simple convolutional network on the MNIST dataset. 4 | 5 | You can run this code and even modify it directly in Google Colab, no 6 | installation required: 7 | 8 | https://colab.research.google.com/github/google/flax/blob/main/examples/mnist/mnist.ipynb 9 | 10 | ### Requirements 11 | * TensorFlow dataset `mnist` will be downloaded and prepared automatically, if necessary 12 | 13 | ### Example output 14 | 15 | | Name | Epochs | Walltime | Top-1 accuracy | Metrics | Workdir | 16 | | :------ | -----: | :------- | :------------- | :---------- | :---------------------------------------- | 17 | | default | 10 | 7.7m | 99.17% | [tfhub.dev] | [gs://flax_public/examples/mnist/default] | 18 | 19 | [tfhub.dev]: https://tensorboard.dev/experiment/1G9SvrW5RQyojRtMKNmMuQ/#scalars&_smoothingWeight=0®exInput=default 20 | [gs://flax_public/examples/mnist/default]: https://console.cloud.google.com/storage/browser/flax_public/examples/mnist/default 21 | 22 | ``` 23 | I0828 08:51:41.821526 139971964110656 train.py:130] train epoch: 10, loss: 0.0097, accuracy: 99.69 24 | I0828 08:51:42.248714 139971964110656 train.py:180] eval epoch: 10, loss: 0.0299, accuracy: 99.14 25 | ``` 26 | 27 | ### How to run 28 | 29 | `python main.py --workdir=/tmp/mnist --config=configs/default.py` 30 | 31 | #### Overriding Hyperparameter configurations 32 | 33 | MNIST example allows specifying a hyperparameter configuration by the means of 34 | setting `--config` flag. Configuration flag is defined using 35 | [config_flags](https://github.com/google/ml_collections/tree/master#config-flags). 36 | `config_flags` allows overriding configuration fields. This can be done as 37 | follows: 38 | 39 | ```shell 40 | python main.py \ 41 | --workdir=/tmp/mnist --config=configs/default.py \ 42 | --config.learning_rate=0.05 --config.num_epochs=5 43 | ``` 44 | -------------------------------------------------------------------------------- /docs/api_reference/flax.cursor.rst: -------------------------------------------------------------------------------- 1 | 2 | flax.cursor package 3 | ============================= 4 | 5 | The Cursor API allows for mutability of pytrees. This API provides a more 6 | ergonomic solution to making partial-updates of deeply nested immutable 7 | data structures, compared to making many nested ``dataclasses.replace`` calls. 8 | 9 | To illustrate, consider the example below:: 10 | 11 | >>> from flax.cursor import cursor 12 | >>> import dataclasses 13 | >>> from typing import Any 14 | 15 | >>> @dataclasses.dataclass(frozen=True) 16 | >>> class A: 17 | ... x: Any 18 | 19 | >>> a = A(A(A(A(A(A(A(0))))))) 20 | 21 | To replace the int ``0`` using ``dataclasses.replace``, we would have to write many nested calls:: 22 | 23 | >>> a2 = dataclasses.replace( 24 | ... a, 25 | ... x=dataclasses.replace( 26 | ... a.x, 27 | ... x=dataclasses.replace( 28 | ... a.x.x, 29 | ... x=dataclasses.replace( 30 | ... a.x.x.x, 31 | ... x=dataclasses.replace( 32 | ... a.x.x.x.x, 33 | ... x=dataclasses.replace( 34 | ... a.x.x.x.x.x, 35 | ... x=dataclasses.replace(a.x.x.x.x.x.x, x=1), 36 | ... ), 37 | ... ), 38 | ... ), 39 | ... ), 40 | ... ), 41 | ... ) 42 | 43 | The equivalent can be achieved much more simply using the Cursor API:: 44 | 45 | >>> a3 = cursor(a).x.x.x.x.x.x.x.set(1) 46 | >>> assert a2 == a3 47 | 48 | The Cursor object keeps tracks of changes made to it and when ``.build`` is called, 49 | generates a new object with the accumulated changes. Basic usage involves 50 | wrapping the object in a Cursor, making changes to the Cursor object and 51 | generating a new copy of the original object with the accumulated changes. 52 | 53 | .. currentmodule:: flax.cursor 54 | 55 | .. autofunction:: cursor 56 | 57 | .. autoclass:: Cursor 58 | :members: apply_update, build, find, find_all, set 59 | 60 | 61 | -------------------------------------------------------------------------------- /examples/imagenet/configs/default.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Default Hyperparameter configuration.""" 15 | 16 | import ml_collections 17 | 18 | 19 | def get_config(): 20 | """Get the default hyperparameter configuration.""" 21 | config = ml_collections.ConfigDict() 22 | 23 | # As defined in the `models` module. 24 | config.model = 'ResNet50' 25 | # `name` argument of tensorflow_datasets.builder() 26 | config.dataset = 'imagenet2012:5.*.*' 27 | 28 | config.learning_rate = 0.1 29 | config.warmup_epochs = 5.0 30 | config.momentum = 0.9 31 | config.batch_size = 128 32 | config.shuffle_buffer_size = 16 * 128 33 | config.prefetch = 10 34 | 35 | config.num_epochs = 100.0 36 | config.log_every_steps = 100 37 | 38 | config.cache = False 39 | config.half_precision = False 40 | 41 | # If num_train_steps==-1 then the number of training steps is calculated from 42 | # num_epochs using the entire dataset. Similarly for steps_per_eval. 43 | config.num_train_steps = -1 44 | config.steps_per_eval = -1 45 | 46 | # whether to profile the training loop 47 | config.profile = True 48 | 49 | return config 50 | 51 | 52 | def metrics(): 53 | return [ 54 | 'train_loss', 55 | 'eval_loss', 56 | 'train_accuracy', 57 | 'eval_accuracy', 58 | 'steps_per_second', 59 | 'train_learning_rate', 60 | ] 61 | -------------------------------------------------------------------------------- /.github/analytics/issue_activity_since_date.gql: -------------------------------------------------------------------------------- 1 | { 2 | # Queries all the issues in a repo. For each issue, we get some basic data such as 3 | # the number, state, labels, and title. The most important part is the 'timelineItems' 4 | # which are the events that happened to the issue, we can use the information about 5 | # the datetime about certain key events to define some metrics. Note that we are 6 | # getting more information than is probably needed but its fine for now. 7 | repository(owner: "_REPO_OWNER_", name: "_REPO_NAME_") { 8 | issues(first: 100) { 9 | totalCount 10 | edges { 11 | cursor 12 | node { 13 | number 14 | title 15 | createdAt 16 | state 17 | closedAt 18 | updatedAt 19 | url 20 | labels(first: 100) { 21 | edges { 22 | node { 23 | name 24 | } 25 | } 26 | } 27 | timelineItems(first: 100, itemTypes: [LABELED_EVENT, CONVERTED_TO_DISCUSSION_EVENT, ISSUE_COMMENT, CLOSED_EVENT]) { 28 | totalCount 29 | edges { 30 | node { 31 | __typename 32 | ... on ConvertedToDiscussionEvent { 33 | createdAt 34 | } 35 | ... on IssueComment { 36 | author { 37 | login 38 | } 39 | createdAt 40 | } 41 | ... on ClosedEvent { 42 | actor { 43 | login 44 | } 45 | createdAt 46 | } 47 | ... on LabeledEvent { 48 | label { 49 | name 50 | } 51 | createdAt 52 | } 53 | } 54 | } 55 | } 56 | } 57 | } 58 | } 59 | } 60 | } 61 | 62 | -------------------------------------------------------------------------------- /tests/nnx/containers_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from flax import nnx 17 | from absl.testing import absltest 18 | import jax.numpy as jnp 19 | 20 | 21 | class TestContainers(absltest.TestCase): 22 | def test_unbox(self): 23 | x = nnx.Param( 24 | jnp.array(1), 25 | on_get_value=lambda c, x: x + 3, # type: ignore 26 | ) 27 | 28 | assert x[...] == 4 29 | 30 | def test_on_set_value(self): 31 | x = nnx.Param( 32 | jnp.array(1), # type: ignore 33 | on_set_value=lambda c, x: x + 7, # type: ignore 34 | ) 35 | x[...] = 5 36 | 37 | assert x.get_raw_value() == 12 38 | 39 | def test_module_unbox(self): 40 | class Foo(nnx.Module): 41 | def __init__(self) -> None: 42 | self.x = nnx.Param(1, on_get_value=lambda c, x: x + 3) 43 | 44 | module = Foo() 45 | 46 | assert module.x.get_value() == 4 47 | assert vars(module)['x'].get_raw_value() == 1 48 | 49 | def test_module_box(self): 50 | class Foo(nnx.Module): 51 | def __init__(self) -> None: 52 | self.x = nnx.Param( 53 | jnp.array(1), 54 | on_set_value=lambda c, x: x + 7, # type: ignore 55 | ) 56 | 57 | module = Foo() 58 | module.x[...] = 5 59 | 60 | assert module.x[...] == 12 61 | assert vars(module)['x'][...] == 12 62 | 63 | 64 | if __name__ == '__main__': 65 | absltest.main() 66 | -------------------------------------------------------------------------------- /examples/cloud/startup_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Note that all __XYZ__ strings are replaced by launch_gce.py 4 | 5 | WORKDIR="/train/workdir_base/__EXAMPLE__/__NAME__/__TIMESTAMP__" 6 | 7 | mkdir -p /train 8 | cd /train 9 | 10 | # Login directly with: 11 | # gcloud compute ssh $VM -- /sudo_tmux_a.sh 12 | echo -e '#!/bin/bash\nsudo /tmux_a.sh' > /sudo_tmux_a.sh 13 | chmod a+x /sudo_tmux_a.sh 14 | echo -e '#!/bin/bash\ntmux a' > /tmux_a.sh 15 | chmod a+x /tmux_a.sh 16 | 17 | # Main script running in bottom left tmux pane. 18 | cat >/install_train_stop.sh <&1 | tee -a $WORKDIR/setup_train_log_${TIMESTAMP}.txt 45 | 46 | if [ __SHUTDOWN_SECS__ -gt 0 ]; then 47 | echo 48 | echo WILL SHUT DOWN IN $((__SHUTDOWN_SECS__/60)) MIN ... 49 | sleep __SHUTDOWN_SECS__ && shutdown now 50 | fi 51 | 52 | EOF 53 | 54 | 55 | # Set up TMUX panes: 56 | tmux new-session -s flax -d 57 | # - top left: htop 58 | tmux send 'htop 59 | ' 60 | tmux split-window 61 | tmux selectp -U 62 | tmux split-window -h 63 | # - top right: htop 64 | tmux send 'watch nvidia-smi 65 | ' 66 | tmux selectp -D 67 | # - bottom left: main script 68 | tmux send '. /install_train_stop.sh 69 | ' 70 | tmux split-window -h 71 | # - bottom right: rsync files to GCS bucket. 72 | tmux send " 73 | while true; do 74 | gsutil rsync -r workdir_base __GCS_WORKDIR_BASE__ 75 | sleep 60 76 | done 2>&1 | tee -a $WORKDIR/gcs_rsync_'__TIMESTAMP__'.txt 77 | " 78 | -------------------------------------------------------------------------------- /examples/nlp_seq/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Main file for running the NLP sequence tagging example. 16 | 17 | This file is intentionally kept short to allow config-based execution. 18 | """ 19 | 20 | from absl import app 21 | from absl import flags 22 | import train 23 | from ml_collections import config_flags 24 | 25 | 26 | FLAGS = flags.FLAGS 27 | 28 | config_flags.DEFINE_config_file( 29 | 'config', 30 | 'configs/default.py', 31 | 'File path to the training hyperparameter configuration.', 32 | lock_config=True, 33 | ) 34 | 35 | 36 | def main(argv): 37 | if len(argv) > 1: 38 | raise app.UsageError('Too many command-line arguments.') 39 | 40 | # Convert config to FLAGS for train.py compatibility 41 | config = FLAGS.config 42 | 43 | # Override FLAGS with config values 44 | FLAGS.model_dir = config.model_dir 45 | FLAGS.experiment = config.experiment 46 | FLAGS.batch_size = config.batch_size 47 | FLAGS.eval_frequency = config.eval_frequency 48 | FLAGS.num_train_steps = config.num_train_steps 49 | FLAGS.learning_rate = config.learning_rate 50 | FLAGS.weight_decay = config.weight_decay 51 | FLAGS.max_length = config.max_length 52 | FLAGS.random_seed = config.random_seed 53 | FLAGS.train = config.train 54 | FLAGS.dev = config.dev 55 | 56 | # Run the training 57 | train.main(argv) 58 | 59 | 60 | if __name__ == '__main__': 61 | app.run(main) 62 | -------------------------------------------------------------------------------- /flax/nnx/ids.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """UUIDs for Flax internals.""" 15 | 16 | import threading 17 | 18 | 19 | class UUIDManager: 20 | """Globally unique counter-based id manager. 21 | 22 | We need globally unique key ids for Module and Variable object instances 23 | to preserve and recreate sharing-by-reference relationship when lifting 24 | transforms and adopting outside Modules. 25 | - Use of id() is unacceptable because these identifiers are literally 26 | pointers which can be recycled, so we rely on a globally unique counter id 27 | instead. 28 | - We need to handle copy/deepcopy uniqueness via a wrapped type. 29 | """ 30 | 31 | def __init__(self): 32 | self._lock = threading.Lock() 33 | self._id = 0 34 | 35 | def __call__(self): 36 | with self._lock: 37 | self._id += 1 38 | return UUID(self._id) 39 | 40 | 41 | uuid = UUIDManager() 42 | 43 | 44 | class UUID: 45 | """Hashable wrapper for ids that handles uniqueness of copies.""" 46 | 47 | def __init__(self, rawid): 48 | self.id = rawid 49 | 50 | def __eq__(self, other): 51 | return isinstance(other, UUID) and other.id == self.id 52 | 53 | def __hash__(self): 54 | return hash(self.id) 55 | 56 | def __repr__(self): 57 | return f'UUID({self.id})' 58 | 59 | def __deepcopy__(self, memo): 60 | del memo 61 | return uuid() 62 | 63 | def __copy__(self): 64 | return uuid() 65 | -------------------------------------------------------------------------------- /flax/ids.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """UUIDs for Flax internals.""" 16 | 17 | import threading 18 | 19 | 20 | class UUIDManager: 21 | """Globally unique counter-based id manager. 22 | 23 | We need globally unique key ids for Module and Variable object instances 24 | to preserve and recreate sharing-by-reference relationship when lifting 25 | transforms and adopting outside Modules. 26 | - Use of id() is unacceptable because these identifiers are literally 27 | pointers which can be recycled, so we rely on a globally unique counter id 28 | instead. 29 | - We need to handle copy/deepcopy uniqueness via a wrapped type. 30 | """ 31 | 32 | def __init__(self): 33 | self._lock = threading.Lock() 34 | self._id = 0 35 | 36 | def __call__(self): 37 | with self._lock: 38 | self._id += 1 39 | return FlaxId(self._id) 40 | 41 | 42 | uuid = UUIDManager() 43 | 44 | 45 | class FlaxId: 46 | """Hashable wrapper for ids that handles uniqueness of copies.""" 47 | 48 | def __init__(self, rawid): 49 | self.id = rawid 50 | 51 | def __eq__(self, other): 52 | return isinstance(other, FlaxId) and other.id == self.id 53 | 54 | def __hash__(self): 55 | return hash(self.id) 56 | 57 | def __repr__(self): 58 | return f'FlaxId({self.id})' 59 | 60 | def __deepcopy__(self, memo): 61 | del memo 62 | return uuid() 63 | 64 | def __copy__(self): 65 | return uuid() 66 | -------------------------------------------------------------------------------- /tests/import_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# Test Import in Colab\n", 9 | "\n", 10 | "\"Run all\" to test that all the Flax imports work in head.\n", 11 | "\n", 12 | "Change runtime type as needed." 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "# Colab runtimes are pre-built with JAX/Flax:\n", 22 | "!pip freeze | egrep 'jax|flax'" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": { 29 | "tags": [ 30 | "skip-execution" 31 | ] 32 | }, 33 | "outputs": [], 34 | "source": [ 35 | "# Install from head\n", 36 | "!pip install git+https://github.com/google/flax.git" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "# Check versions after installing Flax from Github:\n", 46 | "!pip freeze | egrep 'jax|flax'" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 9, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "# Verify we can import everything.\n", 56 | "import flax\n", 57 | "from flax.training import (checkpoints, dynamic_scale, early_stopping, lr_schedule,\n", 58 | " orbax_utils, prefetch_iterator, train_state, common_utils)\n", 59 | "from flax.metrics import tensorboard" 60 | ] 61 | } 62 | ], 63 | "metadata": { 64 | "language_info": { 65 | "codemirror_mode": { 66 | "name": "ipython", 67 | "version": 3 68 | }, 69 | "file_extension": ".py", 70 | "mimetype": "text/x-python", 71 | "name": "python", 72 | "nbconvert_exporter": "python", 73 | "pygments_lexer": "ipython3", 74 | "version": "3.9.15" 75 | } 76 | }, 77 | "nbformat": 4, 78 | "nbformat_minor": 2 79 | } 80 | -------------------------------------------------------------------------------- /flax/core/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Flax Neural Network api.""" 16 | 17 | # pylint: disable=g-multiple-import 18 | # re-export commonly used modules and functions 19 | from flax.linen import activation as activation 20 | from flax.linen import initializers as initializers 21 | from flax.linen.activation import ( 22 | celu as celu, 23 | elu as elu, 24 | gelu as gelu, 25 | glu as glu, 26 | leaky_relu as leaky_relu, 27 | log_sigmoid as log_sigmoid, 28 | log_softmax as log_softmax, 29 | relu as relu, 30 | sigmoid as sigmoid, 31 | silu as silu, 32 | soft_sign as soft_sign, 33 | softmax as softmax, 34 | softplus as softplus, 35 | swish as swish, 36 | tanh as tanh, 37 | ) 38 | from flax.linen.pooling import (avg_pool as avg_pool, max_pool as max_pool) 39 | from .attention import ( 40 | dot_product_attention as dot_product_attention, 41 | multi_head_dot_product_attention as multi_head_dot_product_attention, 42 | ) 43 | from .linear import ( 44 | Embedding as Embedding, 45 | conv_transpose as conv_transpose, 46 | conv as conv, 47 | dense_general as dense_general, 48 | dense as dense, 49 | embedding as embedding, 50 | ) 51 | from .normalization import ( 52 | batch_norm as batch_norm, 53 | group_norm as group_norm, 54 | layer_norm as layer_norm, 55 | ) 56 | from .stochastic import dropout as dropout 57 | 58 | # pylint: enable=g-multiple-import 59 | -------------------------------------------------------------------------------- /examples/sst2/README.md: -------------------------------------------------------------------------------- 1 | ## SST-2 classification 2 | 3 | Trains a simple text classifier on the SST-2 sentiment classification dataset. 4 | 5 | You can run this code and even modify it directly in Google Colab, no 6 | installation required: 7 | 8 | https://colab.research.google.com/github/google/flax/blob/main/examples/sst2/sst2.ipynb 9 | 10 | ### Requirements 11 | * TensorFlow dataset `glue/sst2` will be downloaded and prepared automatically, if necessary. 12 | 13 | ### Example output 14 | 15 | | Name | Platform | Epochs | Walltime | Accuracy | Metrics | Workdir | 16 | |:--------|:--------|--------:|:-----------|:-----------------|:----------------------------------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------| 17 | | default | TPU | 10 | 4.3m | 85.21% | [tensorboard.dev](https://tensorboard.dev/experiment/yTQjjRY9RlGRrZzg8h9PJw/) | | 18 | 19 | ``` 20 | INFO:absl:train epoch 010 loss 0.1918 accuracy 92.41 21 | INFO:absl:eval epoch 010 loss 0.4144 accuracy 85.21 22 | ``` 23 | 24 | ### How to run 25 | 26 | ```bash 27 | python main.py --workdir=/tmp/sst2 --config=configs/default.py` 28 | ``` 29 | 30 | #### Overriding Hyperparameter configurations 31 | 32 | The SST2 example allows specifying a hyperparameter configuration by means of 33 | setting the `--config` flag. The configuration flag is defined using 34 | [config_flags](https://github.com/google/ml_collections/tree/master#config-flags). 35 | `config_flags` allows overriding configuration fields. This can be done as 36 | follows: 37 | 38 | ```shell 39 | python main.py \ 40 | --workdir=/tmp/sst2 --config=configs/default.py \ 41 | --config.learning_rate=0.05 --config.num_epochs=5 42 | ``` 43 | -------------------------------------------------------------------------------- /tests/colab_tpu_jax_version.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# JAX/jaxlib should be both 0.3.25\n", 10 | "# because newer JAX versions are *not* supported on TPU runtimes\n", 11 | "# Flax should be included in a ƒresh kernel.\n", 12 | "!pip freeze | egrep 'jax|flax'" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "# should show 8 TPU devices\n", 22 | "import jax, jax.tools.colab_tpu\n", 23 | "jax.tools.colab_tpu.setup_tpu()\n", 24 | "jax.devices()" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "# sometimes it's necessary to install additional packages; but we need to keep\n", 34 | "# JAX/jaxlib versions pinned to what is supported by the TPU runtime...\n", 35 | "!pip install jax==0.3.25 jaxlib==0.3.25 flax" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "# in case JAX version has changed after the '!pip install`, below command should\n", 45 | "# show the offending packages\n", 46 | "!pip install -qq pipdeptree\n", 47 | "!pipdeptree -w silence -r -p jax" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "# it's possible to get dependency tree without installing packages, but this\n", 57 | "# usually takes some 2-3 minutes...\n", 58 | "!pip install -qq pipgrip\n", 59 | "!pipgrip --tree flax==0.6.4" 60 | ] 61 | } 62 | ], 63 | "metadata": { 64 | "accelerator": "TPU", 65 | "gpuClass": "standard", 66 | "language_info": { 67 | "name": "python" 68 | } 69 | }, 70 | "nbformat": 4, 71 | "nbformat_minor": 0 72 | } 73 | -------------------------------------------------------------------------------- /examples/vae/models.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """VAE model definitions.""" 16 | 17 | from flax import linen as nn 18 | from jax import random 19 | import jax.numpy as jnp 20 | 21 | 22 | class Encoder(nn.Module): 23 | """VAE Encoder.""" 24 | 25 | latents: int 26 | 27 | @nn.compact 28 | def __call__(self, x): 29 | x = nn.Dense(500, name='fc1')(x) 30 | x = nn.relu(x) 31 | mean_x = nn.Dense(self.latents, name='fc2_mean')(x) 32 | logvar_x = nn.Dense(self.latents, name='fc2_logvar')(x) 33 | return mean_x, logvar_x 34 | 35 | 36 | class Decoder(nn.Module): 37 | """VAE Decoder.""" 38 | 39 | @nn.compact 40 | def __call__(self, z): 41 | z = nn.Dense(500, name='fc1')(z) 42 | z = nn.relu(z) 43 | z = nn.Dense(784, name='fc2')(z) 44 | return z 45 | 46 | 47 | class VAE(nn.Module): 48 | """Full VAE model.""" 49 | 50 | latents: int = 20 51 | 52 | def setup(self): 53 | self.encoder = Encoder(self.latents) 54 | self.decoder = Decoder() 55 | 56 | def __call__(self, x, z_rng): 57 | mean, logvar = self.encoder(x) 58 | z = reparameterize(z_rng, mean, logvar) 59 | recon_x = self.decoder(z) 60 | return recon_x, mean, logvar 61 | 62 | def generate(self, z): 63 | return nn.sigmoid(self.decoder(z)) 64 | 65 | 66 | def reparameterize(rng, mean, logvar): 67 | std = jnp.exp(0.5 * logvar) 68 | eps = random.normal(rng, logvar.shape) 69 | return mean + eps * std 70 | 71 | 72 | def model(latents): 73 | return VAE(latents=latents) 74 | -------------------------------------------------------------------------------- /examples/linen_design_test/mlp_inline.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import jax 16 | from jax import numpy as jnp 17 | from flax import linen as nn 18 | from collections.abc import Iterable 19 | from flax.linen import Module, compact 20 | from dense import Dense 21 | 22 | 23 | # Many NN layers and blocks are best described by a single function with inline variables. 24 | # In this case, variables are initialized during the first call. 25 | class MLP(Module): 26 | sizes: Iterable[int] 27 | 28 | @compact 29 | def __call__(self, x): 30 | for size in self.sizes[:-1]: 31 | x = Dense(size)(x) 32 | x = nn.relu(x) 33 | return Dense(self.sizes[-1])(x) 34 | 35 | 36 | # Return an initialized instance of MLP by calling `__call__` with an input batch, 37 | # initializing all variables. 38 | # 39 | # Variable shapes depend on the input shape passed in. 40 | rngkey = jax.random.key(10) 41 | model = MLP((2, 1)) 42 | x = jnp.ones((1, 3)) 43 | mlp_variables = model.init(rngkey, x) 44 | print(mlp_variables) 45 | # {'params': {'Dense_0': {'bias': DeviceArray([0.], dtype=float32), 46 | # 'kernel': DeviceArray([[-0.04267037], 47 | # [-0.51097125]], dtype=float32)}, 48 | # 'Dense_1': {'bias': DeviceArray([0., 0.], dtype=float32), 49 | # 'kernel': DeviceArray([[-6.3845289e-01, 6.0373604e-01], 50 | # [-5.9814966e-01, 5.1718324e-01], 51 | # [-6.2220657e-01, 5.8988278e-04]], dtype=float32)}}} 52 | print(model.apply(mlp_variables, x)) 53 | -------------------------------------------------------------------------------- /examples/linen_design_test/mlp_lazy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import jax 16 | from jax import numpy as jnp 17 | from flax import linen as nn 18 | from flax.linen import Module 19 | from pprint import pprint 20 | from dense import Dense 21 | 22 | 23 | # Here submodules are explicitly defined during init, but still materialized 24 | # lazily only once a first input is passed through and shapes are known. 25 | class MLP(Module): 26 | 27 | def setup(self): 28 | self.dense1 = Dense(features=2) 29 | self.dense2 = Dense(features=1) 30 | 31 | # shapes aren't yet known, so variables aren't materialized 32 | print(self.dense2.variables) 33 | # FrozenDict({}) 34 | 35 | def __call__(self, x): 36 | return self.dense2(nn.relu(self.dense1(x))) 37 | 38 | 39 | # Return an initialized instance of MLP by calling `__call__` with an input batch, 40 | # initializing all variables. 41 | # 42 | # Variable shapes depend on the input shape passed in. 43 | rngkey = jax.random.key(10) 44 | mlp_variables = MLP().init(rngkey, jnp.zeros((1, 3))) 45 | 46 | pprint(mlp_variables) 47 | # {'params': {'dense1': {'bias': DeviceArray([0., 0.], dtype=float32), 48 | # 'kernel': DeviceArray([[ 0.18307537, -0.38739476], 49 | # [-0.902451 , -0.5190721 ], 50 | # [ 0.51552075, 1.1169153 ]], dtype=float32)}, 51 | # 'dense2': {'bias': DeviceArray([0.], dtype=float32), 52 | # 'kernel': DeviceArray([[ 0.6704609 ], 53 | # [-0.90477365]], dtype=float32)}}} 54 | -------------------------------------------------------------------------------- /tests/configurations_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from unittest import mock 16 | 17 | from absl.testing import absltest 18 | 19 | from flax.configurations import bool_flag, config 20 | 21 | 22 | class MyTestCase(absltest.TestCase): 23 | def setUp(self): 24 | super().setUp() 25 | self.enter_context(mock.patch.object(config, '_values', {})) 26 | self._flag = bool_flag('flax_test', default=False, help='Just a test flag.') 27 | 28 | def test_duplicate_flag(self): 29 | with self.assertRaisesRegex(RuntimeError, 'already defined'): 30 | bool_flag(self._flag.name, default=False, help='Another test flag.') 31 | 32 | def test_default(self): 33 | self.assertFalse(self._flag.value) 34 | self.assertFalse(config.flax_test) 35 | 36 | def test_typed_update(self): 37 | config.update(self._flag, True) 38 | self.assertTrue(self._flag.value) 39 | self.assertTrue(config.flax_test) 40 | 41 | def test_untyped_update(self): 42 | config.update(self._flag.name, True) 43 | self.assertTrue(self._flag.value) 44 | self.assertTrue(config.flax_test) 45 | 46 | def test_update_unknown_flag(self): 47 | with self.assertRaisesRegex(LookupError, 'Unrecognized config option'): 48 | config.update('unknown', True) 49 | 50 | def test_temp_flip(self): 51 | self.assertFalse(self._flag.value) 52 | with config.temp_flip_flag('test', True): 53 | self.assertTrue(self._flag.value) 54 | self.assertFalse(self._flag.value) 55 | 56 | 57 | if __name__ == '__main__': 58 | absltest.main() 59 | -------------------------------------------------------------------------------- /examples/nnx_toy_examples/06_scan_over_layers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import jax 17 | import jax.numpy as jnp 18 | 19 | from flax import nnx 20 | 21 | 22 | class Block(nnx.Module): 23 | def __init__(self, dim: int, *, rngs: nnx.Rngs): 24 | self.linear = nnx.Linear(dim, dim, rngs=rngs) 25 | self.bn = nnx.BatchNorm(dim, rngs=rngs) 26 | self.dropout = nnx.Dropout(0.5, rngs=rngs) 27 | 28 | def __call__(self, x: jax.Array): 29 | return jax.nn.gelu(self.dropout(self.bn(self.linear(x)))) 30 | 31 | 32 | class ScanMLP(nnx.Module): 33 | """ 34 | An MLP that uses `vmap` during `__init__` to create a Block instance 35 | with an additional `layer` axis, and `scan` during `__call__` to apply 36 | the sequence of layers iteratively over the input / output `x`. 37 | """ 38 | 39 | def __init__(self, dim: int, *, n_layers: int, rngs: nnx.Rngs): 40 | self.n_layers = n_layers 41 | 42 | @nnx.split_rngs(splits=n_layers) 43 | @nnx.vmap(axis_size=n_layers) 44 | def create_block(rngs: nnx.Rngs): 45 | return Block(dim, rngs=rngs) 46 | 47 | self.layers = create_block(rngs) 48 | 49 | def __call__(self, x: jax.Array) -> jax.Array: 50 | @nnx.scan 51 | def scan_fn(x: jax.Array, block: Block): 52 | x = block(x) 53 | return x, None 54 | 55 | x, _ = scan_fn(x, self.layers) 56 | 57 | return x 58 | 59 | 60 | model = ScanMLP(10, n_layers=5, rngs=nnx.Rngs(0)) 61 | 62 | x = jnp.ones((3, 10)) 63 | y = model(x) 64 | 65 | print(jax.tree.map(jnp.shape, nnx.state(model))) 66 | print(y.shape) 67 | -------------------------------------------------------------------------------- /examples/imagenet/models_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for flax.examples.imagenet.models.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | 20 | import jax 21 | from jax import numpy as jnp 22 | 23 | import models 24 | 25 | 26 | jax.config.update('jax_disable_most_optimizations', True) 27 | 28 | 29 | class ResNetTest(parameterized.TestCase): 30 | """Test cases for ResNet v1.5 model definition.""" 31 | 32 | def test_resnet_model(self): 33 | """Tests ResNet V1.5 model definition and output (variables).""" 34 | rng = jax.random.key(0) 35 | model_def = models.ResNet50(num_classes=10, dtype=jnp.float32) 36 | variables = model_def.init(rng, jnp.ones((8, 224, 224, 3), jnp.float32)) 37 | 38 | self.assertLen(variables, 2) 39 | # Resnet50 model will create parameters for the following layers: 40 | # conv + batch_norm = 2 41 | # BottleneckResNetBlock in stages: [3, 4, 6, 3] = 16 42 | # Followed by a Dense layer = 1 43 | self.assertLen(variables['params'], 19) 44 | 45 | @parameterized.product(model=(models.ResNet18, models.ResNet18Local)) 46 | def test_resnet_18_model(self, model): 47 | """Tests ResNet18 V1.5 model definition and output (variables).""" 48 | rng = jax.random.key(0) 49 | model_def = model(num_classes=2, dtype=jnp.float32) 50 | variables = model_def.init(rng, jnp.ones((1, 64, 64, 3), jnp.float32)) 51 | 52 | self.assertLen(variables, 2) 53 | self.assertLen(variables['params'], 11) 54 | 55 | 56 | if __name__ == '__main__': 57 | absltest.main() 58 | -------------------------------------------------------------------------------- /examples/ppo/configs/default.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Definitions of default hyperparameters.""" 16 | 17 | import ml_collections 18 | 19 | 20 | def get_config(): 21 | """Get the default configuration. 22 | 23 | The default hyperparameters originate from PPO paper arXiv:1707.06347 24 | and openAI baselines 2:: 25 | https://github.com/openai/baselines/blob/master/baselines/ppo2/defaults.py 26 | """ 27 | config = ml_collections.ConfigDict() 28 | # The Atari game used. 29 | config.game = 'Pong' 30 | # Total number of frames seen during training. 31 | config.total_frames = 40000000 32 | # The learning rate for the Adam optimizer. 33 | config.learning_rate = 2.5e-4 34 | # Batch size used in training. 35 | config.batch_size = 256 36 | # Number of agents playing in parallel. 37 | config.num_agents = 8 38 | # Number of steps each agent performs in one policy unroll. 39 | config.actor_steps = 128 40 | # Number of training epochs per each unroll of the policy. 41 | config.num_epochs = 3 42 | # RL discount parameter. 43 | config.gamma = 0.99 44 | # Generalized Advantage Estimation parameter. 45 | config.lambda_ = 0.95 46 | # The PPO clipping parameter used to clamp ratios in loss function. 47 | config.clip_param = 0.1 48 | # Weight of value function loss in the total loss. 49 | config.vf_coeff = 0.5 50 | # Weight of entropy bonus in the total loss. 51 | config.entropy_coeff = 0.01 52 | # Linearly decay learning rate and clipping parameter to zero during 53 | # the training. 54 | config.decaying_lr_and_clip_param = True 55 | return config 56 | -------------------------------------------------------------------------------- /examples/ppo/test_episodes.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Test policy by playing a full Atari game.""" 16 | 17 | import itertools 18 | from typing import Any 19 | from collections.abc import Callable 20 | 21 | import flax 22 | import numpy as np 23 | 24 | import agent 25 | import env_utils 26 | 27 | 28 | def policy_test( 29 | n_episodes: int, 30 | apply_fn: Callable[..., Any], 31 | params: flax.core.frozen_dict.FrozenDict, 32 | game: str, 33 | ): 34 | """Perform a test of the policy in Atari environment. 35 | 36 | Args: 37 | n_episodes: number of full Atari episodes to test on 38 | apply_fn: the actor-critic apply function 39 | params: actor-critic model parameters, they define the policy being tested 40 | game: defines the Atari game to test on 41 | 42 | Returns: 43 | total_reward: obtained score 44 | """ 45 | test_env = env_utils.create_env(game, clip_rewards=False) 46 | for _ in range(n_episodes): 47 | obs = test_env.reset() 48 | state = obs[None, ...] # add batch dimension 49 | total_reward = 0.0 50 | for t in itertools.count(): 51 | log_probs, _ = agent.policy_action(apply_fn, params, state) 52 | probs = np.exp(np.array(log_probs, dtype=np.float32)) 53 | probabilities = probs[0] / probs[0].sum() 54 | action = np.random.choice(probs.shape[1], p=probabilities) 55 | obs, reward, done, _ = test_env.step(action) 56 | total_reward += reward 57 | next_state = obs[None, ...] if not done else None 58 | state = next_state 59 | if done: 60 | break 61 | return total_reward 62 | -------------------------------------------------------------------------------- /flaxlib_src/src/flaxlib/flaxlib_cpp.pyi: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import typing as tp 16 | 17 | RefMap = tp.MutableMapping[tp.Any, int] 18 | IndexMap = dict[int, tp.Any] 19 | 20 | class NodeDef: 21 | type: type 22 | index: int | None 23 | outer_index: int | None 24 | num_attributes: int 25 | metadata: tp.Any 26 | 27 | def with_no_outer_index(self) -> NodeDef: ... 28 | def with_same_outer_index(self) -> NodeDef: ... 29 | def __eq__(self, other: tp.Any) -> bool: ... 30 | def __hash__(self) -> int: ... 31 | def __getstate__( 32 | self, 33 | ) -> tuple[tp.Any, tp.Any, tp.Any, tp.Any, tp.Any]: ... 34 | @staticmethod 35 | def __setstate__( 36 | nodedef: NodeDef, state: tuple[tp.Any, tp.Any, tp.Any, tp.Any, tp.Any] 37 | ) -> None: ... 38 | 39 | class VariableDef: 40 | type: type 41 | index: int 42 | outer_index: int | None 43 | metadata: tp.Any 44 | 45 | def with_no_outer_index(self) -> VariableDef: ... 46 | def with_same_outer_index(self) -> VariableDef: ... 47 | def __eq__(self, other: tp.Any) -> bool: ... 48 | def __hash__(self) -> int: ... 49 | def __getstate__( 50 | self, 51 | ) -> tuple[tp.Any, int, tp.Any, tp.Any]: ... 52 | @staticmethod 53 | def __setstate__( 54 | variabledef: 'VariableDef', state: tuple[tp.Any, int, tp.Any, tp.Any] 55 | ) -> None: ... 56 | 57 | class NodeRef: 58 | index: int 59 | 60 | def __eq__(self, other: tp.Any) -> bool: ... 61 | def __hash__(self) -> int: ... 62 | def __getstate__(self) -> tuple[int]: ... 63 | @staticmethod 64 | def __setstate__(noderef: NodeRef, state: tuple[int]) -> None: ... 65 | 66 | -------------------------------------------------------------------------------- /examples/nnx_toy_examples/09_parameter_surgery.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import jax 17 | 18 | from flax import nnx 19 | 20 | 21 | # lets pretend this function loads a pretrained model from a checkpoint 22 | def load_pretrained(): 23 | return nnx.Linear(784, 128, rngs=nnx.Rngs(0)) 24 | 25 | 26 | # create a simple linear classifier using a pretrained backbone 27 | class Classifier(nnx.Module): 28 | def __init__(self, *, rngs: nnx.Rngs): 29 | self.backbone = nnx.Linear(784, 128, rngs=nnx.Rngs(0)) 30 | self.head = nnx.Linear(128, 10, rngs=rngs) 31 | 32 | def __call__(self, x): 33 | x = self.backbone(x) 34 | x = nnx.relu(x) 35 | x = self.head(x) 36 | return x 37 | 38 | 39 | # create the classifier using the pretrained backbone, here we are technically 40 | # doing "parameter surgery", however, compared to Haiku/Flax where you must manually 41 | # construct the parameter structure, in NNX this is done automatically 42 | model = Classifier(rngs=nnx.Rngs(42)) 43 | model.backbone = load_pretrained() 44 | 45 | 46 | # create a filter to select all the parameters that are not part of the 47 | # backbone, i.e. the classifier parameters 48 | is_trainable = lambda path, node: ( 49 | 'backbone' in path and isinstance(node, nnx.Param) 50 | ) 51 | 52 | # split the parameters into trainable and non-trainable parameters 53 | graphdef, trainable_params, non_trainable = nnx.split(model, is_trainable, ...) 54 | 55 | print( 56 | 'trainable_params =', 57 | jax.tree.map(jax.numpy.shape, trainable_params), 58 | ) 59 | print('non_trainable = ', jax.tree.map(jax.numpy.shape, non_trainable)) 60 | -------------------------------------------------------------------------------- /examples/vae/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Main file for running the VAE example. 16 | 17 | This file is intentionally kept short. The majority for logic is in libraries 18 | that can be easily tested and imported in Colab. 19 | """ 20 | 21 | from absl import app 22 | from absl import flags 23 | from absl import logging 24 | from clu import platform 25 | import jax 26 | from ml_collections import config_flags 27 | import tensorflow as tf 28 | 29 | import train 30 | 31 | 32 | FLAGS = flags.FLAGS 33 | 34 | flags.DEFINE_string('workdir', None, 'Directory to store model data.') 35 | config_flags.DEFINE_config_file( 36 | 'config', 37 | None, 38 | 'File path to the training hyperparameter configuration.', 39 | lock_config=True, 40 | ) 41 | flags.mark_flags_as_required(['config', 'workdir']) 42 | 43 | 44 | def main(argv): 45 | if len(argv) > 1: 46 | raise app.UsageError('Too many command-line arguments.') 47 | 48 | # Make sure tf does not allocate gpu memory. 49 | tf.config.experimental.set_visible_devices([], 'GPU') 50 | 51 | logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) 52 | logging.info('JAX local devices: %r', jax.local_devices()) 53 | 54 | # Add a note so that we can tell which task is which JAX host. 55 | # (Depending on the platform task 0 is not guaranteed to be host 0) 56 | platform.work_unit().set_task_status( 57 | f'process_index: {jax.process_index()}, ' 58 | f'process_count: {jax.process_count()}' 59 | ) 60 | 61 | train.train_and_evaluate(FLAGS.config, FLAGS.workdir) 62 | 63 | 64 | if __name__ == '__main__': 65 | app.run(main) 66 | -------------------------------------------------------------------------------- /examples/ogbg_molpcba/configs/hparam_sweep.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Defines a sweep for the hyperparameters for the GNN.""" 16 | 17 | import ml_collections 18 | 19 | 20 | def get_config(): 21 | """Get the default hyperparameter configuration.""" 22 | config = ml_collections.ConfigDict() 23 | 24 | # Optimizer. 25 | config.optimizer = 'adam' 26 | config.learning_rate = 1e-3 27 | 28 | # Training hyperparameters. 29 | config.batch_size = 256 30 | config.num_train_steps = 500_000 31 | config.log_every_steps = 50 32 | config.eval_every_steps = 1_000 33 | config.checkpoint_every_steps = 10_000 34 | config.add_virtual_node = True 35 | config.add_undirected_edges = True 36 | config.add_self_loops = True 37 | 38 | # GNN hyperparameters. 39 | config.model = 'GraphConvNet' 40 | config.message_passing_steps = 5 41 | config.latent_size = 256 42 | config.dropout_rate = 0.1 43 | config.num_mlp_layers = 2 44 | config.num_classes = 128 45 | config.skip_connections = True 46 | config.layer_norm = True 47 | 48 | return config 49 | 50 | 51 | def sweep(add): 52 | for add_virtual_node in (True, False): 53 | for add_undirected_edges in (True, False): 54 | for add_self_loops in (True, False): 55 | for layer_norm in (True, False): 56 | for skip_connections in (True, False): 57 | add( 58 | add_virtual_node=add_virtual_node, 59 | add_undirected_edges=add_undirected_edges, 60 | add_self_loops=add_self_loops, 61 | layer_norm=layer_norm, 62 | skip_connections=skip_connections, 63 | ) 64 | -------------------------------------------------------------------------------- /flax/linen/README.md: -------------------------------------------------------------------------------- 1 | # Linen: A comfortable evolution of Flax 2 | 3 | Linen is a neural network API developed based on learning from our users and the broader JAX community. Linen improves on much of the former `flax.nn` API (removed since v0.4.0), such as submodule sharing and better support for non-trainable variables. 4 | Moreover, Linen builds on a "functional core", enabling direct usage of JAX transformations such as `vmap`, `remat` or `scan` inside your modules. 5 | 6 | In Linen, Modules behave much closer to vanilla Python objects, while still letting you opt-in to the concise single-method pattern many of our users love. 7 | 8 | The Linen Module API is stable and currently recommended for new projects. We are already supporting users in the OSS community and within Google. Minor changes may come to the top-level `apply` and `init` patterns, which we will communicate clearly. We plan a few improvements, including writing up short design notes, adding more design tests (see last link below), and an API for interactive module instances. 9 | 10 | Please open a [discussion](https://github.com/google/flax/discussions) if you have any questions or thoughts. 11 | 12 | **See the [Linen API reference docs](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/index.html)**, or take a look at our additional material: 13 | 14 | * 2-page intro to the [Linen Design Principles](https://docs.google.com/document/d/1ZlL_4bXCw5Xl0WstQw1GpnZqfb9JFOeUGAPcBVk-kn8) 15 | * [Slides from a talk to the JAX core team](https://docs.google.com/presentation/d/1ngKWUwsSqAwPRvATG8sAxMzu9ujv4N__cKsUofdNno0) 16 | * [Brief Intro to Linen](https://colab.research.google.com/github/google/flax/blob/main/docs/linen_intro.ipynb) in Colab 17 | * An [upgrade guide](https://docs.google.com/document/d/1hYavTVPaKVVe9Be8pCB7yW7r6dDv3RALVNit8NZca4c) + some additional questions we're considering 18 | * Ported [examples](https://github.com/google/flax/tree/main/examples) 19 | * "Design tests" used to ensure that our "functional core" supports [various advanced use-cases](https://github.com/google/flax/tree/main/tests/core/design), and that the mostly-syntactic-sugar Module abstraction 20 | [doesn't get in the way](https://github.com/google/flax/tree/main/examples/linen_design_test) 21 | -------------------------------------------------------------------------------- /docs/examples/repositories_that_use_flax.rst: -------------------------------------------------------------------------------- 1 | Repositories that use Flax 2 | ========================== 3 | 4 | The following code bases use Flax and provide training frameworks and a wealth 5 | of examples. In many cases, you can also find pre-trained weights: 6 | 7 | 8 | 🤗 Hugging Face 9 | *************** 10 | 11 | `🤗 Hugging Face `__ is a 12 | very popular library for building, training, and deploying state of the art 13 | machine learning models. 14 | These models can be applied on text, images, and audio. After organizing the 15 | `JAX/Flax community week `__, 16 | they have now over 5,000 17 | `Flax/JAX models `__ in 18 | their repository. 19 | 20 | 🥑 DALLE Mini 21 | ************* 22 | 23 | `🥑 DALLE Mini `__ is a Transformer-based 24 | text-to-image model implemented in JAX/Flax that follows the ideas from the 25 | original `DALLE `__ paper by OpenAI. 26 | 27 | Scenic 28 | ****** 29 | 30 | `Scenic `__ is a codebase/library 31 | for computer vision research and beyond. Scenic's main focus is around 32 | attention-based models. Scenic has been successfully used to develop 33 | classification, segmentation, and detection models for multiple modalities 34 | including images, video, audio, and multimodal combinations of them. 35 | 36 | Big Vision 37 | ********** 38 | 39 | `Big Vision `__ is a codebase 40 | designed for training large-scale vision models using Cloud TPU VMs or GPU 41 | machines. It is based on Jax/Flax libraries, and uses tf.data and TensorFlow 42 | Datasets for scalable and reproducible input pipelines. This is the original 43 | codebase of ViT, MLP-Mixer, LiT, UViM, and many more models. 44 | 45 | T5X 46 | *** 47 | 48 | `T5X `__ is a modular, composable, 49 | research-friendly framework for high-performance, configurable, self-service 50 | training, evaluation, and inference of sequence models (starting with 51 | language) at many scales. -------------------------------------------------------------------------------- /examples/nnx_toy_examples/08_save_load_checkpoints.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from tempfile import TemporaryDirectory 16 | 17 | import jax 18 | import jax.numpy as jnp 19 | import orbax.checkpoint as orbax 20 | 21 | from flax import nnx 22 | 23 | 24 | class MLP(nnx.Module): 25 | def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs): 26 | self.dense1 = nnx.Linear(din, dmid, rngs=rngs) 27 | self.dense2 = nnx.Linear(dmid, dout, rngs=rngs) 28 | 29 | def __call__(self, x: jax.Array) -> jax.Array: 30 | x = self.dense1(x) 31 | x = jax.nn.relu(x) 32 | x = self.dense2(x) 33 | return x 34 | 35 | 36 | def create_model(seed: int): 37 | return MLP(10, 20, 30, rngs=nnx.Rngs(seed)) 38 | 39 | 40 | def create_and_save(seed: int, path: str): 41 | model = create_model(seed) 42 | state = nnx.state(model) 43 | # Save the parameters 44 | checkpointer = orbax.PyTreeCheckpointer() 45 | checkpointer.save(f'{path}/state', state) 46 | 47 | 48 | def load_model(path: str) -> MLP: 49 | # create that model with abstract shapes 50 | model = nnx.eval_shape(lambda: create_model(0)) 51 | state = nnx.state(model) 52 | # Load the parameters 53 | checkpointer = orbax.PyTreeCheckpointer() 54 | state = checkpointer.restore(f'{path}/state', item=state) 55 | # update the model with the loaded state 56 | nnx.update(model, state) 57 | return model 58 | 59 | 60 | with TemporaryDirectory() as tmpdir: 61 | # create a checkpoint 62 | create_and_save(42, tmpdir) 63 | # load model from checkpoint 64 | model = load_model(tmpdir) 65 | # run the model 66 | y = model(jnp.ones((1, 10))) 67 | print(model) 68 | print(y) 69 | -------------------------------------------------------------------------------- /tests/linen/initializers_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for flax.linen.initializers.""" 16 | 17 | import jax 18 | import jax.numpy as jnp 19 | import numpy as np 20 | from absl.testing import absltest, parameterized 21 | from jax import random 22 | 23 | from flax import linen as nn 24 | from flax.linen import initializers 25 | 26 | # Parse absl flags test_srcdir and test_tmpdir. 27 | jax.config.parse_flags_with_absl() 28 | 29 | 30 | class InitializersTest(parameterized.TestCase): 31 | @parameterized.parameters( 32 | { 33 | 'builder_fn': initializers.zeros_init, 34 | 'params_shape': (2, 3), 35 | 'expected_params': jnp.zeros((2, 3)), 36 | }, 37 | { 38 | 'builder_fn': initializers.ones_init, 39 | 'params_shape': (3, 2), 40 | 'expected_params': jnp.ones((3, 2)), 41 | }, 42 | ) 43 | def test_call_builder(self, builder_fn, params_shape, expected_params): 44 | params = builder_fn()(random.key(42), params_shape, jnp.float32) 45 | np.testing.assert_allclose(params, expected_params) 46 | 47 | @parameterized.parameters( 48 | { 49 | 'builder_fn': initializers.zeros_init, 50 | 'expected_params': jnp.zeros((2, 5)), 51 | }, 52 | { 53 | 'builder_fn': initializers.ones_init, 54 | 'expected_params': jnp.ones((2, 5)), 55 | }, 56 | ) 57 | def test_kernel_builder(self, builder_fn, expected_params): 58 | layer = nn.Dense(5, kernel_init=builder_fn()) 59 | params = layer.init(random.key(42), jnp.empty((3, 2)))['params'] 60 | np.testing.assert_allclose(params['kernel'], expected_params) 61 | 62 | 63 | if __name__ == '__main__': 64 | absltest.main() 65 | -------------------------------------------------------------------------------- /examples/sst2/build_vocabulary.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A vocabulary builder that generates vocab.txt to be used for training.""" 16 | 17 | import time 18 | from collections.abc import Iterable, Sequence 19 | 20 | from absl import logging 21 | import tensorflow as tf 22 | import tensorflow_datasets as tfds 23 | import tensorflow_text as tftext 24 | 25 | import vocabulary 26 | 27 | 28 | def get_tokenized_sequences( 29 | dataset: tf.data.Dataset, 30 | tokenizer: tftext.Tokenizer = tftext.WhitespaceTokenizer(), 31 | input_key: str = 'sentence', 32 | ) -> Iterable[Sequence[bytes]]: 33 | """Returns tokenized sequences for vocabulary building.""" 34 | dataset = dataset.map( 35 | lambda example: tokenizer.tokenize(example[input_key]), 36 | num_parallel_calls=tf.data.experimental.AUTOTUNE, 37 | ) 38 | yield from tfds.as_numpy(dataset) 39 | 40 | 41 | if __name__ == '__main__': 42 | logging.set_verbosity(logging.INFO) 43 | start_time = time.time() 44 | 45 | # Loads the dataset to build the vocabulary from. We use the train split. 46 | dataset = tfds.load('glue/sst2', split='train') 47 | 48 | # Tokenizes the sequences in the dataset and keeps only those. 49 | tokenized_sequences = get_tokenized_sequences(dataset) 50 | 51 | # Builds the vocabulary from the tokenized sequences. 52 | # A token needs to appear at least 3 times to be in the vocabulary. You can 53 | # play with this. It is there to make sure we don't overfit on rare words. 54 | vocab = vocabulary.Vocabulary( 55 | tokenized_sequences=tokenized_sequences, min_freq=3 56 | ) 57 | vocab.save('vocab.txt') 58 | 59 | logging.info('Total time elapsed: %f s', time.time() - start_time) 60 | -------------------------------------------------------------------------------- /benchmarks/tracing/vae.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The JAX Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """VAE helper functions.""" 15 | 16 | from typing import Any 17 | 18 | from flax.examples.vae import models 19 | from flax.examples.vae import train 20 | from flax.training import train_state 21 | import jax 22 | import jax.numpy as jnp 23 | import ml_collections 24 | import optax 25 | 26 | 27 | def get_fake_batch(batch_size: int) -> Any: 28 | """Returns fake data for the given batch size. 29 | 30 | Args: 31 | batch_size: The global batch size to generate. 32 | 33 | Returns: 34 | A properly sharded global batch of data. 35 | """ 36 | return jnp.ones((batch_size, 784), jnp.float32) 37 | 38 | 39 | def get_apply_fn_and_args( 40 | config: ml_collections.ConfigDict, 41 | ) -> tuple[Any, tuple[Any, ...], dict[str, Any]]: 42 | """Returns the apply function and args for the given config. 43 | 44 | Args: 45 | config: The training configuration. 46 | 47 | Returns: 48 | A tuple of the apply function, args and kwargs for the apply function, and 49 | any metadata the training loop needs. 50 | """ 51 | rng = jax.random.key(0) 52 | rng, key = jax.random.split(rng) 53 | batch = get_fake_batch(config.batch_size) 54 | params = models.model(config.latents).init(key, batch, rng)['params'] 55 | state = train_state.TrainState.create( 56 | apply_fn=models.model(config.latents).apply, 57 | params=params, 58 | tx=optax.adam(config.learning_rate), 59 | ) 60 | # Wrap with jit, making latents static since it's a python int 61 | train_step_jit = jax.jit(train.train_step, static_argnames=('latents',)) 62 | return ( 63 | train_step_jit, 64 | (state, batch, rng, config.latents), 65 | dict(), 66 | ) 67 | -------------------------------------------------------------------------------- /flax/traceback_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Flax specific traceback_util functions.""" 16 | 17 | from jax._src import traceback_util as jax_traceback_util 18 | from jax.extend import source_info_util 19 | 20 | from flax import config 21 | 22 | # pylint: disable=protected-access 23 | 24 | # Globals: 25 | # Whether to filter flax frames from traceback. 26 | _flax_filter_tracebacks = config.flax_filter_frames 27 | # Flax specific set of paths to exclude from tracebacks. 28 | _flax_exclusions = set() 29 | 30 | 31 | # re-import JAX symbol for convenience. 32 | api_boundary = jax_traceback_util.api_boundary 33 | 34 | 35 | def register_exclusion(path): 36 | """Marks a Flax source file for exclusion.""" 37 | global _flax_exclusions, _flax_filter_tracebacks 38 | # Record flax exclusions so we can dynamically add and remove them. 39 | _flax_exclusions.add(path) 40 | if _flax_filter_tracebacks: 41 | jax_traceback_util.register_exclusion(path) 42 | source_info_util.register_exclusion(path) 43 | 44 | def hide_flax_in_tracebacks(): 45 | """Hides Flax internal stack frames in tracebacks.""" 46 | global _flax_exclusions, _flax_filter_tracebacks 47 | _flax_filter_tracebacks = True 48 | for exclusion in _flax_exclusions: 49 | if exclusion not in jax_traceback_util._exclude_paths: 50 | jax_traceback_util._exclude_paths.append(exclusion) 51 | 52 | 53 | def show_flax_in_tracebacks(): 54 | """Shows Flax internal stack frames in tracebacks.""" 55 | global _flax_exclusions, _flax_filter_tracebacks 56 | _flax_filter_tracebacks = False 57 | for exclusion in _flax_exclusions: 58 | if exclusion in jax_traceback_util._exclude_paths: 59 | jax_traceback_util._exclude_paths.remove(exclusion) 60 | -------------------------------------------------------------------------------- /.github/analytics/pr_data_query.gql: -------------------------------------------------------------------------------- 1 | query { 2 | # Queries all the Pull Requests in a repo. For each issue, we get some basic data such as 3 | # the number, state, reviews, and title. The most important part is the 'timelineItems' 4 | # which are the events that happened to the issue, we can use the information about 5 | # the datetime about certain key events to define some metrics. We also use the 'reviews' 6 | # as indicators for certain metrics. Note that we are getting more information than is 7 | # probably needed but its fine for now. 8 | repository(owner:"_REPO_OWNER_", name:"_REPO_NAME_") { 9 | pullRequests(first:100) { 10 | totalCount 11 | edges { 12 | cursor 13 | node { 14 | number 15 | state 16 | title 17 | createdAt 18 | author{ 19 | login 20 | } 21 | mergedAt 22 | reviews(first: 100){ 23 | nodes { 24 | createdAt 25 | } 26 | } 27 | timelineItems(first: 100, itemTypes: [LABELED_EVENT, ASSIGNED_EVENT, MERGED_EVENT, READY_FOR_REVIEW_EVENT, CLOSED_EVENT]) { 28 | edges { 29 | node { 30 | __typename 31 | ... on ClosedEvent { 32 | actor { 33 | login 34 | } 35 | createdAt 36 | } 37 | ... on LabeledEvent { 38 | label { 39 | name 40 | } 41 | actor { 42 | login 43 | } 44 | createdAt 45 | } 46 | ... on MergedEvent { 47 | actor { 48 | login 49 | } 50 | createdAt 51 | } 52 | ... on ReadyForReviewEvent { 53 | actor { 54 | login 55 | } 56 | createdAt 57 | } 58 | ... on AssignedEvent { 59 | actor { 60 | login 61 | } 62 | createdAt 63 | } 64 | } 65 | } 66 | } 67 | } 68 | } 69 | } 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /benchmarks/tracing/sst2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The JAX Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """SST2 helper functions.""" 15 | 16 | from typing import Any 17 | 18 | from flax.examples.sst2 import train 19 | import jax 20 | import jax.numpy as jnp 21 | import ml_collections 22 | 23 | 24 | def get_fake_batch(batch_size: int) -> dict[str, Any]: 25 | """Returns fake data for the given batch size. 26 | 27 | Args: 28 | batch_size: The global batch size to generate. 29 | 30 | Returns: 31 | A fake batch dictionary with token_ids, length, and label. 32 | """ 33 | rng = jax.random.key(0) 34 | max_length = 60 35 | token_ids = jax.random.randint( 36 | rng, (batch_size, max_length), 0, 1000, jnp.int32 37 | ) 38 | lengths = jnp.full((batch_size,), max_length, jnp.int32) 39 | labels = jax.random.uniform(rng, (batch_size,), jnp.float32) 40 | return { 41 | 'token_ids': token_ids, 42 | 'length': lengths, 43 | 'label': labels, 44 | } 45 | 46 | 47 | def get_apply_fn_and_args( 48 | config: ml_collections.ConfigDict, 49 | ) -> tuple[Any, tuple[Any, ...], dict[str, Any]]: 50 | """Returns the apply function and args for the given config. 51 | 52 | Args: 53 | config: The training configuration. 54 | 55 | Returns: 56 | A tuple of the apply function, args, kwargs, and any metadata. 57 | """ 58 | rng = jax.random.key(0) 59 | config = config.copy_and_resolve_references() 60 | if config.vocab_size is None: 61 | config.vocab_size = 1000 62 | model = train.model_from_config(config) 63 | state = train.create_train_state(rng, config, model) 64 | batch = get_fake_batch(config.batch_size) 65 | _, dropout_rng = jax.random.split(rng) 66 | rngs = {'dropout': dropout_rng} 67 | train_step_jit = jax.jit(train.train_step) 68 | return train_step_jit, (state, batch, rngs), {} 69 | -------------------------------------------------------------------------------- /examples/sst2/train_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for sst2.train.""" 16 | import sys 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | import jax 21 | import jax.test_util 22 | import numpy as np 23 | 24 | from configs import default as default_config 25 | import train 26 | 27 | 28 | # Parse absl flags test_srcdir and test_tmpdir. 29 | jax.config.parse_flags_with_absl() 30 | 31 | 32 | class TrainTest(parameterized.TestCase): 33 | 34 | def test_train_step_updates_parameters(self): 35 | """Tests if the train step updates the parameters in train state.""" 36 | # Create model and a state that contains the parameters. 37 | config = default_config.get_config() 38 | config.vocab_size = 13 39 | rng = jax.random.key(config.seed) 40 | model = train.model_from_config(config) 41 | state = train.create_train_state(rng, config, model) 42 | 43 | token_ids = np.array([[2, 4, 3], [2, 6, 3]], dtype=np.int32) 44 | lengths = np.array([2, 3], dtype=np.int32) 45 | labels = np.zeros_like(lengths) 46 | batch = {'token_ids': token_ids, 'length': lengths, 'label': labels} 47 | rngs = {'dropout': rng} 48 | train_step_fn = jax.jit(train.train_step) 49 | new_state, metrics = train_step_fn(state, batch, rngs) 50 | self.assertIsInstance(new_state, train.TrainState) 51 | self.assertIsInstance(metrics, train.Metrics) 52 | old_param_values = jax.tree_util.tree_leaves(state.params) 53 | new_param_values = jax.tree_util.tree_leaves(new_state.params) 54 | for old_array, new_array in zip(old_param_values, new_param_values): 55 | # Make sure parameters were updated. 56 | self.assertFalse(np.allclose(old_array, new_array)) 57 | 58 | 59 | if __name__ == '__main__': 60 | absltest.main() 61 | -------------------------------------------------------------------------------- /examples/sst2/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Main file for running the SST2 example. 16 | This file is intentionally kept short. The majority for logic is in libraries 17 | that can be easily tested and imported in Colab. 18 | """ 19 | 20 | from absl import app 21 | from absl import flags 22 | from absl import logging 23 | from clu import platform 24 | import jax 25 | from ml_collections import config_flags 26 | import tensorflow as tf 27 | 28 | import train 29 | 30 | 31 | FLAGS = flags.FLAGS 32 | 33 | flags.DEFINE_string('workdir', None, 'Directory to store model data.') 34 | config_flags.DEFINE_config_file( 35 | 'config', 36 | None, 37 | 'File path to the training hyperparameter configuration.', 38 | lock_config=True, 39 | ) 40 | flags.mark_flags_as_required(['config', 'workdir']) 41 | 42 | 43 | def main(argv): 44 | if len(argv) > 1: 45 | raise app.UsageError('Too many command-line arguments.') 46 | 47 | # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make 48 | # it unavailable to JAX. 49 | tf.config.experimental.set_visible_devices([], 'GPU') 50 | 51 | logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) 52 | logging.info('JAX local devices: %r', jax.local_devices()) 53 | 54 | # Add a note so that we can tell which task is which JAX host. 55 | # (Depending on the platform task 0 is not guaranteed to be host 0) 56 | platform.work_unit().set_task_status( 57 | f'process_index: {jax.process_index()}, ' 58 | f'process_count: {jax.process_count()}' 59 | ) 60 | platform.work_unit().create_artifact( 61 | platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir' 62 | ) 63 | 64 | train.train_and_evaluate(FLAGS.config, FLAGS.workdir) 65 | 66 | 67 | if __name__ == '__main__': 68 | app.run(main) 69 | -------------------------------------------------------------------------------- /examples/mnist/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Main file for running the MNIST example. 16 | 17 | This file is intentionally kept short. The majority of logic is in libraries 18 | than can be easily tested and imported in Colab. 19 | """ 20 | 21 | from absl import app 22 | from absl import flags 23 | from absl import logging 24 | from clu import platform 25 | import jax 26 | from ml_collections import config_flags 27 | import tensorflow as tf 28 | 29 | import train 30 | 31 | 32 | FLAGS = flags.FLAGS 33 | 34 | flags.DEFINE_string('workdir', None, 'Directory to store model data.') 35 | config_flags.DEFINE_config_file( 36 | 'config', 37 | None, 38 | 'File path to the training hyperparameter configuration.', 39 | lock_config=True, 40 | ) 41 | 42 | 43 | def main(argv): 44 | if len(argv) > 1: 45 | raise app.UsageError('Too many command-line arguments.') 46 | 47 | # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make 48 | # it unavailable to JAX. 49 | tf.config.experimental.set_visible_devices([], 'GPU') 50 | 51 | logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) 52 | logging.info('JAX local devices: %r', jax.local_devices()) 53 | 54 | # Add a note so that we can tell which task is which JAX host. 55 | # (Depending on the platform task 0 is not guaranteed to be host 0) 56 | platform.work_unit().set_task_status( 57 | f'process_index: {jax.process_index()}, ' 58 | f'process_count: {jax.process_count()}' 59 | ) 60 | platform.work_unit().create_artifact( 61 | platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir' 62 | ) 63 | 64 | train.train_and_evaluate(FLAGS.config, FLAGS.workdir) 65 | 66 | 67 | if __name__ == '__main__': 68 | flags.mark_flags_as_required(['config', 'workdir']) 69 | app.run(main) 70 | -------------------------------------------------------------------------------- /examples/imagenet/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Main file for running the ImageNet example. 16 | 17 | This file is intentionally kept short. The majority for logic is in libraries 18 | that can be easily tested and imported in Colab. 19 | """ 20 | 21 | from absl import app 22 | from absl import flags 23 | from absl import logging 24 | from clu import platform 25 | import jax 26 | from ml_collections import config_flags 27 | import tensorflow as tf 28 | 29 | import train 30 | 31 | 32 | FLAGS = flags.FLAGS 33 | 34 | flags.DEFINE_string('workdir', None, 'Directory to store model data.') 35 | config_flags.DEFINE_config_file( 36 | 'config', 37 | None, 38 | 'File path to the training hyperparameter configuration.', 39 | lock_config=True, 40 | ) 41 | 42 | 43 | def main(argv): 44 | if len(argv) > 1: 45 | raise app.UsageError('Too many command-line arguments.') 46 | 47 | # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make 48 | # it unavailable to JAX. 49 | tf.config.experimental.set_visible_devices([], 'GPU') 50 | 51 | logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) 52 | logging.info('JAX local devices: %r', jax.local_devices()) 53 | 54 | # Add a note so that we can tell which task is which JAX host. 55 | # (Depending on the platform task 0 is not guaranteed to be host 0) 56 | platform.work_unit().set_task_status( 57 | f'process_index: {jax.process_index()}, ' 58 | f'process_count: {jax.process_count()}' 59 | ) 60 | platform.work_unit().create_artifact( 61 | platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir' 62 | ) 63 | 64 | train.train_and_evaluate(FLAGS.config, FLAGS.workdir) 65 | 66 | 67 | if __name__ == '__main__': 68 | flags.mark_flags_as_required(['config', 'workdir']) 69 | app.run(main) 70 | -------------------------------------------------------------------------------- /examples/lm1b/train_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pathlib 16 | import sys 17 | import tempfile 18 | 19 | from absl import logging 20 | from absl.testing import absltest 21 | import jax 22 | import tensorflow as tf 23 | import tensorflow_datasets as tfds 24 | 25 | from configs import default 26 | import train 27 | 28 | 29 | jax.config.update('jax_disable_most_optimizations', True) 30 | 31 | 32 | class TrainTest(absltest.TestCase): 33 | """Test cases for LM library.""" 34 | 35 | def setUp(self): 36 | super().setUp() 37 | if sys.version_info >= (3, 13): 38 | self.skipTest('Test (and tensorflow-text) does not suport Python 3.13+') 39 | tf.config.experimental.set_visible_devices([], 'GPU') 40 | 41 | def test_train_and_evaluate(self): 42 | config = default.get_config() 43 | config.max_corpus_chars = 1000 44 | config.vocab_size = 32 45 | config.per_device_batch_size = 2 46 | config.num_train_steps = 1 47 | config.num_eval_steps = 1 48 | config.num_predict_steps = 1 49 | 50 | config.num_layers = 1 51 | config.qkv_dim = 128 52 | config.emb_dim = 128 53 | config.mlp_dim = 512 54 | config.num_heads = 2 55 | 56 | config.max_target_length = 32 57 | config.max_eval_target_length = 32 58 | config.max_predict_length = 32 59 | 60 | workdir = tempfile.mkdtemp() 61 | 62 | # Go two directories up to the root of the flax directory. 63 | flax_root_dir = pathlib.Path(__file__).parents[2] 64 | data_dir = str(flax_root_dir) + '/.tfds/metadata' # pylint: disable=unused-variable 65 | 66 | with tfds.testing.mock_data(num_examples=128, data_dir=data_dir): 67 | train.train_and_evaluate(config, workdir) 68 | logging.info('workdir content: %s', tf.io.gfile.listdir(workdir)) 69 | 70 | 71 | if __name__ == '__main__': 72 | absltest.main() 73 | --------------------------------------------------------------------------------