├── CHANGELOG.md ├── examples ├── image │ ├── requirements.txt │ ├── training │ │ ├── data_transform.py │ │ ├── edm_time_discretization.py │ │ ├── grad_scaler.py │ │ ├── load_and_save.py │ │ ├── distributed_mode.py │ │ └── train_loop.py │ ├── models │ │ ├── ema.py │ │ ├── discrete_unet.py │ │ ├── model_configs.py │ │ └── nn.py │ ├── README.md │ ├── submitit_train.py │ └── train_arg_parser.py ├── text │ ├── logic │ │ ├── __init__.py │ │ ├── generate.py │ │ ├── state.py │ │ ├── flow.py │ │ ├── training.py │ │ └── evaluate.py │ ├── utils │ │ ├── __init__.py │ │ ├── checkpointing.py │ │ └── logging.py │ ├── data │ │ ├── __init__.py │ │ ├── tokenizer.py │ │ ├── utils.py │ │ └── data.py │ ├── model │ │ ├── __init__.py │ │ └── rotary.py │ ├── environment.yml │ ├── run_train.py │ ├── configs │ │ └── config.yaml │ ├── scripts │ │ ├── run_eval.py │ │ └── eval.py │ └── README.md └── README.md ├── assets ├── teaser.png ├── arXiv-2412.06264-red.svg └── License-CC_BY--NC_4.0-lightgrey.svg ├── docs ├── source │ ├── _images │ │ ├── discrete.png │ │ ├── standalone.png │ │ ├── riemannian_sphere.png │ │ └── riemannian_torus.png │ ├── references.rst │ ├── dummy.rst │ ├── _static │ │ └── css │ │ │ └── custom.css │ ├── modules.rst │ ├── _templates │ │ └── classtemplate.rst │ ├── flow_matching.loss.rst │ ├── flow_matching.solver.rst │ ├── flow_matching.utils.model_wrapper.rst │ ├── flow_matching.utils.manifolds.rst │ ├── flow_matching.path.scheduler.rst │ ├── installation.rst │ ├── flow_matching.path.rst │ ├── index.rst │ ├── notebooks.rst │ ├── conf.py │ └── refs.bib ├── deps.yml ├── _templates │ └── classtemplate.rst ├── server.py ├── README.md ├── deploy.py └── Makefile ├── tests ├── __init__.py ├── path │ ├── __init__.py │ ├── test_schedule_transform.py │ ├── test_scheduler.py │ └── test_path.py ├── solver │ ├── __init__.py │ └── test_discrete_solver.py └── utils │ ├── __init__.py │ └── test_utils.py ├── flow_matching ├── __init__.py ├── loss │ ├── __init__.py │ └── generalized_loss.py ├── solver │ ├── solver.py │ ├── __init__.py │ └── utils.py ├── utils │ ├── manifolds │ │ ├── __init__.py │ │ ├── torus.py │ │ ├── utils.py │ │ ├── sphere.py │ │ └── manifold.py │ ├── __init__.py │ ├── categorical_sampler.py │ ├── model_wrapper.py │ └── utils.py └── path │ ├── __init__.py │ ├── scheduler │ ├── __init__.py │ ├── schedule_transform.py │ └── scheduler.py │ ├── path_sample.py │ ├── path.py │ ├── geodesic.py │ └── mixture.py ├── RELEASE.md ├── .flake8 ├── .pre-commit-config.yaml ├── environment.yml ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── workflows │ ├── notebooks.yaml │ └── ci.yaml ├── CONTRIBUTING.md ├── .gitignore ├── setup.py ├── CODE_OF_CONDUCT.md └── README.md /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Change log 2 | 3 | ## [0.1] - 2024-12-01 4 | 5 | - Initial release. -------------------------------------------------------------------------------- /examples/image/requirements.txt: -------------------------------------------------------------------------------- 1 | submitit 2 | torchmetrics[image] 3 | torchvision 4 | -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/flow_matching/HEAD/assets/teaser.png -------------------------------------------------------------------------------- /docs/source/_images/discrete.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/flow_matching/HEAD/docs/source/_images/discrete.png -------------------------------------------------------------------------------- /docs/source/_images/standalone.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/flow_matching/HEAD/docs/source/_images/standalone.png -------------------------------------------------------------------------------- /docs/source/references.rst: -------------------------------------------------------------------------------- 1 | References 2 | ------ 3 | 4 | 5 | .. bibliography:: 6 | :list: enumerated 7 | :all: 8 | :notcited: 9 | -------------------------------------------------------------------------------- /docs/source/_images/riemannian_sphere.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/flow_matching/HEAD/docs/source/_images/riemannian_sphere.png -------------------------------------------------------------------------------- /docs/source/_images/riemannian_torus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/flow_matching/HEAD/docs/source/_images/riemannian_torus.png -------------------------------------------------------------------------------- /docs/deps.yml: -------------------------------------------------------------------------------- 1 | dependencies: 2 | - pandoc 3 | - pip: 4 | - sphinx 5 | - sphinxcontrib-katex 6 | - nbsphinx 7 | - sphinxcontrib.bibtex 8 | - pydata-sphinx-theme -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/path/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/solver/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /examples/text/logic/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /examples/text/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /flow_matching/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | __version__ = "1.0.10" 8 | -------------------------------------------------------------------------------- /docs/source/dummy.rst: -------------------------------------------------------------------------------- 1 | .. toctree:: 2 | :maxdepth: 0 3 | :hidden: 4 | :titlesonly: 5 | 6 | notebooks/standalone_flow_matching 7 | notebooks/2d_discrete_flow_matching 8 | notebooks/2d_riemannian_flow_matching_flat_torus 9 | notebooks/2d_riemannian_flow_matching_sphere 10 | -------------------------------------------------------------------------------- /docs/source/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Meta Platforms, Inc. and affiliates. 2 | All rights reserved. 3 | 4 | This source code is licensed under the CC-by-NC license found in the 5 | LICENSE file in the root directory of this source tree. */ 6 | 7 | div.math { justify-content: center } 8 | -------------------------------------------------------------------------------- /docs/source/modules.rst: -------------------------------------------------------------------------------- 1 | API Reference 2 | =============================== 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | flow_matching.loss 8 | flow_matching.path 9 | flow_matching.path.scheduler 10 | flow_matching.solver 11 | flow_matching.utils.model_wrapper 12 | flow_matching.utils.manifolds 13 | -------------------------------------------------------------------------------- /examples/text/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .data import DataState 8 | 9 | __all__ = ["DataState"] 10 | -------------------------------------------------------------------------------- /docs/_templates/classtemplate.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: {{ module }} 4 | 5 | 6 | {{ name | underline}} 7 | 8 | .. autoclass:: {{ name }} 9 | :members: 10 | 11 | 12 | .. 13 | autogenerated from source/_templates/classtemplate.rst 14 | note it does not have :inherited-members: -------------------------------------------------------------------------------- /docs/source/_templates/classtemplate.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: {{ module }} 4 | 5 | 6 | {{ name | underline}} 7 | 8 | .. autoclass:: {{ name }} 9 | :members: 10 | 11 | 12 | .. 13 | autogenerated from source/_templates/classtemplate.rst 14 | note it does not have :inherited-members: -------------------------------------------------------------------------------- /examples/text/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .transformer import Transformer 8 | 9 | __all__ = [ 10 | "Transformer", 11 | ] 12 | -------------------------------------------------------------------------------- /RELEASE.md: -------------------------------------------------------------------------------- 1 | # Release Instructions 2 | 3 | Build a wheel: 4 | 5 | ``` 6 | pip wheel --no-deps . --wheel-dir dist 7 | ``` 8 | 9 | In your home directory, create `~/.pypirc` with the following: 10 | 11 | ``` 12 | [pypi] 13 | username = __token__ 14 | password = 15 | ``` 16 | 17 | Upload the wheel: 18 | 19 | ``` 20 | twine upload dist/* 21 | ``` 22 | -------------------------------------------------------------------------------- /docs/source/flow_matching.loss.rst: -------------------------------------------------------------------------------- 1 | ``flow_matching.loss`` 2 | ============================= 3 | 4 | .. currentmodule:: flow_matching.loss 5 | 6 | 7 | MixturePathGeneralizedKL 8 | -------------------------------- 9 | 10 | .. autosummary:: 11 | :toctree: generated 12 | :nosignatures: 13 | :template: classtemplate.rst 14 | 15 | MixturePathGeneralizedKL 16 | 17 | -------------------------------------------------------------------------------- /flow_matching/loss/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .generalized_loss import MixturePathGeneralizedKL 8 | 9 | __all__ = [ 10 | "MixturePathGeneralizedKL", 11 | ] 12 | -------------------------------------------------------------------------------- /examples/text/environment.yml: -------------------------------------------------------------------------------- 1 | name: discrete_flow_matching 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - nvidia 6 | dependencies: 7 | - python=3.10 8 | - numpy 9 | - pip 10 | - tqdm 11 | - pip: 12 | - torch>=2.5.0 13 | - hydra-core 14 | - hydra-submitit-launcher 15 | - datasets 16 | - transformers 17 | - wandb 18 | - einops 19 | - flow_matching -------------------------------------------------------------------------------- /docs/source/flow_matching.solver.rst: -------------------------------------------------------------------------------- 1 | ``flow_matching.solver`` 2 | ============================= 3 | 4 | .. currentmodule:: flow_matching.solver 5 | 6 | Solvers 7 | ------- 8 | 9 | .. autosummary:: 10 | :toctree: generated 11 | :nosignatures: 12 | :template: classtemplate.rst 13 | 14 | Solver 15 | ODESolver 16 | MixtureDiscreteEulerSolver 17 | RiemannianODESolver 18 | 19 | -------------------------------------------------------------------------------- /docs/source/flow_matching.utils.model_wrapper.rst: -------------------------------------------------------------------------------- 1 | ``flow_matching.utils.model_wrapper`` 2 | ============================= 3 | 4 | .. currentmodule:: flow_matching.utils.model_wrapper 5 | 6 | 7 | ModelWrapper 8 | -------------------------------- 9 | 10 | .. autosummary:: 11 | :toctree: generated 12 | :nosignatures: 13 | :template: classtemplate.rst 14 | 15 | ModelWrapper 16 | 17 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | extend-ignore = 3 | B006 4 | B007 5 | B008 6 | B010 7 | B023 8 | B028 9 | B601 10 | C403 11 | C405 12 | C408 13 | C416 14 | C417 15 | C419 16 | E203 17 | E402 18 | E501 19 | E731 20 | W391 21 | W605 22 | exclude=build,notebooks,protobuf 23 | 24 | # ignore unused imports in __init__.py files 25 | per-file-ignores = 26 | __init__.py:F401 -------------------------------------------------------------------------------- /docs/server.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from fastapi import FastAPI 7 | from fastapi.staticfiles import StaticFiles 8 | 9 | app = FastAPI() 10 | 11 | app.mount("/", StaticFiles(directory="build/html", html=True), name="static") 12 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/omnilib/ufmt 3 | rev: v2.3.0 4 | hooks: 5 | - id: ufmt 6 | additional_dependencies: 7 | - black == 22.6.0 8 | - usort == 1.0.4 9 | - repo: https://github.com/pycqa/flake8 10 | rev: 7.0.0 11 | hooks: 12 | - id: flake8 13 | - repo: https://github.com/jsh9/pydoclint 14 | rev: 0.5.9 15 | hooks: 16 | - id: pydoclint 17 | args: [--style=google, flow_matching] 18 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: flow_matching 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - nvidia 6 | dependencies: 7 | - python=3.9 8 | - pytorch 9 | - pytorch-cuda 10 | - matplotlib 11 | - jupyter 12 | - numpy 13 | - pip 14 | - tqdm 15 | - pip: 16 | - pre-commit 17 | - black==22.6.0 18 | - usort==1.0.4 19 | - ufmt==2.3.0 20 | - flake8==7.0.0 21 | - ipykernel 22 | - torchdiffeq 23 | - scikit-learn 24 | - pydoclint 25 | - coverage 26 | -------------------------------------------------------------------------------- /flow_matching/solver/solver.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from abc import ABC, abstractmethod 8 | 9 | from torch import nn, Tensor 10 | 11 | 12 | class Solver(ABC, nn.Module): 13 | """Abstract base class for solvers.""" 14 | 15 | @abstractmethod 16 | def sample(self, x_0: Tensor = None) -> Tensor: 17 | ... 18 | -------------------------------------------------------------------------------- /flow_matching/utils/manifolds/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .manifold import Euclidean, Manifold 8 | from .sphere import Sphere 9 | from .torus import FlatTorus 10 | from .utils import geodesic 11 | 12 | __all__ = [ 13 | "Euclidean", 14 | "Manifold", 15 | "Sphere", 16 | "FlatTorus", 17 | "geodesic", 18 | ] 19 | -------------------------------------------------------------------------------- /flow_matching/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .categorical_sampler import categorical 8 | from .model_wrapper import ModelWrapper 9 | from .utils import expand_tensor_like, gradient, unsqueeze_to_match 10 | 11 | __all__ = [ 12 | "unsqueeze_to_match", 13 | "expand_tensor_like", 14 | "gradient", 15 | "categorical", 16 | "ModelWrapper", 17 | ] 18 | -------------------------------------------------------------------------------- /examples/image/training/data_transform.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import torch 7 | from torchvision.transforms.v2 import Compose, RandomHorizontalFlip, ToDtype, ToImage 8 | 9 | 10 | def get_train_transform(): 11 | transform_list = [ 12 | ToImage(), 13 | RandomHorizontalFlip(), 14 | ToDtype(torch.float32, scale=True), 15 | ] 16 | return Compose(transform_list) 17 | -------------------------------------------------------------------------------- /flow_matching/solver/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .discrete_solver import MixtureDiscreteEulerSolver 8 | from .ode_solver import ODESolver 9 | from .riemannian_ode_solver import RiemannianODESolver 10 | from .solver import Solver 11 | 12 | __all__ = [ 13 | "ODESolver", 14 | "Solver", 15 | "ModelWrapper", 16 | "MixtureDiscreteEulerSolver", 17 | "RiemannianODESolver", 18 | ] 19 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Minimal working code example that reproduces the behaviour. 15 | 16 | **Expected behavior** 17 | A clear and concise description of what you expected to happen. 18 | 19 | **Screenshots** 20 | If applicable, add screenshots to help explain your problem. 21 | 22 | **Additional context** 23 | Add any other context about the problem here. 24 | -------------------------------------------------------------------------------- /docs/source/flow_matching.utils.manifolds.rst: -------------------------------------------------------------------------------- 1 | ``flow_matching.utils.manifolds`` 2 | ================================= 3 | 4 | .. currentmodule:: flow_matching.utils.manifolds 5 | 6 | 7 | Manifold 8 | ----------------- 9 | 10 | Manifold classes for logarithmic and exponential map projections 11 | 12 | .. autosummary:: 13 | :toctree: generated 14 | :nosignatures: 15 | :template: classtemplate.rst 16 | 17 | Manifold 18 | Sphere 19 | FlatTorus 20 | 21 | Utility Functions 22 | ----------------- 23 | 24 | .. autosummary:: 25 | :toctree: generated 26 | :nosignatures: 27 | :template: classtemplate.rst 28 | 29 | geodesic -------------------------------------------------------------------------------- /flow_matching/solver/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor 9 | 10 | 11 | def get_nearest_times(time_grid: Tensor, t_discretization: Tensor) -> Tensor: 12 | distances = torch.cdist( 13 | time_grid.unsqueeze(1), 14 | t_discretization.unsqueeze(1), 15 | compute_mode="donot_use_mm_for_euclid_dist", 16 | ) 17 | nearest_indices = distances.argmin(dim=1) 18 | 19 | return t_discretization[nearest_indices] 20 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /flow_matching/path/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .affine import AffineProbPath, CondOTProbPath 8 | from .geodesic import GeodesicProbPath 9 | from .mixture import MixtureDiscreteProbPath 10 | from .path import ProbPath 11 | from .path_sample import DiscretePathSample, PathSample 12 | 13 | 14 | __all__ = [ 15 | "ProbPath", 16 | "AffineProbPath", 17 | "CondOTProbPath", 18 | "MixtureDiscreteProbPath", 19 | "GeodesicProbPath", 20 | "PathSample", 21 | "DiscretePathSample", 22 | ] 23 | -------------------------------------------------------------------------------- /docs/source/flow_matching.path.scheduler.rst: -------------------------------------------------------------------------------- 1 | ``flow_matching.path.scheduler`` 2 | ================================= 3 | 4 | .. currentmodule:: flow_matching.path.scheduler 5 | 6 | Scheduler 7 | ---------- 8 | 9 | .. autosummary:: 10 | :toctree: generated 11 | :nosignatures: 12 | :template: classtemplate.rst 13 | 14 | Scheduler 15 | CondOTScheduler 16 | CosineScheduler 17 | VPScheduler 18 | PolynomialConvexScheduler 19 | 20 | ScheduleTransformedModel 21 | ------------------------ 22 | 23 | ScheduleTransformedModel wraps a given model and converts its scheduler 24 | 25 | .. autosummary:: 26 | :toctree: generated 27 | :nosignatures: 28 | :template: classtemplate.rst 29 | 30 | ScheduleTransformedModel 31 | -------------------------------------------------------------------------------- /flow_matching/utils/categorical_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor 9 | 10 | 11 | def categorical(probs: Tensor) -> Tensor: 12 | r"""Categorical sampler according to weights in the last dimension of ``probs`` using :func:`torch.multinomial`. 13 | 14 | Args: 15 | probs (Tensor): probabilities. 16 | 17 | Returns: 18 | Tensor: Samples. 19 | """ 20 | 21 | return torch.multinomial(probs.flatten(0, -2), 1, replacement=True).view( 22 | *probs.shape[:-1] 23 | ) 24 | -------------------------------------------------------------------------------- /docs/source/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | This repository requires Python 3.9 and Pytorch 2.1 or greater. To install the latest version run: 5 | 6 | :: 7 | 8 | pip install flow-matching 9 | 10 | Development 11 | ----------------- 12 | 13 | To create a conda environment with all required dependencies, run: 14 | 15 | :: 16 | 17 | conda env create -f environment.yml 18 | conda activate flow_matching 19 | 20 | Install pre-commit hook. This will ensure that all linting is done on each commit 21 | 22 | :: 23 | 24 | pre-commit install 25 | conda activate flow_matching 26 | 27 | 28 | Install the `flow_matching` package in an editable mode: 29 | 30 | :: 31 | 32 | pip install -e . 33 | 34 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | ## How to build docs 2 | 3 | Install `sphinx` 4 | 5 | ``` 6 | conda env update --file deps.yml 7 | ``` 8 | 9 | Build HTML 10 | 11 | ``` 12 | make html 13 | ``` 14 | 15 | Start server to view the html 16 | 17 | ``` 18 | cd build/html && python3 -m http.server 19 | ``` 20 | 21 | To run auto-update the server when files change (`pip install fastapi[standard]`): 22 | 23 | ``` 24 | make serve 25 | ``` 26 | 27 | ## Adding to Papers 28 | 29 | The "/papers" page lists relevant papers. To add, insert a bibtex citation to `source/refs.bib`. The order in which citations are listed is the order that they will appear in the page. 30 | 31 | ## Deploy 32 | 33 | To deploy the docs (in the current branch) to github pages, run `make deploy` 34 | -------------------------------------------------------------------------------- /docs/source/flow_matching.path.rst: -------------------------------------------------------------------------------- 1 | ``flow_matching.path`` 2 | ============================= 3 | 4 | .. currentmodule:: flow_matching.path 5 | 6 | 7 | Probability Paths 8 | -------------------------------- 9 | 10 | .. autosummary:: 11 | :toctree: generated 12 | :nosignatures: 13 | :template: classtemplate.rst 14 | 15 | ProbPath 16 | AffineProbPath 17 | CondOTProbPath 18 | MixtureDiscreteProbPath 19 | GeodesicProbPath 20 | 21 | 22 | Path Sample 23 | -------------------------------- 24 | 25 | Corresponds to an instance of a sample drawn from the probability path. 26 | 27 | .. autosummary:: 28 | :toctree: generated 29 | :nosignatures: 30 | :template: classtemplate.rst 31 | 32 | path_sample.PathSample 33 | path_sample.DiscretePathSample 34 | 35 | -------------------------------------------------------------------------------- /flow_matching/path/scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .schedule_transform import ScheduleTransformedModel 8 | from .scheduler import ( 9 | CondOTScheduler, 10 | ConvexScheduler, 11 | CosineScheduler, 12 | LinearVPScheduler, 13 | PolynomialConvexScheduler, 14 | Scheduler, 15 | SchedulerOutput, 16 | VPScheduler, 17 | ) 18 | 19 | __all__ = [ 20 | "CondOTScheduler", 21 | "CosineScheduler", 22 | "ConvexScheduler", 23 | "PolynomialConvexScheduler", 24 | "ScheduleTransformedModel", 25 | "Scheduler", 26 | "VPScheduler", 27 | "LinearVPScheduler", 28 | "SchedulerOutput", 29 | ] 30 | -------------------------------------------------------------------------------- /.github/workflows/notebooks.yaml: -------------------------------------------------------------------------------- 1 | name: Notebooks 2 | on: push 3 | 4 | jobs: 5 | run-notebooks: 6 | runs-on: 4-core-ubuntu-gpu-t4 7 | steps: 8 | - uses: actions/checkout@v2 9 | - uses: mamba-org/setup-micromamba@v1.8.1 10 | with: 11 | environment-file: environment.yml 12 | cache-environment: true 13 | 14 | - name: Check GPU availability 15 | shell: bash -l {0} 16 | run: | 17 | python -c "import torch; print('Is CUDA available:', torch.cuda.is_available())" 18 | 19 | - name: Run notebooks 20 | shell: bash -l {0} 21 | run: | 22 | set -e 23 | export PYTHONPATH=${PWD}:$PYTHONPATH 24 | for file in examples/*.ipynb; do 25 | jupyter nbconvert --to notebook --execute --ExecutePreprocessor.timeout=300 "$file" 26 | done -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | ============= 2 | Flow Matching 3 | ============= 4 | 5 | `flow_matching` is a PyTorch library for implementing flow matching algorithms, featuring state-of-the-art continuous and discrete implementations. It includes practical examples for both text and image modalities. This repository is part of `Flow Matching Guide and Codebase `_. 6 | 7 | .. image:: _images/teaser.png 8 | :width: 800 9 | :align: center 10 | 11 | 12 | Table of contents 13 | ----------------- 14 | 15 | .. toctree:: 16 | :maxdepth: 1 17 | 18 | modules 19 | installation 20 | notebooks 21 | references 22 | 23 | Code index 24 | ================== 25 | 26 | * :ref:`genindex` 27 | * :ref:`search` 28 | 29 | Legal 30 | ----------------- 31 | 32 | * `Terms of Use `_ 33 | * `Privacy Policy `_ 34 | -------------------------------------------------------------------------------- /flow_matching/utils/manifolds/torus.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | import torch 10 | from torch import Tensor 11 | 12 | from flow_matching.utils.manifolds import Manifold 13 | 14 | 15 | class FlatTorus(Manifold): 16 | r"""Represents a flat torus on the :math:`[0, 2\pi]^D` subspace. Isometric to the product of 1-D spheres.""" 17 | 18 | def expmap(self, x: Tensor, u: Tensor) -> Tensor: 19 | return (x + u) % (2 * math.pi) 20 | 21 | def logmap(self, x: Tensor, y: Tensor) -> Tensor: 22 | return torch.atan2(torch.sin(y - x), torch.cos(y - x)) 23 | 24 | def projx(self, x: Tensor) -> Tensor: 25 | return x % (2 * math.pi) 26 | 27 | def proju(self, x: Tensor, u: Tensor) -> Tensor: 28 | return u 29 | -------------------------------------------------------------------------------- /examples/image/training/edm_time_discretization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """This is an ad-hoc sampling schedule that was proposed in https://arxiv.org/abs/2206.00364 it works very well for cifar 10 so we added its implementation here. It did not yield an improvement on ImageNet.""" 7 | import torch 8 | 9 | 10 | def get_time_discretization(nfes: int, rho=7): 11 | step_indices = torch.arange(nfes, dtype=torch.float64) 12 | sigma_min = 0.002 13 | sigma_max = 80.0 14 | sigma_vec = ( 15 | sigma_max ** (1 / rho) 16 | + step_indices / (nfes - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) 17 | ) ** rho 18 | sigma_vec = torch.cat([sigma_vec, torch.zeros_like(sigma_vec[:1])]) 19 | time_vec = (sigma_vec / (1 + sigma_vec)).squeeze() 20 | t_samples = 1.0 - torch.clip(time_vec, min=0.0, max=1.0) 21 | return t_samples 22 | -------------------------------------------------------------------------------- /docs/source/notebooks.rst: -------------------------------------------------------------------------------- 1 | Notebooks 2 | =============== 3 | 4 | 5 | 6 | .. customcardstart:: 7 | 8 | .. customcarditem:: 9 | :header: Simple Training/Sampling example 10 | :card_description: Train and sample from a 2D Flow Matching model. 11 | :image: _static/standalone.png 12 | :link: notebooks/standalone_flow_matching.html 13 | 14 | .. customcarditem:: 15 | :header: Discrete Flow Matching 16 | :card_description: Train and sample from a 2D Discrete Flow Matching model. 17 | :image: _static/discrete.png 18 | :link: notebooks/2d_discrete_flow_matching.html 19 | 20 | .. customcarditem:: 21 | :header: Riemannian Flow Matching (Sphere) 22 | :card_description: 2D sphere riemannian flow matching example 23 | :image: _static/riemannian_sphere.png 24 | :link: notebooks/2d_riemannian_flow_matching_sphere.html 25 | 26 | .. customcarditem:: 27 | :header: Riemannian Flow Matching (Flat Torus) 28 | :card_description: 2D flat torus riemannian flow matching example 29 | :image: _static/riemannian_torus.png 30 | :link: notebooks/2d_riemannian_flow_matching_flat_torus.html 31 | 32 | .. customcardend:: 33 | -------------------------------------------------------------------------------- /assets/arXiv-2412.06264-red.svg: -------------------------------------------------------------------------------- 1 | arXiv: 2412.06264arXiv2412.06264 -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to flow_matching 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | 6 | ## Pull Requests 7 | We actively welcome your pull requests. 8 | 9 | 1. Fork the repo and create your branch from `main`. 10 | 2. If you've added code that should be tested, add tests. 11 | 3. If you've changed APIs, update the documentation. 12 | 4. Ensure the test suite passes. 13 | 5. Make sure your code lints. 14 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 15 | 16 | ## Contributor License Agreement ("CLA") 17 | In order to accept your pull request, we need you to submit a CLA. You only need 18 | to do this once to work on any of Meta's open source projects. 19 | 20 | Complete your CLA here: 21 | 22 | ## Issues 23 | We use GitHub issues to track public bugs. Please ensure your description is 24 | clear and has sufficient instructions to be able to reproduce the issue. 25 | 26 | ## License 27 | By contributing to flow_matching, you agree that your contributions will be licensed 28 | under the LICENSE file in the root directory of this source tree. 29 | -------------------------------------------------------------------------------- /assets/License-CC_BY--NC_4.0-lightgrey.svg: -------------------------------------------------------------------------------- 1 | License: CC BY-NC 4.0LicenseCC BY-NC 4.0 -------------------------------------------------------------------------------- /docs/deploy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | import shutil 8 | from subprocess import check_call 9 | from tempfile import TemporaryDirectory 10 | 11 | this_dir = os.path.dirname(os.path.realpath(__file__)) 12 | 13 | remote = "git@github.com:fairinternal/flow_matching.git" 14 | branch = "gh-pages" 15 | 16 | 17 | with TemporaryDirectory() as tdir: 18 | local = os.path.join(tdir, "repo") 19 | shutil.copytree(os.path.join(this_dir, "build/html"), local) 20 | 21 | with open(os.path.join(local, ".nojekyll"), "w") as fout: 22 | print("", end="", file=fout) 23 | 24 | check_call(["git", "init", local]) 25 | check_call(["git", "remote", "add", "origin", remote], cwd=local) 26 | check_call(["git", "checkout", "-b", branch], cwd=local) 27 | 28 | check_call(["git", "add", "--all"], cwd=local) 29 | check_call(["git", "commit", "-m", "Update github pages"], cwd=local) 30 | 31 | check_call(["git", "push", "--set-upstream", "origin", "gh-pages", "-f"], cwd=local) 32 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Flow matching examples 2 | 3 | ## Image 4 | 5 | [Flow matching on images.](image/) Pixel space image generation using continuous flow matching. 6 | 7 | ## Text 8 | 9 | [Flow matching on text.](text/) Text generation using discrete flow matching. 10 | 11 | ## Notebooks 12 | 13 | | Notebook | Description | 14 | | --- | --- | 15 | | [standalone_flow_matching.ipynb](standalone_flow_matching.ipynb) | A concise flow matching example built in pure PyTorch. | 16 | | [standalone_discrete_flow_matching.ipynb](standalone_discrete_flow_matching.ipynb) | A concise discrete flow matching example built in pure PyTorch. | 17 | | [2d_flow_matching.ipynb](2d_flow_matching.ipynb) | 2D flow matching example on the checkerboard dataset using the flow_matching library. | 18 | | [2d_discrete_flow_matching.ipynb](2d_discrete_flow_matching.ipynb) | 2D discrete flow matching example on the checkerboard dataset using the flow_matching library. | 19 | | [2d_riemannian_flow_matching_flat_torus.ipynb](2d_riemannian_flow_matching_flat_torus.ipynb) | 2D Riemannian flow matching on a flat torus on the checkerboard dataset and the flow_matching library. | 20 | | [2d_riemannian_flow_matching_sphere.ipynb](2d_riemannian_flow_matching_sphere.ipynb) | 2D Riemannian flow matching on a sphere on the checkerboard dataset and the flow_matching library. | 21 | 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # VSCode 36 | *.vscode 37 | 38 | # Others 39 | examples/image_generation/data 40 | examples/*/output_dir* 41 | examples/image/scripts 42 | examples/image/outputs 43 | examples/image/data 44 | examples/image/output_dir 45 | examples/images/* 46 | examples/imagenet/* 47 | examples/image_generation/* 48 | examples/*.ignore 49 | examples/*/snapshots* 50 | examples/*/outputs 51 | 52 | examples/imagenet/scripts 53 | *.ipynb_checkpoints* 54 | 55 | make.bat 56 | docs/output 57 | docs/source/generated 58 | docs/source/notebooks 59 | docs/source/images 60 | **/*.ipynb_checkpoints/ 61 | 62 | projects/image_latent/cache 63 | projects/image_latent/vqvae_training/cache 64 | projects/image_latent/outputs 65 | */assets/ 66 | 67 | outputs/ 68 | output_dir/ 69 | 70 | *logs/ 71 | *mixture_uniform_step=320001/ 72 | *.out 73 | *.err -------------------------------------------------------------------------------- /tests/utils/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import unittest 7 | 8 | import torch 9 | from flow_matching.utils import expand_tensor_like, gradient, unsqueeze_to_match 10 | 11 | 12 | class TestUtils(unittest.TestCase): 13 | def test_unsqueeze_to_match_suffix(self): 14 | source = torch.randn(3) 15 | target = torch.randn(3, 4, 5) 16 | result = unsqueeze_to_match(source, target) 17 | self.assertEqual(result.shape, (3, 1, 1)) 18 | 19 | def test_unsqueeze_to_match_prefix(self): 20 | source = torch.randn(3) 21 | target = torch.randn(4, 5, 3) 22 | result = unsqueeze_to_match(source, target, how="prefix") 23 | self.assertEqual(result.shape, (1, 1, 3)) 24 | 25 | def test_expand_tensor_like(self): 26 | input_tensor = torch.randn(3) 27 | expand_to = torch.randn(3, 4, 5) 28 | result = expand_tensor_like(input_tensor, expand_to) 29 | self.assertEqual(result.shape, (3, 4, 5)) 30 | 31 | def test_gradient(self): 32 | x = torch.randn(3, requires_grad=True) 33 | output = x**2 34 | grad_outputs = torch.ones_like(output) 35 | result = gradient(output, x, grad_outputs=grad_outputs) 36 | self.assertTrue(torch.allclose(result, 2 * x)) 37 | 38 | 39 | if __name__ == "__main__": 40 | unittest.main() 41 | -------------------------------------------------------------------------------- /flow_matching/utils/manifolds/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Callable 8 | 9 | import torch 10 | from torch import Tensor 11 | 12 | from flow_matching.utils.manifolds import Manifold 13 | 14 | 15 | def geodesic( 16 | manifold: Manifold, start_point: Tensor, end_point: Tensor 17 | ) -> Callable[[Tensor], Tensor]: 18 | """Generate parameterized function for geodesic curve. 19 | 20 | Args: 21 | manifold (Manifold): the manifold to compute geodesic on. 22 | start_point (Tensor): point on the manifold at :math:`t=0`. 23 | end_point (Tensor): point on the manifold at :math:`t=1`. 24 | 25 | Returns: 26 | Callable[[Tensor], Tensor]: a function that takes in :math:`t` and outputs the geodesic at time :math:`t`. 27 | """ 28 | 29 | shooting_tangent_vec = manifold.logmap(start_point, end_point) 30 | 31 | def path(t: Tensor) -> Tensor: 32 | """Generate parameterized function for geodesic curve. 33 | 34 | Args: 35 | t (Tensor): Times at which to compute points of the geodesics. 36 | 37 | Returns: 38 | Tensor: geodesic path evaluated at time t. 39 | """ 40 | tangent_vecs = torch.einsum("i,...k->...ik", t, shooting_tangent_vec) 41 | points_at_time_t = manifold.expmap(start_point.unsqueeze(-2), tangent_vecs) 42 | 43 | return points_at_time_t 44 | 45 | return path 46 | -------------------------------------------------------------------------------- /flow_matching/utils/model_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from abc import ABC 8 | 9 | from torch import nn, Tensor 10 | 11 | 12 | class ModelWrapper(ABC, nn.Module): 13 | """ 14 | This class is used to wrap around another model, adding custom forward pass logic. 15 | """ 16 | 17 | def __init__(self, model: nn.Module): 18 | super().__init__() 19 | self.model = model 20 | 21 | def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor: 22 | r""" 23 | This method defines how inputs should be passed through the wrapped model. 24 | Here, we're assuming that the wrapped model takes both :math:`x` and :math:`t` as input, 25 | along with any additional keyword arguments. 26 | 27 | Optional things to do here: 28 | - check that t is in the dimensions that the model is expecting. 29 | - add a custom forward pass logic. 30 | - call the wrapped model. 31 | 32 | | given x, t 33 | | returns the model output for input x at time t, with extra information `extra`. 34 | 35 | Args: 36 | x (Tensor): input data to the model (batch_size, ...). 37 | t (Tensor): time (batch_size). 38 | **extras: additional information forwarded to the model, e.g., text condition. 39 | 40 | Returns: 41 | Tensor: model output. 42 | """ 43 | return self.model(x=x, t=t, **extras) 44 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | ROOT_DIR:=$(shell dirname $(realpath $(firstword $(MAKEFILE_LIST)))) 12 | 13 | links: 14 | mkdir -p source/notebooks && ln -sfn $(ROOT_DIR)/../examples/standalone_flow_matching.ipynb source/notebooks/standalone_flow_matching.ipynb 15 | mkdir -p source/notebooks && ln -sfn $(ROOT_DIR)/../examples/2d_discrete_flow_matching.ipynb source/notebooks/2d_discrete_flow_matching.ipynb 16 | mkdir -p source/notebooks && ln -sfn $(ROOT_DIR)/../examples/2d_riemannian_flow_matching_flat_torus.ipynb source/notebooks/2d_riemannian_flow_matching_flat_torus.ipynb 17 | mkdir -p source/notebooks && ln -sfn $(ROOT_DIR)/../examples/2d_riemannian_flow_matching_sphere.ipynb source/notebooks/2d_riemannian_flow_matching_sphere.ipynb 18 | ln -sfn $(ROOT_DIR)/../assets/teaser.png source/_images/teaser.png 19 | 20 | # Put it first so that "make" without argument is like "make help". 21 | help: 22 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 23 | 24 | .PHONY: help Makefile 25 | 26 | %: export PYTHONPATH=../:./ 27 | 28 | # Catch-all target: route all unknown targets to Sphinx using the new 29 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 30 | %: Makefile links 31 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 32 | 33 | deploy: html 34 | python deploy.py 35 | 36 | serve: 37 | uvicorn server:app --reload --reload-include 'build/html/*.html' -------------------------------------------------------------------------------- /flow_matching/utils/manifolds/sphere.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor 9 | 10 | from flow_matching.utils.manifolds import Manifold 11 | 12 | 13 | class Sphere(Manifold): 14 | """Represents a hyperpshere in :math:`R^D`. Isometric to the product of 1-D spheres.""" 15 | 16 | EPS = {torch.float32: 1e-4, torch.float64: 1e-7} 17 | 18 | def expmap(self, x: Tensor, u: Tensor) -> Tensor: 19 | norm_u = u.norm(dim=-1, keepdim=True) 20 | exp = x * torch.cos(norm_u) + u * torch.sin(norm_u) / norm_u 21 | retr = self.projx(x + u) 22 | cond = norm_u > self.EPS[norm_u.dtype] 23 | 24 | return torch.where(cond, exp, retr) 25 | 26 | def logmap(self, x: Tensor, y: Tensor) -> Tensor: 27 | u = self.proju(x, y - x) 28 | dist = self.dist(x, y, keepdim=True) 29 | cond = dist.gt(self.EPS[x.dtype]) 30 | result = torch.where( 31 | cond, 32 | u * dist / u.norm(dim=-1, keepdim=True).clamp_min(self.EPS[x.dtype]), 33 | u, 34 | ) 35 | return result 36 | 37 | def projx(self, x: Tensor) -> Tensor: 38 | return x / x.norm(dim=-1, keepdim=True) 39 | 40 | def proju(self, x: Tensor, u: Tensor) -> Tensor: 41 | return u - (x * u).sum(dim=-1, keepdim=True) * x 42 | 43 | def dist(self, x: Tensor, y: Tensor, *, keepdim=False) -> Tensor: 44 | inner = (x * y).sum(-1, keepdim=keepdim) 45 | return torch.acos(inner) 46 | -------------------------------------------------------------------------------- /examples/text/data/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # This implementation is adapted from https://github.com/louaaron/Score-Entropy-Discrete-Diffusion 7 | # which is released under MIT license 8 | 9 | import re 10 | 11 | 12 | def wt_detokenizer(string): 13 | # contractions 14 | string = string.replace("s '", "s'") 15 | string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) 16 | # number separators 17 | string = string.replace(" @-@ ", "-") 18 | string = string.replace(" @,@ ", ",") 19 | string = string.replace(" @.@ ", ".") 20 | # punctuation 21 | string = string.replace(" : ", ": ") 22 | string = string.replace(" ; ", "; ") 23 | string = string.replace(" . ", ". ") 24 | string = string.replace(" ! ", "! ") 25 | string = string.replace(" ? ", "? ") 26 | string = string.replace(" , ", ", ") 27 | # double brackets 28 | string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string) 29 | string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string) 30 | string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string) 31 | string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string) 32 | string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string) 33 | # miscellaneous 34 | string = string.replace("= = = =", "====") 35 | string = string.replace("= = =", "===") 36 | string = string.replace("= =", "==") 37 | string = string.replace(" " + chr(176) + " ", chr(176)) 38 | string = string.replace(" \n", "\n") 39 | string = string.replace("\n ", "\n") 40 | string = string.replace(" N ", " 1 ") 41 | string = string.replace(" 's", "'s") 42 | return string 43 | -------------------------------------------------------------------------------- /examples/text/run_train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Part of this implementation is adapted from https://github.com/louaaron/Score-Entropy-Discrete-Diffusion 7 | # which is released under MIT license 8 | 9 | import os 10 | 11 | import hydra 12 | import torch.multiprocessing as mp 13 | 14 | from hydra.core.hydra_config import HydraConfig 15 | from hydra.types import RunMode 16 | from omegaconf import open_dict 17 | from omegaconf.dictconfig import DictConfig 18 | from train import run_mp_training 19 | 20 | from utils import checkpointing 21 | 22 | 23 | @hydra.main(version_base=None, config_path="configs", config_name="config") 24 | def main(cfg: DictConfig): 25 | if "load_dir" in cfg: 26 | work_dir = cfg.load_dir 27 | cfg = checkpointing.load_hydra_config_from_run(cfg.load_dir) 28 | else: 29 | hydra_cfg = HydraConfig.get() 30 | work_dir = ( 31 | hydra_cfg.run.dir 32 | if hydra_cfg.mode == RunMode.RUN 33 | else os.path.join(hydra_cfg.sweep.dir, hydra_cfg.sweep.subdir) 34 | ) 35 | os.makedirs(work_dir, exist_ok=True) 36 | 37 | with open_dict(cfg): 38 | cfg.work_dir = work_dir 39 | 40 | port = 12346 41 | 42 | if cfg.compute.ngpus == 1: 43 | run_mp_training(rank=0, world_size=1, cfg=cfg, port=port) 44 | else: 45 | mp.set_start_method("forkserver") 46 | mp.spawn( 47 | run_mp_training, 48 | args=(cfg.compute.ngpus, cfg, port), 49 | nprocs=cfg.compute.ngpus, 50 | join=True, 51 | ) 52 | 53 | 54 | if __name__ == "__main__": 55 | main() 56 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: 3 | push: 4 | pull_request: 5 | 6 | jobs: 7 | build: 8 | runs-on: ubuntu-latest 9 | strategy: 10 | matrix: 11 | python-version: [3.9] 12 | 13 | steps: 14 | - uses: actions/checkout@v2 15 | 16 | - uses: mamba-org/setup-micromamba@v1.8.1 17 | with: 18 | environment-file: environment.yml 19 | cache-environment: true 20 | 21 | - name: Check formatting 22 | shell: bash -l {0} 23 | run: | 24 | ufmt check flow_matching/ examples/ 25 | 26 | - name: flake8 lint 27 | shell: bash -l {0} 28 | run: | 29 | flake8 flow_matching 30 | 31 | - name: Run tests 32 | shell: bash -l {0} 33 | run: | 34 | coverage run --include='flow_matching/**/*.py' -m unittest discover tests -v 35 | 36 | - name: Docstring Lint 37 | shell: bash -l {0} 38 | run: | 39 | pydoclint --style=google flow_matching 40 | 41 | - name: Build doc pages 42 | shell: bash -l {0} 43 | working-directory: docs 44 | run: | 45 | micromamba env update --file deps.yml 46 | PYTHONPATH=../:. make html 47 | 48 | - name: coverage 49 | shell: bash -l {0} 50 | run: | 51 | pip install coverage-badge 52 | coverage html --include='flow_matching/**/*.py' -d docs/build/html/coverage 53 | coverage-badge -o docs/build/html/coverage/coverage-badge.svg 54 | rm docs/build/html/coverage/.gitignore 55 | 56 | - name: Deploy docs to GitHub Pages 57 | uses: peaceiris/actions-gh-pages@v3 58 | if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/coverage' 59 | with: 60 | github_token: ${{ secrets.GITHUB_TOKEN }} 61 | publish_dir: ./docs/build/html 62 | -------------------------------------------------------------------------------- /examples/text/configs/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - override hydra/launcher: submitit_slurm 4 | 5 | compute: 6 | ngpus: 8 7 | nodes: 1 8 | 9 | logging: 10 | log_freq: 100 11 | log_lr_every: ${logging.log_freq} 12 | log_file_name: stdout.log 13 | enable_wandb: True 14 | entity: flows 15 | project: flow_matching 16 | group: null 17 | 18 | data: 19 | train: fineweb-edu 20 | valid: wikitext103 21 | cache_dir: /path/to/cache/dir 22 | num_workers: 8 23 | 24 | training: 25 | batch_size: 512 26 | snapshot: 2000 27 | eval_freq: 20000 28 | perplexity_freq: 20000 29 | seed: 42 30 | 31 | eval: 32 | batch_size: 512 33 | sample_batch_size: 16 34 | perplexity: True 35 | perplexity_batch_size: 16 36 | 37 | optim: 38 | weight_decay: 0.03 39 | optimizer: AdamW 40 | lr: 3e-4 41 | beta1: 0.9 42 | beta2: 0.95 43 | eps: 1e-8 44 | warmup: 2500 45 | grad_clip: 1. 46 | eta_min_ratio: 0.1 47 | fused: false 48 | n_iters: 1000000 49 | log_lr_every: ${logging.log_lr_every} 50 | 51 | flow: 52 | source_distribution: uniform # [uniform, mask] 53 | loss_function: cross_entropy # [cross_entropy, generalized_kl] 54 | exponent: 1. 55 | scheduler_type: polynomial 56 | sampling_steps: 1024 57 | 58 | model: 59 | hidden_size: 768 60 | cond_dim: 128 61 | length: 1024 62 | n_blocks: 12 63 | n_heads: 12 64 | dropout: 0.1 65 | compile: true 66 | 67 | hydra_dir: /path/to/hydra/dir 68 | 69 | hydra: 70 | run: 71 | dir: ${hydra_dir}/${now:%Y.%m.%d}/${now:%H%M%S} 72 | sweep: 73 | dir: ${hydra_dir}/${now:%Y.%m.%d}/${now:%H%M%S} 74 | subdir: ${hydra.job.num} 75 | launcher: 76 | max_num_timeout: 100000 77 | timeout_min: 4320 78 | partition: learn 79 | qos: # TODO: change it to your own qos 80 | gpus_per_node: ${compute.ngpus} 81 | mem_gb: 1760 82 | cpus_per_task: 32 83 | nodes: ${compute.nodes} 84 | -------------------------------------------------------------------------------- /flow_matching/path/path_sample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from dataclasses import dataclass, field 8 | 9 | from torch import Tensor 10 | 11 | 12 | @dataclass 13 | class PathSample: 14 | r"""Represents a sample of a conditional-flow generated probability path. 15 | 16 | Attributes: 17 | x_1 (Tensor): the target sample :math:`X_1`. 18 | x_0 (Tensor): the source sample :math:`X_0`. 19 | t (Tensor): the time sample :math:`t`. 20 | x_t (Tensor): samples :math:`X_t \sim p_t(X_t)`, shape (batch_size, ...). 21 | dx_t (Tensor): conditional target :math:`\frac{\partial X}{\partial t}`, shape: (batch_size, ...). 22 | 23 | """ 24 | 25 | x_1: Tensor = field(metadata={"help": "target samples X_1 (batch_size, ...)."}) 26 | x_0: Tensor = field(metadata={"help": "source samples X_0 (batch_size, ...)."}) 27 | t: Tensor = field(metadata={"help": "time samples t (batch_size, ...)."}) 28 | x_t: Tensor = field( 29 | metadata={"help": "samples x_t ~ p_t(X_t), shape (batch_size, ...)."} 30 | ) 31 | dx_t: Tensor = field( 32 | metadata={"help": "conditional target dX_t, shape: (batch_size, ...)."} 33 | ) 34 | 35 | 36 | @dataclass 37 | class DiscretePathSample: 38 | r""" 39 | Represents a sample of a conditional-flow generated discrete probability path. 40 | 41 | Attributes: 42 | x_1 (Tensor): the target sample :math:`X_1`. 43 | x_0 (Tensor): the source sample :math:`X_0`. 44 | t (Tensor): the time sample :math:`t`. 45 | x_t (Tensor): the sample along the path :math:`X_t \sim p_t`. 46 | """ 47 | 48 | x_1: Tensor = field(metadata={"help": "target samples X_1 (batch_size, ...)."}) 49 | x_0: Tensor = field(metadata={"help": "source samples X_0 (batch_size, ...)."}) 50 | t: Tensor = field(metadata={"help": "time samples t (batch_size, ...)."}) 51 | x_t: Tensor = field( 52 | metadata={"help": "samples X_t ~ p_t(X_t), shape (batch_size, ...)."} 53 | ) 54 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | 8 | import setuptools 9 | 10 | NAME = "flow_matching" 11 | DESCRIPTION = "Flow Matching for Generative Modeling" 12 | URL = "https://github.com/facebookresearch/flow_matching" 13 | EMAIL = "ylipman@meta.com" 14 | # Alphabetical 15 | AUTHOR = ",".join( 16 | [ 17 | "Brian Karrer", 18 | "David Lopez-Paz", 19 | "Heli Ben-Hamu", 20 | "Itai Gat", 21 | "Marton Havasi", 22 | "Matthew Le", 23 | "Neta Shaul", 24 | "Peter Holderrieth", 25 | "Ricky T.Q. Chen", 26 | "Yaron Lipman", 27 | ] 28 | ) 29 | REQUIRES_PYTHON = ">=3.9.0" 30 | 31 | for line in open("flow_matching/__init__.py"): 32 | line = line.strip() 33 | if "__version__" in line: 34 | context = {} 35 | exec(line, context) 36 | VERSION = context["__version__"] 37 | 38 | readme_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "README.md") 39 | 40 | try: 41 | with open(readme_path) as f: 42 | long_description = "\n" + f.read() 43 | except FileNotFoundError: 44 | long_description = DESCRIPTION 45 | 46 | setuptools.setup( 47 | name=NAME, 48 | version=VERSION, 49 | description=DESCRIPTION, 50 | long_description=long_description, 51 | long_description_content_type="text/markdown", 52 | author=AUTHOR, 53 | author_email=EMAIL, 54 | python_requires=REQUIRES_PYTHON, 55 | url=URL, 56 | packages=setuptools.find_packages(), 57 | extras_require={ 58 | "dev": [ 59 | "pre-commit", 60 | "black==22.6.0", 61 | "usort==1.0.4", 62 | "ufmt==2.3.0", 63 | "flake8==7.0.0", 64 | "pydoclint", 65 | ], 66 | }, 67 | install_requires=["numpy", "torch", "torchdiffeq"], 68 | license="CC-by-NC", 69 | classifiers=[ 70 | # Trove classifiers 71 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 72 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 73 | "License :: OSI Approved :: MIT License", 74 | ], 75 | ) 76 | -------------------------------------------------------------------------------- /tests/path/test_schedule_transform.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import unittest 7 | 8 | import torch 9 | from flow_matching.path.scheduler import ( 10 | CondOTScheduler, 11 | CosineScheduler, 12 | ScheduleTransformedModel, 13 | ) 14 | from flow_matching.solver import ODESolver 15 | from flow_matching.utils import ModelWrapper 16 | 17 | 18 | class DummyModel(ModelWrapper): 19 | def __init__(self): 20 | super().__init__(None) 21 | 22 | def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor: 23 | return x * t**2 24 | 25 | 26 | class TestScheduleTransformedModel(unittest.TestCase): 27 | def setUp(self): 28 | self.batch_size = 10 29 | self.data_dim = 2 30 | self.num_steps = 1000 31 | self.x_0 = torch.randn([self.batch_size, self.data_dim]) 32 | self.model = DummyModel() 33 | self.original_scheduler = CondOTScheduler() 34 | self.new_scheduler = CosineScheduler() 35 | 36 | def test_schedule_transformation(self): 37 | solver_original = ODESolver(velocity_model=self.model) 38 | x_1_original = solver_original.sample( 39 | time_steps=torch.tensor([0.0, 1.0]), 40 | x_init=self.x_0, 41 | step_size=1 / self.num_steps, 42 | method="euler", 43 | )[1] 44 | transformed_model = ScheduleTransformedModel( 45 | velocity_model=self.model, 46 | original_scheduler=self.original_scheduler, 47 | new_scheduler=self.new_scheduler, 48 | ) 49 | 50 | solver_transformed = ODESolver(velocity_model=transformed_model) 51 | x_1_transformed = solver_transformed.sample( 52 | time_steps=torch.tensor([0.0, 1.0]), 53 | x_init=self.x_0, 54 | step_size=1 / self.num_steps, 55 | method="euler", 56 | )[1] 57 | 58 | self.assertTrue( 59 | torch.allclose(x_1_original, x_1_transformed, atol=1e-2), 60 | "The samples with and without the transformed scheduler should be approximately equal.", 61 | ) 62 | 63 | 64 | if __name__ == "__main__": 65 | unittest.main() 66 | -------------------------------------------------------------------------------- /examples/image/training/grad_scaler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import torch 7 | 8 | from torch import Tensor 9 | 10 | 11 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> Tensor: 12 | if isinstance(parameters, Tensor): 13 | parameters = [parameters] 14 | parameters = [p for p in parameters if p.grad is not None] 15 | norm_type = float(norm_type) 16 | if len(parameters) == 0: 17 | return Tensor(0.0) 18 | device = parameters[0].grad.device 19 | if norm_type == torch.inf: 20 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 21 | else: 22 | total_norm = torch.norm( 23 | torch.stack( 24 | [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters] 25 | ), 26 | norm_type, 27 | ) 28 | return total_norm 29 | 30 | 31 | class NativeScalerWithGradNormCount: 32 | state_dict_key = "amp_scaler" 33 | 34 | def __init__(self): 35 | self._scaler = torch.cuda.amp.GradScaler() 36 | 37 | def __call__( 38 | self, 39 | loss, 40 | optimizer, 41 | clip_grad=None, 42 | parameters=None, 43 | create_graph=False, 44 | update_grad=True, 45 | ): 46 | self._scaler.scale(loss).backward(create_graph=create_graph) 47 | if update_grad: 48 | if clip_grad is not None: 49 | assert parameters is not None 50 | self._scaler.unscale_( 51 | optimizer 52 | ) # unscale the gradients of optimizer's assigned params in-place 53 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 54 | else: 55 | self._scaler.unscale_(optimizer) 56 | norm = get_grad_norm_(parameters) 57 | self._scaler.step(optimizer) 58 | self._scaler.update() 59 | else: 60 | norm = None 61 | return norm 62 | 63 | def state_dict(self): 64 | return self._scaler.state_dict() 65 | 66 | def load_state_dict(self, state_dict): 67 | self._scaler.load_state_dict(state_dict) 68 | -------------------------------------------------------------------------------- /examples/text/data/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # This implementation is adapted from https://github.com/pytorch/data/blob/main/torchdata/stateful_dataloader/sampler.py#L132 7 | # which is released under BSD-3 license 8 | 9 | import itertools 10 | from typing import Any, Dict, Optional 11 | 12 | import numpy as np 13 | import torch 14 | from torch import Tensor 15 | from torch.utils.data import DataLoader, Dataset, Sampler 16 | 17 | 18 | def cycle_loader(dataloader: DataLoader, sampler: Sampler = None) -> Tensor: 19 | while 1: 20 | if sampler is not None: 21 | sampler.set_epoch(np.random.randint(0, 100000)) 22 | for data in dataloader: 23 | yield data 24 | 25 | 26 | class StatefulDistributedSampler(torch.utils.data.distributed.DistributedSampler): 27 | """ 28 | From: https://github.com/pytorch/data/blob/main/torchdata/stateful_dataloader/sampler.py#L132 29 | """ 30 | 31 | _YIELDED = "yielded" 32 | 33 | def __init__( 34 | self, 35 | dataset: Dataset, 36 | num_replicas: Optional[int] = None, 37 | rank: Optional[int] = None, 38 | shuffle: bool = True, 39 | seed: int = 0, 40 | drop_last: bool = False, 41 | ) -> None: 42 | super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last) 43 | self.yielded = 0 44 | self.next_yielded = None 45 | 46 | def __iter__(self): 47 | self.yielded = 0 48 | if self.next_yielded is not None: 49 | self.yielded = self.next_yielded 50 | self.next_yielded = None 51 | it = super().__iter__() 52 | for idx in itertools.islice(it, self.yielded, None): 53 | self.yielded += 1 54 | yield idx 55 | 56 | def state_dict(self) -> Dict[str, Any]: 57 | return {self._YIELDED: self.yielded} 58 | 59 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 60 | if self._YIELDED not in state_dict: 61 | raise ValueError("Invalid state_dict") 62 | if state_dict[self._YIELDED] < 0: 63 | raise ValueError("Cannot load state_dict with negative yielded value") 64 | self.next_yielded = state_dict[self._YIELDED] 65 | -------------------------------------------------------------------------------- /flow_matching/path/path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from abc import ABC, abstractmethod 8 | 9 | from torch import Tensor 10 | 11 | from flow_matching.path.path_sample import PathSample 12 | 13 | 14 | class ProbPath(ABC): 15 | r"""Abstract class, representing a probability path. 16 | 17 | A probability path transforms the distribution :math:`p(X_0)` into :math:`p(X_1)` over :math:`t=0\rightarrow 1`. 18 | 19 | The ``ProbPath`` class is designed to support model training in the flow matching framework. It supports two key functionalities: (1) sampling the conditional probability path and (2) conversion between various training objectives. 20 | Here is a high-level example 21 | 22 | .. code-block:: python 23 | 24 | # Instantiate a probability path 25 | my_path = ProbPath(...) 26 | 27 | for x_0, x_1 in dataset: 28 | # Sets t to a random value in [0,1] 29 | t = torch.rand() 30 | 31 | # Samples the conditional path X_t ~ p_t(X_t|X_0,X_1) 32 | path_sample = my_path.sample(x_0=x_0, x_1=x_1, t=t) 33 | 34 | # Optimizes the model. The loss function varies, depending on model and path. 35 | loss(path_sample, my_model(x_t, t)).backward() 36 | 37 | """ 38 | 39 | @abstractmethod 40 | def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample: 41 | r"""Sample from an abstract probability path: 42 | 43 | | given :math:`(X_0,X_1) \sim \pi(X_0,X_1)`. 44 | | returns :math:`X_0, X_1, X_t \sim p_t(X_t)`, and a conditional target :math:`Y`, all objects are under ``PathSample``. 45 | 46 | Args: 47 | x_0 (Tensor): source data point, shape (batch_size, ...). 48 | x_1 (Tensor): target data point, shape (batch_size, ...). 49 | t (Tensor): times in [0,1], shape (batch_size). 50 | 51 | Returns: 52 | PathSample: a conditional sample. 53 | """ 54 | 55 | def assert_sample_shape(self, x_0: Tensor, x_1: Tensor, t: Tensor): 56 | assert ( 57 | t.ndim == 1 58 | ), f"The time vector t must have shape [batch_size]. Got {t.shape}." 59 | assert ( 60 | t.shape[0] == x_0.shape[0] == x_1.shape[0] 61 | ), f"Time t dimension must match the batch size [{x_1.shape[0]}]. Got {t.shape}" 62 | -------------------------------------------------------------------------------- /examples/text/logic/generate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from pathlib import Path 8 | from typing import Optional 9 | 10 | import torch 11 | from flow_matching.path import ProbPath 12 | from flow_matching.solver import MixtureDiscreteEulerSolver 13 | from flow_matching.utils import ModelWrapper 14 | from torch import nn, Tensor 15 | from transformers.tokenization_utils import PreTrainedTokenizer 16 | 17 | from .flow import SourceDistribution 18 | 19 | 20 | class WrappedModel(ModelWrapper): 21 | def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor: 22 | # Note: logit's precision is important. 23 | return torch.softmax(self.model(x_t=x, time=t).float(), -1) 24 | 25 | 26 | def generate_samples( 27 | model: nn.Module, 28 | step: int, 29 | vocab_size: int, 30 | tokenizer: PreTrainedTokenizer, 31 | rank: int, 32 | device: torch.device, 33 | path: ProbPath, 34 | source_distribution: SourceDistribution, 35 | sample_batch_size: int, 36 | sequence_length: int, 37 | sampling_steps: int, 38 | time_epsilon: float = 0.0, 39 | sample_dir: Optional[Path] = None, 40 | dtype_categorical: torch.dtype = torch.float64, 41 | ) -> Tensor: 42 | wrapped_probability_denoiser = WrappedModel(model=model) 43 | 44 | add_token = 1 if source_distribution.masked else 0 45 | solver = MixtureDiscreteEulerSolver( 46 | model=wrapped_probability_denoiser, 47 | path=path, 48 | vocabulary_size=vocab_size + add_token, 49 | ) 50 | 51 | x_init = source_distribution.sample( 52 | tensor_size=(sample_batch_size, sequence_length), device=device 53 | ) 54 | 55 | sample = solver.sample( 56 | x_init=x_init, 57 | step_size=1 / sampling_steps, 58 | verbose=True, 59 | dtype_categorical=dtype_categorical, 60 | time_grid=torch.tensor([0.0, 1.0 - time_epsilon]), 61 | ) 62 | 63 | sentences = tokenizer.batch_decode(sample) 64 | 65 | if sample_dir is not None: 66 | file_name = sample_dir / f"iter_{step}" / f"sample_{rank}.txt" 67 | file_name.parents[0].mkdir(exist_ok=True, parents=True) 68 | 69 | with open(file_name, "w") as file: 70 | for sentence in sentences: 71 | file.write(f"{sentence}\n{'=' * 20} New sample {'=' * 20}\n") 72 | 73 | return sample 74 | -------------------------------------------------------------------------------- /examples/text/utils/checkpointing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Part of this implementation is adapted from https://github.com/louaaron/Score-Entropy-Discrete-Diffusion 7 | # which is released under MIT license 8 | 9 | from dataclasses import dataclass, field 10 | from pathlib import Path 11 | 12 | import torch 13 | from logic.flow import SourceDistribution 14 | from model import Transformer 15 | from omegaconf import OmegaConf 16 | from torch import nn 17 | from torch.nn.parallel import DistributedDataParallel as DDP 18 | 19 | 20 | def load_cfg_from_path(work_dir: str) -> OmegaConf: 21 | work_dir = Path(work_dir) 22 | 23 | root_dir = work_dir if work_dir.is_dir() else work_dir.parents[1] 24 | 25 | cfg_path = root_dir / ".hydra/config.yaml" 26 | 27 | return OmegaConf.load(cfg_path) 28 | 29 | 30 | def load_model_from_path( 31 | work_dir: str, 32 | source_distribution: SourceDistribution, 33 | device: torch.device, 34 | vocab_size: int, 35 | cfg: OmegaConf, 36 | ) -> nn.Module: 37 | work_dir = Path(work_dir) 38 | 39 | if work_dir.is_dir(): 40 | root_dir = work_dir 41 | ckpt_dir = work_dir / "checkpoints" / "checkpoint.pth" 42 | else: 43 | root_dir = work_dir.parents[1] 44 | ckpt_dir = work_dir 45 | 46 | model = Transformer( 47 | config=cfg, vocab_size=vocab_size, masked=source_distribution.masked 48 | ).to(device) 49 | model = DDP(model, device_ids=[device]) 50 | 51 | ckpt_dir = root_dir / "checkpoints" / "checkpoint.pth" 52 | loaded_state = torch.load(ckpt_dir, map_location=device, weights_only=True) 53 | 54 | model.module.load_state_dict(loaded_state["model"]) 55 | 56 | return model 57 | 58 | 59 | @dataclass 60 | class WorkDirectory: 61 | root: Path = field(metadata={"help": "Root work directory"}) 62 | checkpoint: Path = field(metadata={"help": "Checkpoint directory"}) 63 | samples: Path = field(metadata={"help": "Samples directory"}) 64 | 65 | 66 | def get_work_dirs(work_dir: str, rank: int) -> WorkDirectory: 67 | work_dir = Path(work_dir) 68 | 69 | sample_dir = work_dir / "samples" 70 | checkpoint_dir = work_dir / "checkpoints" / "checkpoint.pth" 71 | 72 | if rank == 0: 73 | sample_dir.mkdir(exist_ok=True) 74 | checkpoint_dir.parents[0].mkdir(exist_ok=True) 75 | 76 | return WorkDirectory(root=work_dir, checkpoint=checkpoint_dir, samples=sample_dir) 77 | -------------------------------------------------------------------------------- /examples/image/training/load_and_save.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from pathlib import Path 7 | 8 | import torch 9 | from training.distributed_mode import is_main_process 10 | 11 | 12 | def save_on_master(*args, **kwargs): 13 | if is_main_process(): 14 | torch.save(*args, **kwargs) 15 | 16 | 17 | def save_model( 18 | args, epoch, model, model_without_ddp, optimizer, lr_schedule, loss_scaler 19 | ): 20 | output_dir = Path(args.output_dir) 21 | epoch_name = str(epoch) 22 | if loss_scaler is not None: 23 | checkpoint_paths = [ 24 | output_dir / ("checkpoint-%s.pth" % epoch_name), 25 | output_dir / "checkpoint.pth", 26 | ] 27 | for checkpoint_path in checkpoint_paths: 28 | to_save = { 29 | "model": model_without_ddp.state_dict(), 30 | "optimizer": optimizer.state_dict(), 31 | "lr_schedule": lr_schedule.state_dict(), 32 | "epoch": epoch, 33 | "scaler": loss_scaler.state_dict(), 34 | "args": args, 35 | } 36 | 37 | save_on_master(to_save, checkpoint_path) 38 | else: 39 | client_state = {"epoch": epoch} 40 | model.save_checkpoint( 41 | save_dir=args.output_dir, 42 | tag="checkpoint-%s" % epoch_name, 43 | client_state=client_state, 44 | ) 45 | 46 | 47 | def load_model(args, model_without_ddp, optimizer, loss_scaler, lr_schedule): 48 | if args.resume: 49 | if args.resume.startswith("https"): 50 | checkpoint = torch.hub.load_state_dict_from_url( 51 | args.resume, map_location="cpu", check_hash=True 52 | ) 53 | else: 54 | checkpoint = torch.load(args.resume, map_location="cpu") 55 | model_without_ddp.load_state_dict(checkpoint["model"]) 56 | print("Resume checkpoint %s" % args.resume) 57 | if ( 58 | "optimizer" in checkpoint 59 | and "epoch" in checkpoint 60 | and not (hasattr(args, "eval") and args.eval) 61 | ): 62 | optimizer.load_state_dict(checkpoint["optimizer"]) 63 | lr_schedule.load_state_dict(checkpoint["lr_schedule"]) 64 | args.start_epoch = checkpoint["epoch"] + 1 65 | if "scaler" in checkpoint: 66 | loss_scaler.load_state_dict(checkpoint["scaler"]) 67 | print("With optim & sched!") 68 | -------------------------------------------------------------------------------- /examples/text/scripts/run_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Part of this implementation is adapted from https://github.com/louaaron/Score-Entropy-Discrete-Diffusion 7 | # which is released under MIT license 8 | 9 | import argparse 10 | 11 | import torch.multiprocessing as mp 12 | 13 | from eval import run_mp_eval 14 | 15 | 16 | def main(args: argparse.Namespace): 17 | port = 12346 18 | 19 | assert args.perplexity_n_samples % args.ngpus == 0 20 | assert args.batch_size % args.ngpus == 0 21 | 22 | if args.ngpus == 1: 23 | run_mp_eval( 24 | rank=0, 25 | world_size=1, 26 | seed=args.seed, 27 | work_dir=args.work_dir, 28 | batch_size=args.batch_size // args.ngpus, 29 | sampling_steps=args.sampling_steps, 30 | eval_elbo=args.eval_elbo, 31 | eval_perplexity=args.eval_perplexity, 32 | elbo_data=args.elbo_data, 33 | perplexity_n_samples=args.perplexity_n_samples // args.ngpus, 34 | port=port, 35 | ) 36 | else: 37 | mp.set_start_method("forkserver") 38 | 39 | mp.spawn( 40 | run_mp_eval, 41 | args=( 42 | args.ngpus, 43 | args.seed, 44 | args.work_dir, 45 | args.batch_size // args.ngpus, 46 | args.sampling_steps, 47 | args.eval_elbo, 48 | args.eval_perplexity, 49 | args.elbo_data, 50 | args.perplexity_n_samples // args.ngpus, 51 | port, 52 | ), 53 | nprocs=args.ngpus, 54 | join=True, 55 | ) 56 | 57 | 58 | if __name__ == "__main__": 59 | parser = argparse.ArgumentParser() 60 | 61 | parser.add_argument("--work_dir", type=str, required=True) 62 | 63 | parser.add_argument("--seed", type=int, default=42) 64 | parser.add_argument("--batch_size", type=int, default=256) 65 | parser.add_argument("--ngpus", type=int, default=8) 66 | 67 | parser.add_argument("--eval_elbo", action="store_true") 68 | parser.add_argument("--eval_perplexity", action="store_true") 69 | 70 | # Perplexity parameters 71 | parser.add_argument("--sampling_steps", type=int, default=1024) 72 | parser.add_argument("--perplexity_n_samples", type=int, default=1024) 73 | 74 | # ELBO parameters 75 | parser.add_argument("--elbo_data", type=str, default="wikitext103") 76 | 77 | args = parser.parse_args() 78 | main(args) 79 | -------------------------------------------------------------------------------- /examples/image/training/distributed_mode.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | from datetime import timedelta 8 | 9 | import torch 10 | import torch.distributed as dist 11 | 12 | 13 | def is_dist_avail_and_initialized(): 14 | if not dist.is_available(): 15 | return False 16 | if not dist.is_initialized(): 17 | return False 18 | return True 19 | 20 | 21 | def get_world_size(): 22 | if not is_dist_avail_and_initialized(): 23 | return 1 24 | return dist.get_world_size() 25 | 26 | 27 | def get_rank(): 28 | if not is_dist_avail_and_initialized(): 29 | return 0 30 | return dist.get_rank() 31 | 32 | 33 | def is_main_process(): 34 | return get_rank() == 0 35 | 36 | 37 | def init_distributed_mode(args): 38 | if args.dist_on_itp: 39 | args.rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) 40 | args.world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) 41 | args.gpu = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) 42 | args.dist_url = "tcp://%s:%s" % ( 43 | os.environ["MASTER_ADDR"], 44 | os.environ["MASTER_PORT"], 45 | ) 46 | os.environ["LOCAL_RANK"] = str(args.gpu) 47 | os.environ["RANK"] = str(args.rank) 48 | os.environ["WORLD_SIZE"] = str(args.world_size) 49 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 50 | elif "RANK" in os.environ and "WORLD_SIZE" in os.environ: 51 | args.rank = int(os.environ["RANK"]) 52 | args.world_size = int(os.environ["WORLD_SIZE"]) 53 | args.gpu = int(os.environ["LOCAL_RANK"]) 54 | elif ( 55 | "SLURM_PROCID" in os.environ and os.environ["SLURM_JOB_NAME"] != "bash" 56 | ): # Exclude interactive shells 57 | args.rank = int(os.environ["SLURM_PROCID"]) 58 | args.gpu = args.rank % torch.cuda.device_count() 59 | else: 60 | print("Not using distributed mode") 61 | args.distributed = False 62 | return 63 | 64 | args.distributed = True 65 | 66 | torch.cuda.set_device(args.gpu) 67 | args.dist_backend = "nccl" 68 | print( 69 | "| distributed init (rank {}): {}, gpu {}".format( 70 | args.rank, args.dist_url, args.gpu 71 | ), 72 | flush=True, 73 | ) 74 | torch.distributed.init_process_group( 75 | backend=args.dist_backend, 76 | init_method=args.dist_url, 77 | world_size=args.world_size, 78 | rank=args.rank, 79 | timeout=timedelta(hours=1), 80 | ) 81 | torch.distributed.barrier() 82 | -------------------------------------------------------------------------------- /flow_matching/utils/manifolds/manifold.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import abc 8 | 9 | import torch.nn as nn 10 | from torch import Tensor 11 | 12 | 13 | class Manifold(nn.Module, metaclass=abc.ABCMeta): 14 | """A manifold class that contains projection operations and logarithm and exponential maps.""" 15 | 16 | @abc.abstractmethod 17 | def expmap(self, x: Tensor, u: Tensor) -> Tensor: 18 | r"""Computes exponential map :math:`\exp_x(u)`. 19 | 20 | Args: 21 | x (Tensor): point on the manifold 22 | u (Tensor): tangent vector at point :math:`x` 23 | 24 | Raises: 25 | NotImplementedError: if not implemented 26 | 27 | Returns: 28 | Tensor: transported point 29 | """ 30 | raise NotImplementedError 31 | 32 | @abc.abstractmethod 33 | def logmap(self, x: Tensor, y: Tensor) -> Tensor: 34 | r"""Computes logarithmic map :math:`\log_x(y)`. 35 | 36 | Args: 37 | x (Tensor): point on the manifold 38 | y (Tensor): point on the manifold 39 | 40 | Raises: 41 | NotImplementedError: if not implemented 42 | 43 | Returns: 44 | Tensor: tangent vector at point :math:`x` 45 | """ 46 | raise NotImplementedError 47 | 48 | @abc.abstractmethod 49 | def projx(self, x: Tensor) -> Tensor: 50 | """Project point :math:`x` on the manifold. 51 | 52 | Args: 53 | x (Tensor): point to be projected 54 | 55 | Raises: 56 | NotImplementedError: if not implemented 57 | 58 | Returns: 59 | Tensor: projected point on the manifold 60 | """ 61 | raise NotImplementedError 62 | 63 | @abc.abstractmethod 64 | def proju(self, x: Tensor, u: Tensor) -> Tensor: 65 | """Project vector :math:`u` on a tangent space for :math:`x`. 66 | 67 | Args: 68 | x (Tensor): point on the manifold 69 | u (Tensor): vector to be projected 70 | 71 | Raises: 72 | NotImplementedError: if not implemented 73 | 74 | Returns: 75 | Tensor: projected tangent vector 76 | """ 77 | raise NotImplementedError 78 | 79 | 80 | class Euclidean(Manifold): 81 | """The Euclidean manifold.""" 82 | 83 | def expmap(self, x: Tensor, u: Tensor) -> Tensor: 84 | return x + u 85 | 86 | def logmap(self, x: Tensor, y: Tensor) -> Tensor: 87 | return y - x 88 | 89 | def projx(self, x: Tensor) -> Tensor: 90 | return x 91 | 92 | def proju(self, x: Tensor, u: Tensor) -> Tensor: 93 | return u 94 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Configuration file for the Sphinx documentation builder. 7 | # 8 | # For the full list of built-in configuration values, see the documentation: 9 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 10 | 11 | # -- Project information ----------------------------------------------------- 12 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 13 | 14 | project = "Flow Matching" 15 | copyright = "2024 Meta Platforms, Inc" 16 | author = "FAIR" 17 | 18 | # -- General configuration --------------------------------------------------- 19 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 20 | 21 | extensions = [ 22 | "nbsphinx", 23 | "sphinx.ext.autodoc", 24 | "sphinx.ext.autosummary", 25 | "sphinx.ext.doctest", 26 | "sphinx.ext.intersphinx", 27 | "sphinx.ext.todo", 28 | "sphinx.ext.coverage", 29 | "sphinx.ext.napoleon", 30 | "sphinx.ext.viewcode", 31 | "sphinxcontrib.katex", 32 | "sphinx.ext.autosectionlabel", 33 | "sphinxcontrib.bibtex", 34 | ] 35 | 36 | bibtex_bibfiles = ["refs.bib"] 37 | bibtex_default_style = "unsrt" 38 | 39 | templates_path = ["_templates"] 40 | exclude_patterns = ["_build", "**.ipynb_checkpoints"] 41 | 42 | # -- Options for HTML output ------------------------------------------------- 43 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 44 | html_theme = "pydata_sphinx_theme" 45 | html_static_path = ["_static", "_images"] 46 | 47 | # katex config 48 | katex_css_path = "https://cdn.jsdelivr.net/npm/katex@0.16.10/dist/katex.min.css" 49 | katex_js_path = "katex.min.js" 50 | katex_autorender_path = "auto-render.min.js" 51 | katex_inline = [r"\(", r"\)"] 52 | katex_display = [r"\[", r"\]"] 53 | katex_prerender = False 54 | katex_options = "" 55 | 56 | # autodoc config 57 | autodoc_member_order = "bysource" 58 | autosummary_generate = True # Turn on sphinx.ext.autosummary 59 | 60 | from custom_directives import ( 61 | CustomCardEnd, 62 | CustomCardItem, 63 | CustomCardStart, 64 | SupportedDevices, 65 | SupportedProperties, 66 | ) 67 | 68 | # Register custom directives 69 | 70 | from docutils.parsers import rst 71 | 72 | rst.directives.register_directive("devices", SupportedDevices) 73 | rst.directives.register_directive("properties", SupportedProperties) 74 | rst.directives.register_directive("customcardstart", CustomCardStart) 75 | rst.directives.register_directive("customcarditem", CustomCardItem) 76 | rst.directives.register_directive("customcardend", CustomCardEnd) 77 | 78 | 79 | def setup(app): 80 | app.add_css_file("css/custom.css") # may also be an URL 81 | -------------------------------------------------------------------------------- /examples/text/logic/state.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | from pathlib import Path 9 | 10 | import torch 11 | from data import DataState 12 | 13 | from torch import nn 14 | from torch.optim import Optimizer 15 | 16 | 17 | class TrainState: 18 | def __init__( 19 | self, 20 | model: nn.Module, 21 | optimizer: Optimizer, 22 | step: int, 23 | data_state: DataState, 24 | ): 25 | self._model = model 26 | self._optimizer = optimizer 27 | self._step = step 28 | self._data_state = data_state 29 | 30 | @property 31 | def step(self) -> int: 32 | return self._step 33 | 34 | @step.setter 35 | def step(self, value: int) -> None: 36 | self._step = value 37 | 38 | @property 39 | def optimizer(self) -> Optimizer: 40 | return self._optimizer 41 | 42 | @property 43 | def model(self) -> nn.Module: 44 | return self._model 45 | 46 | @property 47 | def data_state(self) -> DataState: 48 | return self._data_state 49 | 50 | def compile_model(self) -> None: 51 | self._model = torch.compile(self._model) 52 | 53 | def restore_checkpoint( 54 | self, ckpt_dir: Path, device: torch.device, rank: int 55 | ) -> None: 56 | if ckpt_dir.exists(): 57 | loaded_state = torch.load(ckpt_dir, map_location=device, weights_only=True) 58 | 59 | self.optimizer.load_state_dict(loaded_state["optimizer"]) 60 | self.model.module.load_state_dict(loaded_state["model"]) 61 | self.step = loaded_state["step"] 62 | self._data_state.test.load_state_dict(loaded_state["test_sampler"]) 63 | self._data_state.train.sampler.load_state_dict( 64 | loaded_state["train_sampler"] 65 | ) 66 | else: 67 | ckpt_dir.parent.mkdir(exist_ok=True, parents=True) 68 | 69 | if rank == 0: 70 | logging.warning( 71 | f"No checkpoint found at {ckpt_dir}. Returned the same state as input" 72 | ) 73 | 74 | def save_checkpoint(self, ckpt_dir: str, rank: int) -> None: 75 | saved_state = { 76 | "optimizer": self.optimizer.state_dict(), 77 | "model": self.model.module.state_dict(), 78 | "step": self.step, 79 | "train_sampler": self._data_state.train.sampler.state_dict(), 80 | "test_sampler": self._data_state.test.sampler.state_dict(), 81 | } 82 | 83 | if rank == 0: 84 | torch.save(saved_state, ckpt_dir) 85 | 86 | def eval(self) -> None: 87 | self.train(training=False) 88 | 89 | def train(self, training: bool = True) -> None: 90 | self._model.train(mode=training) 91 | -------------------------------------------------------------------------------- /examples/text/model/rotary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Part of this implementation is adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py#L20 7 | # which is released under BSD-3 license 8 | # Part of this implementation is adapted from https://github.com/louaaron/Score-Entropy-Discrete-Diffusion 9 | # which is released under MIT license 10 | 11 | from typing import Tuple 12 | 13 | import torch 14 | from einops import repeat 15 | from torch import Tensor 16 | 17 | 18 | class Rotary(torch.nn.Module): 19 | """ 20 | From: https://github.com/louaaron/Score-Entropy-Discrete-Diffusion 21 | """ 22 | 23 | def __init__(self, dim: int, base: int = 10_000): 24 | super().__init__() 25 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) 26 | self.register_buffer("inv_freq", inv_freq) 27 | self.seq_len_cached = None 28 | self.cos_cached = None 29 | self.sin_cached = None 30 | 31 | def forward(self, x: Tensor, seq_dim: int = 1) -> Tuple[Tensor, Tensor]: 32 | seq_len = x.shape[seq_dim] 33 | if seq_len != self.seq_len_cached: 34 | self.seq_len_cached = seq_len 35 | t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) 36 | freqs = torch.einsum("i,j->ij", t, self.inv_freq.clone()) 37 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 38 | 39 | # dims are: batch, seq_len, qkv, head, dim 40 | self.cos_cached = emb.cos()[None, :, None, None, :].repeat(1, 1, 3, 1, 1) 41 | self.sin_cached = emb.sin()[None, :, None, None, :].repeat(1, 1, 3, 1, 1) 42 | 43 | # This makes the transformation on v an identity. 44 | self.cos_cached[:, :, 2, :, :].fill_(1.0) 45 | self.sin_cached[:, :, 2, :, :].fill_(0.0) 46 | 47 | return self.cos_cached, self.sin_cached 48 | 49 | 50 | def rotate_half(x: Tensor) -> Tensor: 51 | x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] 52 | 53 | return torch.cat((-x2, x1), dim=-1) 54 | 55 | 56 | def apply_rotary_emb_torch(x, cos, sin, interleaved=False): 57 | """ 58 | From: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py#L20 59 | """ 60 | cos = cos[0, :, 0, 0, : cos.shape[-1] // 2] 61 | sin = sin[0, :, 0, 0, : sin.shape[-1] // 2] 62 | 63 | ro_dim = cos.shape[-1] * 2 64 | assert ro_dim <= x.shape[-1] 65 | cos = repeat( 66 | cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" 67 | ) 68 | sin = repeat( 69 | sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" 70 | ) 71 | 72 | return x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim]) * sin 73 | -------------------------------------------------------------------------------- /examples/image/models/ema.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import logging 7 | from typing import List 8 | 9 | import torch 10 | from torch.nn import Module, Parameter, ParameterList 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class EMA(Module): 16 | def __init__(self, model: Module, decay: float = 0.999): 17 | super().__init__() 18 | self.model = model 19 | self.decay = decay 20 | 21 | # Put this in a buffer so that it gets included in the state dict 22 | self.register_buffer("num_updates", torch.tensor(0)) 23 | 24 | self.shadow_params: ParameterList = ParameterList( 25 | [ 26 | Parameter(p.clone().detach(), requires_grad=False) 27 | for p in model.parameters() 28 | if p.requires_grad 29 | ] 30 | ) 31 | self.backup_params: List[torch.Tensor] = [] 32 | 33 | def train(self, mode: bool) -> None: 34 | if self.training == mode: 35 | super().train(mode) 36 | return 37 | 38 | if not mode: 39 | logger.info( 40 | "EMA: Switching from train to eval, backing up parameters and copying EMA params" 41 | ) 42 | self.backup() 43 | self.copy_to_model() 44 | else: 45 | logger.info("EMA: Switching from eval to train, restoring saved parameters") 46 | self.restore_to_model() 47 | 48 | super().train(mode) 49 | 50 | def update_ema(self) -> None: 51 | self.num_updates += 1 52 | num_updates = self.num_updates.item() 53 | decay = min(self.decay, (1 + num_updates) / (10 + num_updates)) 54 | with torch.no_grad(): 55 | params = [p for p in self.model.parameters() if p.requires_grad] 56 | for shadow, param in zip(self.shadow_params, params): 57 | shadow.sub_((1 - decay) * (shadow - param)) 58 | 59 | def forward(self, *args, **kwargs) -> torch.Tensor: 60 | return self.model(*args, **kwargs) 61 | 62 | def copy_to_model(self) -> None: 63 | params = [p for p in self.model.parameters() if p.requires_grad] 64 | for shadow, param in zip(self.shadow_params, params): 65 | param.data.copy_(shadow.data) 66 | 67 | def backup(self) -> None: 68 | assert ( 69 | self.training 70 | ), "Backup can only be created in train mode to avoid backing-up ema weights." 71 | if len(self.backup_params) > 0: 72 | for p, b in zip(self.model.parameters(), self.backup_params): 73 | b.data.copy_(p.data) 74 | else: 75 | self.backup_params = [param.clone() for param in self.model.parameters()] 76 | 77 | def restore_to_model(self) -> None: 78 | for param, backup in zip(self.model.parameters(), self.backup_params): 79 | param.data.copy_(backup.data) 80 | -------------------------------------------------------------------------------- /examples/text/logic/flow.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from abc import ABC 8 | from typing import Optional, Tuple 9 | 10 | import torch 11 | from flow_matching.loss import MixturePathGeneralizedKL 12 | from flow_matching.path import MixtureDiscreteProbPath, ProbPath 13 | from flow_matching.path.scheduler import PolynomialConvexScheduler 14 | from torch import Tensor 15 | from torch.nn.modules.loss import _Loss 16 | 17 | 18 | class SourceDistribution(ABC): 19 | def __init__( 20 | self, 21 | ) -> None: 22 | ... 23 | 24 | def sample(self, tensor_size: Tuple[int, ...], device: torch.device) -> Tensor: 25 | ... 26 | 27 | def sample_like(self, tensor_like: Tensor) -> Tensor: 28 | ... 29 | 30 | 31 | class MaskedSourceDistribution(SourceDistribution): 32 | def __init__(self, mask_token: int) -> None: 33 | self.mask_token = mask_token 34 | 35 | @property 36 | def masked(self) -> bool: 37 | return True 38 | 39 | def sample(self, tensor_size: Tuple[int, ...], device: torch.device) -> Tensor: 40 | return torch.zeros(tensor_size, device=device).fill_(self.mask_token).long() 41 | 42 | def sample_like(self, tensor_like: Tensor) -> Tensor: 43 | return torch.zeros_like(tensor_like).fill_(self.mask_token).long() 44 | 45 | 46 | class UniformSourceDistribution(SourceDistribution): 47 | def __init__(self, vocab_size: int) -> None: 48 | self.vocab_size = vocab_size 49 | 50 | @property 51 | def masked(self) -> bool: 52 | return False 53 | 54 | def sample(self, tensor_size: Tuple[int, ...], device: torch.device) -> Tensor: 55 | return torch.randint(size=tensor_size, high=self.vocab_size, device=device) 56 | 57 | def sample_like(self, tensor_like: Tensor) -> Tensor: 58 | return torch.randint_like(tensor_like, high=self.vocab_size) 59 | 60 | 61 | def get_path(scheduler_type: str, exponent: Optional[float] = None) -> ProbPath: 62 | if scheduler_type == "polynomial": 63 | scheduler = PolynomialConvexScheduler(n=exponent) 64 | else: 65 | raise ValueError(f"{scheduler_type} is not supported") 66 | 67 | return MixtureDiscreteProbPath(scheduler=scheduler) 68 | 69 | 70 | def get_source_distribution( 71 | source_distribution: str, vocab_size: int 72 | ) -> SourceDistribution: 73 | if source_distribution == "mask": 74 | return MaskedSourceDistribution(mask_token=vocab_size) 75 | elif source_distribution == "uniform": 76 | return UniformSourceDistribution(vocab_size=vocab_size) 77 | else: 78 | raise ValueError(f"{source_distribution} is not supported") 79 | 80 | 81 | def get_loss_function(loss_function: str, path: Optional[ProbPath] = None) -> _Loss: 82 | if loss_function == "cross_entropy": 83 | return torch.nn.CrossEntropyLoss() 84 | elif loss_function == "generalized_kl": 85 | assert path is not None 86 | 87 | return MixturePathGeneralizedKL(path=path) 88 | else: 89 | raise ValueError(f"{loss_function} is not supported") 90 | -------------------------------------------------------------------------------- /tests/path/test_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import unittest 7 | 8 | import torch 9 | 10 | from flow_matching.path.scheduler import ( 11 | CondOTScheduler, 12 | CosineScheduler, 13 | LinearVPScheduler, 14 | PolynomialConvexScheduler, 15 | SchedulerOutput, 16 | VPScheduler, 17 | ) 18 | from torch import Tensor 19 | 20 | 21 | class TestScheduler(unittest.TestCase): 22 | def setUp(self): 23 | self.t = torch.tensor([0.1, 0.5, 0.9]) 24 | 25 | def assert_output_shapes( 26 | self, outputs: SchedulerOutput, expected_shape: torch.Size 27 | ): 28 | self.assertEqual(outputs.alpha_t.shape, expected_shape) 29 | self.assertEqual(outputs.sigma_t.shape, expected_shape) 30 | self.assertEqual(outputs.d_alpha_t.shape, expected_shape) 31 | self.assertEqual(outputs.d_sigma_t.shape, expected_shape) 32 | 33 | def assert_recover_t_from_kappa(self, scheduler, t: Tensor): 34 | scheduler_output = scheduler(t) 35 | t_recovered = scheduler.kappa_inverse(scheduler_output.alpha_t) 36 | 37 | self.assertTrue( 38 | torch.allclose(t, t_recovered, atol=1e-5), 39 | f"Recovered t: {t_recovered}, Original t: {t}", 40 | ) 41 | 42 | def assert_recover_t_from_snr(self, scheduler, t: Tensor): 43 | scheduler_output = scheduler(t) 44 | snr = scheduler_output.alpha_t / scheduler_output.sigma_t 45 | 46 | t_recovered = scheduler.snr_inverse(snr) 47 | 48 | self.assertTrue( 49 | torch.allclose(t, t_recovered, atol=1e-5), 50 | f"Recovered t: {t_recovered}, Original t: {t}", 51 | ) 52 | 53 | def test_cond_ot_scheduler(self): 54 | scheduler = CondOTScheduler() 55 | outputs = scheduler(self.t) 56 | 57 | self.assert_output_shapes(outputs, self.t.shape) 58 | 59 | self.assert_recover_t_from_kappa(scheduler, self.t) 60 | self.assert_recover_t_from_snr(scheduler, self.t) 61 | 62 | def test_cosine_scheduler(self): 63 | scheduler = CosineScheduler() 64 | outputs = scheduler(self.t) 65 | self.assert_output_shapes(outputs, self.t.shape) 66 | 67 | self.assert_recover_t_from_snr(scheduler, self.t) 68 | 69 | def test_scheduler_vp(self): 70 | scheduler = VPScheduler() 71 | outputs = scheduler(self.t) 72 | self.assert_output_shapes(outputs, self.t.shape) 73 | 74 | self.assert_recover_t_from_snr(scheduler, self.t) 75 | 76 | def test_scheduler_vp_linear(self): 77 | scheduler = LinearVPScheduler() 78 | outputs = scheduler(self.t) 79 | self.assert_output_shapes(outputs, self.t.shape) 80 | 81 | self.assert_recover_t_from_snr(scheduler, self.t) 82 | 83 | def test_polynomial_convex_scheduler(self): 84 | scheduler = PolynomialConvexScheduler(n=2) 85 | outputs = scheduler(self.t) 86 | self.assert_output_shapes(outputs, self.t.shape) 87 | 88 | self.assert_recover_t_from_kappa(scheduler, self.t) 89 | self.assert_recover_t_from_snr(scheduler, self.t) 90 | 91 | 92 | if __name__ == "__main__": 93 | unittest.main() 94 | -------------------------------------------------------------------------------- /flow_matching/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional 8 | 9 | import torch 10 | from torch import Tensor 11 | 12 | 13 | def unsqueeze_to_match(source: Tensor, target: Tensor, how: str = "suffix") -> Tensor: 14 | """ 15 | Unsqueeze the source tensor to match the dimensionality of the target tensor. 16 | 17 | Args: 18 | source (Tensor): The source tensor to be unsqueezed. 19 | target (Tensor): The target tensor to match the dimensionality of. 20 | how (str, optional): Whether to unsqueeze the source tensor at the beginning 21 | ("prefix") or end ("suffix"). Defaults to "suffix". 22 | 23 | Returns: 24 | Tensor: The unsqueezed source tensor. 25 | """ 26 | assert ( 27 | how == "prefix" or how == "suffix" 28 | ), f"{how} is not supported, only 'prefix' and 'suffix' are supported." 29 | 30 | dim_diff = target.dim() - source.dim() 31 | 32 | for _ in range(dim_diff): 33 | if how == "prefix": 34 | source = source.unsqueeze(0) 35 | elif how == "suffix": 36 | source = source.unsqueeze(-1) 37 | 38 | return source 39 | 40 | 41 | def expand_tensor_like(input_tensor: Tensor, expand_to: Tensor) -> Tensor: 42 | """`input_tensor` is a 1d vector of length equal to the batch size of `expand_to`, 43 | expand `input_tensor` to have the same shape as `expand_to` along all remaining dimensions. 44 | 45 | Args: 46 | input_tensor (Tensor): (batch_size,). 47 | expand_to (Tensor): (batch_size, ...). 48 | 49 | Returns: 50 | Tensor: (batch_size, ...). 51 | """ 52 | assert input_tensor.ndim == 1, "Input tensor must be a 1d vector." 53 | assert ( 54 | input_tensor.shape[0] == expand_to.shape[0] 55 | ), f"The first (batch_size) dimension must match. Got shape {input_tensor.shape} and {expand_to.shape}." 56 | 57 | dim_diff = expand_to.ndim - input_tensor.ndim 58 | 59 | t_expanded = input_tensor.clone() 60 | t_expanded = t_expanded.reshape(-1, *([1] * dim_diff)) 61 | 62 | return t_expanded.expand_as(expand_to) 63 | 64 | 65 | def gradient( 66 | output: Tensor, 67 | x: Tensor, 68 | grad_outputs: Optional[Tensor] = None, 69 | create_graph: bool = False, 70 | ) -> Tensor: 71 | """ 72 | Compute the gradient of the inner product of output and grad_outputs w.r.t :math:`x`. 73 | 74 | Args: 75 | output (Tensor): [N, D] Output of the function. 76 | x (Tensor): [N, d_1, d_2, ... ] input 77 | grad_outputs (Optional[Tensor]): [N, D] Gradient of outputs, if `None`, 78 | then will use a tensor of ones 79 | create_graph (bool): If True, graph of the derivative will be constructed, allowing 80 | to compute higher order derivative products. Defaults to False. 81 | Returns: 82 | Tensor: [N, d_1, d_2, ... ]. the gradient w.r.t x. 83 | """ 84 | 85 | if grad_outputs is None: 86 | grad_outputs = torch.ones_like(output).detach() 87 | grad = torch.autograd.grad( 88 | output, x, grad_outputs=grad_outputs, create_graph=create_graph 89 | )[0] 90 | return grad 91 | -------------------------------------------------------------------------------- /examples/image/models/discrete_unet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from dataclasses import dataclass 7 | from typing import Mapping, Optional, Tuple 8 | 9 | import torch 10 | import torch.nn as nn 11 | from models.unet import UNetModel 12 | 13 | 14 | class PixelEmbedding(nn.Module): 15 | def __init__( 16 | self, 17 | n_tokens: int, 18 | hidden_size: int, 19 | ): 20 | super().__init__() 21 | self.embedding_table = nn.Embedding(n_tokens, hidden_size) 22 | 23 | def forward(self, x: torch.Tensor): 24 | B, _, H, W = x.shape 25 | emb = self.embedding_table(x) 26 | result = emb.permute(0, 1, 4, 2, 3).reshape(B, -1, H, W) 27 | return result 28 | 29 | 30 | @dataclass(eq=False) 31 | class DiscreteUNetModel(nn.Module): 32 | vocab_size: int 33 | in_channels: int = 3 34 | model_channels: int = 128 35 | out_channels: int = 3 36 | num_res_blocks: int = 2 37 | attention_resolutions: Tuple[int] = (1, 2, 2, 2) 38 | dropout: float = 0.0 39 | channel_mult: Tuple[int] = (1, 2, 4, 8) 40 | conv_resample: bool = True 41 | dims: int = 2 42 | num_classes: Optional[int] = None 43 | use_checkpoint: bool = False 44 | num_heads: int = 1 45 | num_head_channels: int = -1 46 | num_heads_upsample: int = -1 47 | use_scale_shift_norm: bool = False 48 | resblock_updown: bool = False 49 | use_new_attention_order: bool = False 50 | with_fourier_features: bool = False 51 | 52 | def __post_init__(self): 53 | super().__init__() 54 | assert ( 55 | self.model_channels * self.channel_mult[0] % self.in_channels == 0 56 | ), f"Unet input dimensions must be divisible by the number of channels. Got {self.model_channels * self.channel_mult[0]} / {self.in_channels}" 57 | self.embedding_dim = ( 58 | self.model_channels * self.channel_mult[0] // self.in_channels 59 | ) 60 | 61 | self.pixel_embedding = PixelEmbedding( 62 | n_tokens=self.vocab_size, hidden_size=self.embedding_dim 63 | ) 64 | 65 | self.unet = UNetModel( 66 | in_channels=self.in_channels * self.embedding_dim, 67 | model_channels=self.model_channels, 68 | out_channels=self.out_channels * (self.vocab_size), 69 | num_res_blocks=self.num_res_blocks, 70 | attention_resolutions=self.attention_resolutions, 71 | dropout=self.dropout, 72 | channel_mult=self.channel_mult, 73 | conv_resample=self.conv_resample, 74 | dims=self.dims, 75 | num_classes=self.num_classes, 76 | use_checkpoint=self.use_checkpoint, 77 | num_heads=self.num_heads, 78 | num_head_channels=self.num_head_channels, 79 | num_heads_upsample=self.num_heads_upsample, 80 | use_scale_shift_norm=self.use_scale_shift_norm, 81 | resblock_updown=self.resblock_updown, 82 | use_new_attention_order=self.use_new_attention_order, 83 | with_fourier_features=self.with_fourier_features, 84 | ignore_time=True, 85 | input_projection=False, 86 | ) 87 | 88 | def forward( 89 | self, x_t: torch.Tensor, t: torch.Tensor, extra: Mapping[str, torch.Tensor] 90 | ) -> torch.Tensor: 91 | B, C, H, W = x_t.shape 92 | logits = ( 93 | self.unet(self.pixel_embedding(x_t), t, extra) 94 | .reshape(B, C, self.vocab_size, H, W) 95 | .permute(0, 1, 3, 4, 2) 96 | ) 97 | return logits 98 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /flow_matching/loss/generalized_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor 9 | from torch.nn.modules.loss import _Loss 10 | 11 | from flow_matching.path import MixtureDiscreteProbPath 12 | 13 | 14 | class MixturePathGeneralizedKL(_Loss): 15 | r"""A generalized KL loss for discrete flow matching. 16 | A class that measures the generalized KL of a discrete flow model :math:`p_{1|t}` w.r.t. a probability path given by ``path``. Note: this class is assuming that the model is trained on the same path. 17 | 18 | For a model trained on a space :math:`\mathcal{S} = \mathcal{T}^d`, :math:`\mathcal{T} = [K] = \set{1,2,\ldots,K}`, the loss is given by 19 | 20 | .. math:: 21 | \ell_i(x_1, x_t, t) = -\frac{\dot{\kappa}_t}{1-\kappa_t} \biggr[ p_{1|t}(x_t^i|x_t) -\delta_{x^i_1}(x_t^i) + (1-\delta_{x^i_1}(x_t^i))\left(\log p_{1|t}(x_1^i|x_t)\right)\biggr], 22 | 23 | where :math:`\kappa_t` is the scheduler associated with ``path``. 24 | 25 | Args: 26 | path (MixtureDiscreteProbPath): Probability path (x-prediction training). 27 | reduction (str, optional): Specify the reduction to apply to the output ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction is applied to the output, ``'mean'``: the output is reduced by mean over sequence elements, ``'sum'``: the output is reduced by sum over sequence elements. Defaults to 'mean'. 28 | """ 29 | 30 | def __init__(self, path: MixtureDiscreteProbPath, reduction: str = "mean") -> None: 31 | super().__init__(None, None, reduction) 32 | self.path = path 33 | 34 | def forward(self, logits: Tensor, x_1: Tensor, x_t: Tensor, t: Tensor) -> Tensor: 35 | r"""Evaluates the generalized KL loss. 36 | 37 | Args: 38 | logits (Tensor): posterior model output (i.e., softmax(``logits``) :math:`=p_{1|t}(x|x_t)`), shape (batch, d, K). 39 | x_1 (Tensor): target data point :math:`x_1 \sim q`, shape (batch, d). 40 | x_t (Tensor): conditional sample at :math:`x_t \sim p_t(\cdot|x_1)`, shape (batch, d). 41 | t (Tensor): times in :math:`[0,1]`, shape (batch). 42 | 43 | Raises: 44 | ValueError: reduction value must be one of ``'none'`` | ``'mean'`` | ``'sum'``. 45 | 46 | Returns: 47 | Tensor: Generalized KL loss. 48 | """ 49 | x_1_shape = x_1.shape 50 | 51 | # extract x_1 value of log(p_{1|t}(x|x_t)). 52 | log_p_1t = torch.log_softmax(logits, dim=-1) 53 | log_p_1t_x1 = torch.gather(log_p_1t, dim=-1, index=x_1.unsqueeze(-1)) 54 | log_p_1t_x1 = log_p_1t_x1.view(*x_1_shape) 55 | 56 | # extract x_t value of p_{1|t}(x|x_t). 57 | p_1t = torch.exp(log_p_1t) 58 | p_1t_xt = torch.gather(p_1t, dim=-1, index=x_t.unsqueeze(-1)) 59 | p_1t_xt = p_1t_xt.view(*x_1_shape) 60 | 61 | scheduler_output = self.path.scheduler(t) 62 | 63 | jump_coefficient = ( 64 | scheduler_output.d_alpha_t / (1 - scheduler_output.alpha_t) 65 | )[(...,) + (None,) * (x_1.dim() - 1)] 66 | jump_coefficient = jump_coefficient.repeat(1, *x_1_shape[1:]) 67 | delta_x1_xt = (x_t == x_1).to(log_p_1t.dtype) 68 | 69 | loss = -jump_coefficient * ( 70 | p_1t_xt - delta_x1_xt + (1 - delta_x1_xt) * log_p_1t_x1 71 | ) 72 | 73 | if self.reduction == "mean": 74 | return torch.mean(loss) 75 | elif self.reduction == "sum": 76 | return torch.sum(loss) 77 | elif self.reduction == "none": 78 | return loss 79 | else: 80 | raise ValueError(f"{self.reduction} is not a valid value for reduction") 81 | -------------------------------------------------------------------------------- /examples/image/models/model_configs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from typing import Union 7 | 8 | from models.discrete_unet import DiscreteUNetModel 9 | from models.ema import EMA 10 | from models.unet import UNetModel 11 | 12 | MODEL_CONFIGS = { 13 | "imagenet": { 14 | "in_channels": 3, 15 | "model_channels": 192, 16 | "out_channels": 3, 17 | "num_res_blocks": 3, 18 | "attention_resolutions": [2, 4, 8], 19 | "dropout": 0.1, 20 | "channel_mult": [1, 2, 3, 4], 21 | "num_classes": 1000, 22 | "use_checkpoint": False, 23 | "num_heads": 4, 24 | "num_head_channels": 64, 25 | "use_scale_shift_norm": True, 26 | "resblock_updown": True, 27 | "use_new_attention_order": True, 28 | "with_fourier_features": False, 29 | }, 30 | "imagenet_discrete": { 31 | "in_channels": 3, 32 | "model_channels": 192, 33 | "out_channels": 3, 34 | "num_res_blocks": 4, 35 | "attention_resolutions": [2, 4, 8], 36 | "dropout": 0.2, 37 | "channel_mult": [2, 3, 4, 4], 38 | "num_classes": 1000, 39 | "use_checkpoint": False, 40 | "num_heads": -1, 41 | "num_head_channels": 64, 42 | "use_scale_shift_norm": True, 43 | "resblock_updown": True, 44 | "use_new_attention_order": True, 45 | "with_fourier_features": False, 46 | }, 47 | "cifar10": { 48 | "in_channels": 3, 49 | "model_channels": 128, 50 | "out_channels": 3, 51 | "num_res_blocks": 4, 52 | "attention_resolutions": [2], 53 | "dropout": 0.3, 54 | "channel_mult": [2, 2, 2], 55 | "conv_resample": False, 56 | "dims": 2, 57 | "num_classes": None, 58 | "use_checkpoint": False, 59 | "num_heads": 1, 60 | "num_head_channels": -1, 61 | "num_heads_upsample": -1, 62 | "use_scale_shift_norm": True, 63 | "resblock_updown": False, 64 | "use_new_attention_order": True, 65 | "with_fourier_features": False, 66 | }, 67 | "cifar10_discrete": { 68 | "in_channels": 3, 69 | "model_channels": 96, 70 | "out_channels": 3, 71 | "num_res_blocks": 5, 72 | "attention_resolutions": [2], 73 | "dropout": 0.4, 74 | "channel_mult": [3, 4, 4], 75 | "conv_resample": False, 76 | "dims": 2, 77 | "num_classes": None, 78 | "use_checkpoint": False, 79 | "num_heads": -1, 80 | "num_head_channels": 64, 81 | "num_heads_upsample": -1, 82 | "use_scale_shift_norm": True, 83 | "resblock_updown": False, 84 | "use_new_attention_order": True, 85 | "with_fourier_features": False, 86 | }, 87 | } 88 | 89 | 90 | def instantiate_model( 91 | architechture: str, is_discrete: bool, use_ema: bool 92 | ) -> Union[UNetModel, DiscreteUNetModel]: 93 | assert ( 94 | architechture in MODEL_CONFIGS 95 | ), f"Model architecture {architechture} is missing its config." 96 | 97 | if is_discrete: 98 | if architechture + "_discrete" in MODEL_CONFIGS: 99 | config = MODEL_CONFIGS[architechture + "_discrete"] 100 | else: 101 | config = MODEL_CONFIGS[architechture] 102 | model = DiscreteUNetModel( 103 | vocab_size=257, 104 | **config, 105 | ) 106 | else: 107 | model = UNetModel(**MODEL_CONFIGS[architechture]) 108 | 109 | if use_ema: 110 | return EMA(model=model) 111 | else: 112 | return model 113 | -------------------------------------------------------------------------------- /docs/source/refs.bib: -------------------------------------------------------------------------------- 1 | % Copyright (c) Meta Platforms, Inc. and affiliates. 2 | % All rights reserved. 3 | % 4 | % This source code is licensed under the CC-by-NC license found in the 5 | % LICENSE file in the root directory of this source tree. 6 | 7 | @misc{lipman2023flowmatchinggenerativemodeling, 8 | title={Flow Matching for Generative Modeling}, 9 | author={Yaron Lipman and Ricky T. Q. Chen and Heli Ben-Hamu and Maximilian Nickel and Matt Le}, 10 | year={2023}, 11 | eprint={2210.02747}, 12 | archivePrefix={arXiv}, 13 | primaryClass={cs.LG}, 14 | url={https://arxiv.org/abs/2210.02747}, 15 | } 16 | 17 | @misc{gat2024discreteflowmatching, 18 | title={Discrete Flow Matching}, 19 | author={Itai Gat and Tal Remez and Neta Shaul and Felix Kreuk and Ricky T. Q. Chen and Gabriel Synnaeve and Yossi Adi and Yaron Lipman}, 20 | year={2024}, 21 | eprint={2407.15595}, 22 | archivePrefix={arXiv}, 23 | primaryClass={cs.LG}, 24 | url={https://arxiv.org/abs/2407.15595}, 25 | } 26 | 27 | @misc{chen2024flowmatchinggeneralgeometries, 28 | title={Flow Matching on General Geometries}, 29 | author={Ricky T. Q. Chen and Yaron Lipman}, 30 | year={2024}, 31 | eprint={2302.03660}, 32 | archivePrefix={arXiv}, 33 | primaryClass={cs.LG}, 34 | url={https://arxiv.org/abs/2302.03660}, 35 | } 36 | 37 | @misc{holderrieth2024generator, 38 | title={Generator Matching: Generative modeling with arbitrary Markov processes}, 39 | author={Holderrieth, Peter and Havasi, Marton and Yim, Jason and Shaul, Neta and Gat, Itai and Jaakkola, Tommi and Karrer, Brian and Chen, Ricky TQ and Lipman, Yaron}, 40 | eprint={2410.20587}, 41 | archivePrefix={arXiv}, 42 | primaryClass={cs.LG}, 43 | url={https://arxiv.org/abs/2410.20587}, 44 | year={2024} 45 | } 46 | 47 | @misc{shaul2024flow, 48 | title={Flow Matching with General Discrete Paths: A Kinetic-Optimal Perspective}, 49 | author={Neta Shaul and Itai Gat and Marton Havasi and Daniel Severo and Anuroop Sriram and Peter Holderrieth and Brian Karrer and Yaron Lipman and Ricky T. Q. Chen}, 50 | eprint={2412.03487}, 51 | archivePrefix={arXiv}, 52 | primaryClass={cs.LG}, 53 | url={https://arxiv.org/abs/2412.03487}, 54 | year={2024} 55 | } 56 | 57 | @article{albergo2022building, 58 | title={Building normalizing flows with stochastic interpolants}, 59 | author={Albergo, Michael S and Vanden-Eijnden, Eric}, 60 | journal={arXiv preprint arXiv:2209.15571}, 61 | year={2022} 62 | } 63 | 64 | 65 | 66 | @article{liu2022flow, 67 | title={Flow straight and fast: Learning to generate and transfer data with rectified flow}, 68 | author={Liu, Xingchao and Gong, Chengyue and Liu, Qiang}, 69 | journal={arXiv preprint arXiv:2209.03003}, 70 | year={2022} 71 | } 72 | 73 | @article{tong2023improving, 74 | title={Improving and generalizing flow-based generative models with minibatch optimal transport}, 75 | author={Tong, Alexander and Malkin, Nikolay and Huguet, Guillaume and Zhang, Yanlei and Rector-Brooks, Jarrid and Fatras, Kilian and Wolf, Guy and Bengio, Yoshua}, 76 | journal={arXiv preprint arXiv:2302.00482}, 77 | year={2023} 78 | } 79 | 80 | @article{benhamu2022cnfm, 81 | author = {Ben-Hamu, Heli and Cohen, Samuel and Bose, Joey and Amos, Brandon and Nickel, Maximillian and Grover, Aditya and Chen, Ricky T. Q. and Lipman, Yaron}, 82 | journal = {Proceedings of the 39th International Conference on Machine Learning}, 83 | title = {Matching Normalizing Flows and Probability Paths on Manifolds}, 84 | volume = {162}, 85 | year = {2022} 86 | } 87 | 88 | @article{campbell2024generative, 89 | title={Generative Flows on Discrete State-Spaces: Enabling Multimodal Flows with Applications to Protein Co-Design}, 90 | author={Campbell, Andrew and Yim, Jason and Barzilay, Regina and Rainforth, Tom and Jaakkola, Tommi}, 91 | journal={arXiv preprint arXiv:2402.04997}, 92 | year={2024} 93 | } 94 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Flow Matching 4 | 5 | [![arXiv](assets/arXiv-2412.06264-red.svg)](https://arxiv.org/abs/2412.06264) 6 | [![CI](https://github.com/facebookresearch/flow_matching/actions/workflows/ci.yaml/badge.svg)](https://github.com/facebookresearch/flow_matching/actions/workflows/ci.yaml) 7 | [![Coverage](https://github.com/facebookresearch/flow_matching/raw/refs/heads/gh-pages/coverage/coverage-badge.svg)](https://stunning-potato-4k4z71e.pages.github.io/coverage/) 8 | [![License: CC BY-NC 4.0](assets/License-CC_BY--NC_4.0-lightgrey.svg)](https://creativecommons.org/licenses/by-nc/4.0/) 9 | [![PyPI](https://img.shields.io/pypi/v/flow-matching)](https://pypi.org/project/flow-matching/) 10 | 11 | 12 |
13 | 14 | `flow_matching` is a PyTorch library for Flow Matching algorithms, featuring continuous and discrete implementations. It includes examples for both text and image modalities. This repository is part of [Flow Matching Guide and Codebase](https://arxiv.org/abs/2412.06264). 15 | 16 | 17 | ![](./assets/teaser.png) 18 | 19 | ## Installation 20 | 21 | This repository requires Python 3.9 and Pytorch 2.1 or greater. To install the latest version run: 22 | ``` 23 | pip install flow_matching 24 | ``` 25 | 26 | ## Repository structure 27 | 28 | The core and example folders are structured in the following way: 29 | ```bash 30 | . 31 | ├── flow_matching # Core library 32 | │   ├── loss # Loss functions 33 | │   │   └── ... 34 | │   ├── path # Path and schedulers 35 | │   │   ├── ... 36 | │   │   └── scheduler # Schedulers and transformations 37 | │   │   └── ... 38 | │   ├── solver # Solvers for continuous and discrete flows 39 | │   │   └── ... 40 | │   └── utils 41 | │   └── ... 42 | └── examples # Synthetic, image, and text examples 43 |     ├── ... 44 |     ├── image 45 |    │   └── ... 46 |     └── text 47 |        └── ... 48 | ``` 49 | 50 | ## Development 51 | 52 | To create a conda environment with all required dependencies, run: 53 | ``` 54 | conda env create -f environment.yml 55 | conda activate flow_matching 56 | ``` 57 | 58 | Install pre-commit hook. This will ensure that all linting is done on each commit 59 | ``` 60 | pre-commit install 61 | ``` 62 | 63 | Install the `flow_matching` package in an editable mode: 64 | ``` 65 | pip install -e . 66 | ``` 67 | 68 | ## FAQ 69 | 70 | #### I want to train a Flow Matching model, where can I find the training code? 71 | 72 | We provide [training examples](examples). Under this folder, you can find synthetic data for [continuous](examples/2d_flow_matching.ipynb), [discrete](examples/2d_discrete_flow_matching.ipynb), and [Riemannian](examples/2d_riemannian_flow_matching_flat_torus.ipynb) Flow Matching. We also provide full training [examples](examples/image) (continuous and discrete) on CIFAR10 and face-blurred ImageNet, and a scalable discrete Flow Matching example for [text modeling](examples/text). 73 | 74 | #### Do you release pre-trained models? 75 | 76 | In this version, we don't release pre-trained models. All models under [examples](examples) can be trained from scratch by a single running command. 77 | 78 | #### How to contribute to this codebase? 79 | Please follow the [contribution guide](CONTRIBUTING.md). 80 | 81 | ## License 82 | 83 | The code in this repository is CC BY-NC licensed. See the [LICENSE](LICENSE) for details. 84 | 85 | ## Citation 86 | 87 | If you found this repository useful, please cite the following. 88 | 89 | ``` 90 | @misc{lipman2024flowmatchingguidecode, 91 | title={Flow Matching Guide and Code}, 92 | author={Yaron Lipman and Marton Havasi and Peter Holderrieth and Neta Shaul and Matt Le and Brian Karrer and Ricky T. Q. Chen and David Lopez-Paz and Heli Ben-Hamu and Itai Gat}, 93 | year={2024}, 94 | eprint={2412.06264}, 95 | archivePrefix={arXiv}, 96 | primaryClass={cs.LG}, 97 | url={https://arxiv.org/abs/2412.06264}, 98 | } 99 | ``` 100 | -------------------------------------------------------------------------------- /flow_matching/path/geodesic.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from torch import Tensor 10 | from torch.func import jvp, vmap 11 | 12 | from flow_matching.path.path import ProbPath 13 | 14 | from flow_matching.path.path_sample import PathSample 15 | from flow_matching.path.scheduler import ConvexScheduler 16 | from flow_matching.utils import expand_tensor_like 17 | 18 | from flow_matching.utils.manifolds import geodesic, Manifold 19 | 20 | 21 | class GeodesicProbPath(ProbPath): 22 | r"""The ``GeodesicProbPath`` class represents a specific type of probability path where the transformation between distributions is defined through the geodesic path. 23 | Mathematically, a geodesic path can be represented as: 24 | 25 | .. math:: 26 | 27 | X_t = \psi_t(X_0 | X_1) = \exp_{X_1}(\kappa_t \log_{X_1}(X_0)), 28 | 29 | where :math:`X_t` is the transformed data point at time `t`, :math:`X_0` and :math:`X_1` are the source and target data points, respectively, and :math:`\kappa_t` is a scheduler. 30 | 31 | The scheduler is responsible for providing the time-dependent :math:`\kappa_t` and must be differentiable. 32 | 33 | Using ``GeodesicProbPath`` in the flow matching framework: 34 | 35 | .. code-block:: python 36 | # Instantiates a manifold 37 | manifold = FlatTorus() 38 | 39 | # Instantiates a scheduler 40 | scheduler = CondOTScheduler() 41 | 42 | # Instantiates a probability path 43 | my_path = GeodesicProbPath(scheduler, manifold) 44 | mse_loss = torch.nn.MSELoss() 45 | 46 | for x_1 in dataset: 47 | # Sets x_0 to random noise 48 | x_0 = torch.randn() 49 | 50 | # Sets t to a random value in [0,1] 51 | t = torch.rand() 52 | 53 | # Samples the conditional path :math:`X_t \sim p_t(X_t|X_0,X_1)` 54 | path_sample = my_path.sample(x_0=x_0, x_1=x_1, t=t) 55 | 56 | # Computes the MSE loss w.r.t. the velocity 57 | loss = mse_loss(path_sample.dx_t, my_model(x_t, t)) 58 | loss.backward() 59 | 60 | Args: 61 | scheduler (ConvexScheduler): The scheduler that provides :math:`\kappa_t`. 62 | manifold (Manifold): The manifold on which the probability path is defined. 63 | 64 | """ 65 | 66 | def __init__(self, scheduler: ConvexScheduler, manifold: Manifold): 67 | self.scheduler = scheduler 68 | self.manifold = manifold 69 | 70 | def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample: 71 | r"""Sample from the Riemannian probability path with geodesic interpolation: 72 | 73 | | given :math:`(X_0,X_1) \sim \pi(X_0,X_1)` and a scheduler :math:`\kappa_t`. 74 | | return :math:`X_0, X_1, X_t = \exp_{X_1}(\kappa_t \log_{X_1}(X_0))`, and the conditional velocity at :math:`X_t, \dot{X}_t`. 75 | 76 | Args: 77 | x_0 (Tensor): source data point, shape (batch_size, ...). 78 | x_1 (Tensor): target data point, shape (batch_size, ...). 79 | t (Tensor): times in [0,1], shape (batch_size). 80 | 81 | Returns: 82 | PathSample: A conditional sample at :math:`X_t \sim p_t`. 83 | """ 84 | self.assert_sample_shape(x_0=x_0, x_1=x_1, t=t) 85 | t = expand_tensor_like(input_tensor=t, expand_to=x_1[..., 0:1]).clone() 86 | 87 | def cond_u(x_0, x_1, t): 88 | path = geodesic(self.manifold, x_0, x_1) 89 | x_t, dx_t = jvp( 90 | lambda t: path(self.scheduler(t).alpha_t), 91 | (t,), 92 | (torch.ones_like(t).to(t),), 93 | ) 94 | return x_t, dx_t 95 | 96 | x_t, dx_t = vmap(cond_u)(x_0, x_1, t) 97 | x_t = x_t.reshape_as(x_1) 98 | dx_t = dx_t.reshape_as(x_1) 99 | 100 | return PathSample(x_t=x_t, dx_t=dx_t, x_1=x_1, x_0=x_0, t=t) 101 | -------------------------------------------------------------------------------- /examples/text/logic/training.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from contextlib import nullcontext 9 | from typing import Optional 10 | 11 | import torch 12 | from flow_matching.loss import MixturePathGeneralizedKL 13 | from flow_matching.path import ProbPath 14 | from omegaconf.dictconfig import DictConfig 15 | from torch import nn, Tensor 16 | from torch.cuda.amp import GradScaler 17 | 18 | from torch.utils.data import DataLoader 19 | from utils.logging import TrainLogger 20 | 21 | from .flow import SourceDistribution 22 | from .state import TrainState 23 | 24 | 25 | def _get_lr(lr: float, step: int, warmup: int, n_iters: int, eta_min_ratio: float): 26 | if step < warmup: 27 | # Linear warmup 28 | return lr * (step / warmup) 29 | else: 30 | # Cosine annealing 31 | total_steps = n_iters 32 | eta_min = eta_min_ratio * lr 33 | cosine_decay = 0.5 * ( 34 | 1 + math.cos(math.pi * (step - warmup) / (total_steps - warmup)) 35 | ) 36 | return eta_min + (lr - eta_min) * cosine_decay 37 | 38 | 39 | def optimization_step( 40 | state: TrainState, 41 | scaler: GradScaler, 42 | loss: Tensor, 43 | optim_params: DictConfig, 44 | logger: TrainLogger, 45 | ) -> None: 46 | scaler.scale(loss).backward() 47 | scaler.unscale_(state.optimizer) 48 | 49 | lr = _get_lr( 50 | lr=optim_params.lr, 51 | step=state.step, 52 | warmup=optim_params.warmup, 53 | n_iters=optim_params.n_iters, 54 | eta_min_ratio=optim_params.eta_min_ratio, 55 | ) 56 | 57 | # Update learning rate in optimizer 58 | for g in state.optimizer.param_groups: 59 | g["lr"] = lr 60 | 61 | if state.step % optim_params.log_lr_every == 0: 62 | logger.log_lr(value=lr, step=state.step) 63 | 64 | if optim_params.grad_clip >= 0: 65 | torch.nn.utils.clip_grad_norm_( 66 | state.model.parameters(), max_norm=optim_params.grad_clip 67 | ) 68 | 69 | scaler.step(state.optimizer) 70 | scaler.update() 71 | 72 | state.optimizer.zero_grad() 73 | 74 | 75 | def step( 76 | state: TrainState, 77 | loss_fn: nn.Module, 78 | path: ProbPath, 79 | scaler: GradScaler, 80 | iterator: DataLoader, 81 | device: torch.device, 82 | source_distribution: SourceDistribution, 83 | logger: TrainLogger, 84 | training: bool, 85 | optim_params: Optional[DictConfig] = None, 86 | time_epsilon: float = 0.0, 87 | ) -> Tensor: 88 | assert (training and (optim_params is not None)) or (not training) 89 | 90 | if training: 91 | state.train() 92 | else: 93 | state.eval() 94 | 95 | x_1 = next(iterator)["input_ids"].to(device) 96 | 97 | # Sample from path 98 | with torch.no_grad(): 99 | x_0 = source_distribution.sample_like(x_1) 100 | t = torch.rand(x_1.shape[0], device=x_1.device) * (1.0 - time_epsilon) 101 | path_sample = path.sample(t=t, x_0=x_0, x_1=x_1) 102 | 103 | # Forward and compute loss 104 | ctx = nullcontext() if training else torch.no_grad() 105 | 106 | with ctx: 107 | logits = state.model(x_t=path_sample.x_t, time=path_sample.t) 108 | 109 | if isinstance(loss_fn, nn.CrossEntropyLoss): 110 | loss = loss_fn(logits.flatten(0, 1), x_1.flatten(0, 1)).mean() 111 | elif isinstance(loss_fn, MixturePathGeneralizedKL): 112 | loss = loss_fn( 113 | logits=logits, x_1=x_1, x_t=path_sample.x_t, t=path_sample.t 114 | ).mean() 115 | else: 116 | raise ValueError("Invalid loss function") 117 | 118 | # Optimization step (only if training=true) 119 | if training: 120 | optimization_step( 121 | state=state, 122 | loss=loss, 123 | scaler=scaler, 124 | optim_params=optim_params, 125 | logger=logger, 126 | ) 127 | 128 | return loss.detach() 129 | -------------------------------------------------------------------------------- /examples/text/utils/logging.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | from logging import Logger 10 | from pathlib import Path 11 | from typing import Optional 12 | 13 | import torch 14 | import wandb 15 | from omegaconf import OmegaConf 16 | 17 | 18 | def get_logger(log_path: str, rank: int): 19 | if rank != 0: 20 | return logging.getLogger("dummy") 21 | 22 | logger = logging.getLogger() 23 | default_level = logging.INFO 24 | 25 | if logger.hasHandlers(): 26 | logger.handlers.clear() 27 | 28 | logger.setLevel(default_level) 29 | 30 | formatter = logging.Formatter( 31 | "%(levelname)s | %(asctime)s | %(message)s", "%Y-%m-%d %H:%M:%S" 32 | ) 33 | 34 | info_file_handler = logging.FileHandler(log_path, mode="a") 35 | info_file_handler.setLevel(default_level) 36 | info_file_handler.setFormatter(formatter) 37 | logger.addHandler(info_file_handler) 38 | 39 | console_handler = logging.StreamHandler() 40 | console_handler.setLevel(default_level) 41 | console_handler.setFormatter(formatter) 42 | logger.addHandler(console_handler) 43 | 44 | return logger 45 | 46 | 47 | class TrainLogger: 48 | def __init__(self, log_dir: Path, rank: int, cfg: bool = False): 49 | self.log_dir = log_dir 50 | self.cfg = cfg 51 | 52 | self._init_text_logger(rank=rank) 53 | 54 | self.enable_wandb = self.cfg.logging.enable_wandb and (rank == 0) 55 | 56 | if self.enable_wandb: 57 | self._init_wandb() 58 | 59 | def _init_text_logger(self, rank: int): 60 | log_path = self.log_dir / self.cfg.logging.log_file_name 61 | self._logger = get_logger(log_path=log_path, rank=rank) 62 | 63 | def _init_wandb( 64 | self, 65 | ): 66 | wandb_run_id_path = self.log_dir / "wandb_run.id" 67 | 68 | try: 69 | wandb_run_id = wandb_run_id_path.read_text() 70 | except FileNotFoundError: 71 | wandb_run_id = wandb.util.generate_id() 72 | wandb_run_id_path.write_text(wandb_run_id) 73 | 74 | self.wandb_logger = wandb.init( 75 | id=wandb_run_id, 76 | project=self.cfg.logging.project, 77 | group=self.cfg.logging.group, 78 | dir=self.log_dir, 79 | entity=self.cfg.logging.entity, 80 | resume="allow", 81 | config=OmegaConf.to_container(self.cfg, resolve=True), 82 | ) 83 | 84 | def log_metric(self, value: float, name: str, stage: bool, step: int) -> None: 85 | self._logger.info(f"[{step}] {stage} {name}: {value:.3f}") 86 | 87 | if self.enable_wandb: 88 | self.wandb_logger.log(data={f"{stage}/{name}": value}, step=step) 89 | 90 | def log_lr(self, value: float, step: int) -> None: 91 | if self.enable_wandb: 92 | self.wandb_logger.log(data={"Optimization/LR": value}, step=step) 93 | 94 | def info(self, msg: str, step: Optional[int] = None) -> None: 95 | step_str = f"[{step}] " if step else "" 96 | self._logger.info(f"{step_str}{msg}") 97 | 98 | def warning(self, msg: str) -> None: 99 | self._logger.warning(msg) 100 | 101 | def finish(self) -> None: 102 | for handler in self._logger.handlers: 103 | if isinstance(handler, logging.FileHandler): 104 | handler.close() 105 | 106 | if self.enable_wandb: 107 | wandb.finish() 108 | 109 | @staticmethod 110 | def log_devices(device: torch.device, logger: Logger) -> None: 111 | if device.type == "cuda": 112 | logger.info("Found {} CUDA devices.".format(torch.cuda.device_count())) 113 | for i in range(torch.cuda.device_count()): 114 | props = torch.cuda.get_device_properties(i) 115 | logger.info( 116 | "{} \t Memory: {:.2f}GB".format( 117 | props.name, props.total_memory / (1024**3) 118 | ) 119 | ) 120 | else: 121 | logger.warning("WARNING: Using device {}".format(device)) 122 | logger.info(f"Found {os.cpu_count()} total number of CPUs.") 123 | -------------------------------------------------------------------------------- /examples/text/logic/evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Part of this implementation is adapted from https://github.com/louaaron/Score-Entropy-Discrete-Diffusion 7 | # which is released under MIT license 8 | 9 | import math 10 | from collections import Counter 11 | from typing import List 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from flow_matching.loss import MixturePathGeneralizedKL 16 | from flow_matching.path import MixtureDiscreteProbPath, ProbPath 17 | from flow_matching.path.scheduler import PolynomialConvexScheduler 18 | from flow_matching.utils import ModelWrapper 19 | from torch import nn, Tensor 20 | from torch.utils.data import DataLoader 21 | from tqdm import tqdm 22 | from transformers import GPT2LMHeadModel 23 | 24 | from logic.flow import SourceDistribution 25 | 26 | 27 | class WrappedModel(ModelWrapper): 28 | def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor: 29 | return self.model(x_t=x, time=t).float() 30 | 31 | 32 | @torch.no_grad() 33 | def compute_perplexity(samples: Tensor, perplexity_batch_size: int) -> Tensor: 34 | eval_model = GPT2LMHeadModel.from_pretrained("gpt2-large").to(samples.device).eval() 35 | batches = samples.shape[0] // perplexity_batch_size 36 | total_perplexity = 0 37 | 38 | for i in range(batches): 39 | s = samples[i * perplexity_batch_size : (i + 1) * perplexity_batch_size] 40 | _, logits = eval_model(s, labels=s)[:2] 41 | logits = logits.transpose(-1, -2).detach() 42 | 43 | perplexity = F.cross_entropy(logits[..., :-1], s[..., 1:], reduction="none") 44 | perplexity = perplexity.mean(dim=-1).exp().mean() 45 | 46 | total_perplexity += perplexity 47 | 48 | total_perplexity /= batches 49 | 50 | return total_perplexity 51 | 52 | 53 | def _sample_entropy(sample: List) -> float: 54 | histogram = Counter(sample) 55 | total = sum(histogram.values()) 56 | entropy = 0 57 | 58 | for count in histogram.values(): 59 | p = count / total 60 | entropy -= p * math.log2(p) 61 | 62 | return entropy 63 | 64 | 65 | def compute_entropy(samples: Tensor) -> Tensor: 66 | entropies = [_sample_entropy(sample.tolist()) for sample in samples] 67 | entropy = sum(entropies) / len(entropies) 68 | 69 | return torch.tensor(entropy, device=samples.device) 70 | 71 | 72 | @torch.no_grad() 73 | def estimate_likelihood( 74 | model: nn.Module, 75 | dataloader: DataLoader, 76 | source_distribution: SourceDistribution, 77 | path: ProbPath, 78 | n_discretization: int, 79 | device: torch.device, 80 | batch_size: int = 32, 81 | epsilon: float = 1e-3, 82 | ) -> Tensor: 83 | model = WrappedModel(model) 84 | 85 | # Generalized KL function (will use it to compute the elbo) 86 | linear_scheduler = PolynomialConvexScheduler(n=1.0) 87 | linear_path = MixtureDiscreteProbPath(scheduler=linear_scheduler) 88 | 89 | generalized_kl_fn = MixturePathGeneralizedKL(path=linear_path, reduction="none") 90 | 91 | # Time discretization 92 | discretization = ( 93 | torch.linspace(0, 1, n_discretization + 1, device=device)[:-1] 94 | .view(-1, 1) 95 | .repeat(1, batch_size) 96 | ) 97 | 98 | elbo = torch.zeros((1,), device=device) 99 | n_elements = torch.zeros((1,), device=device) 100 | 101 | for x_1 in tqdm(dataloader, total=len(dataloader)): 102 | x_1 = x_1["input_ids"].to(device) 103 | 104 | # Lower variance estimator for time discretization 105 | discretization = discretization + torch.rand( 106 | size=(1, batch_size), device=device 107 | ) 108 | discretization = discretization % 1 109 | discretization = discretization * (1 - epsilon) 110 | 111 | for k in discretization[:, : x_1.shape[0]]: 112 | x_0 = source_distribution.sample_like(x_1) 113 | x_t = linear_path.sample(t=k, x_0=x_0, x_1=x_1).x_t 114 | 115 | t = path.scheduler.kappa_inverse(k) 116 | 117 | logits = model(x=x_t, t=t) 118 | 119 | generalized_kl = generalized_kl_fn(logits=logits, x_1=x_1, x_t=x_t, t=k) 120 | n_elements += generalized_kl.numel() 121 | 122 | elbo += generalized_kl.sum() 123 | 124 | return elbo, n_elements 125 | -------------------------------------------------------------------------------- /examples/text/README.md: -------------------------------------------------------------------------------- 1 | # Text example 2 | 3 | This example implements training of a discrete flow matching model on text data. This repository provides the necessary tools and scripts to train and evaluate these models. 4 | 5 | **Note:** this example was tested only using PyTorch 2.5 and on a single node of H100 (8 gpus). With this setup, we achieved approximately 380k training steps in 24 hours. 6 | 7 | ## Installation 8 | 9 | To get started with this project, follow these steps to set up your environment: 10 | 11 | ```bash 12 | conda env create -f environment.yml 13 | conda activate discrete_flow_matching 14 | ``` 15 | 16 | ## Usage 17 | 18 | Specify the data cache and checkpoint directories. Data will automatically be downloaded into the cache directory. 19 | ```bash 20 | CACHE_DIR=... 21 | HYDRA_RUN_DIR=... 22 | ``` 23 | 24 | To train a discrete flow matching model on fine-web-edu, run: 25 | 26 | ```bash 27 | python run_train.py data.cache_dir=${CACHE_DIR} 28 | ``` 29 | 30 | To use `slurm`, modify the `slurm` config according to the cluster you are working on, and run: 31 | ```bash 32 | python run_train.py data.cache_dir=${CACHE_DIR} hydra_dir=${HYDRA_RUN_DIR} -m & 33 | ``` 34 | 35 | ## Results 36 | 37 | We trained models with linear scheduler (`PolynomialConvexScheduler(n=1.0)`) for one million steps on FineWeb-EDU. 38 | 39 | ```bash 40 | PYTHONPATH="." python scripts/run_eval.py --work_dir "/path/to/exp/folder" --ngpus 8 --eval_elbo --eval_perplexity 41 | ``` 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 |
SchedulerSource distributionLossGenerative perplexityELBO
LinearMaskCross-entropy
128.9
53.2
Generalized KL
132.2
47.9
UniformCross-entropy
90.9
71.7
Generalized KL
82.1
71.3
79 | 80 | ## Folder structure 81 | 82 | ```bash 83 | . 84 | ├── configs # Train configs 85 | │   └── ... 86 | ├── data # Data loading and preprocessing 87 | │   └── ... 88 | ├── logic # Logic components, such as flow related classes 89 | │   └── ... 90 | ├── model # Transformer implementation 91 | │   └── ... 92 | ├── scripts # Evaluation script 93 | │   └── ... 94 | ├── utils # Utility functions 95 | │ └── ... 96 | ├── README.md 97 | ├── environment.yml 98 | ├── train.py 99 | └── run_train.py # Run training script 100 | ``` 101 | 102 | ## Implemented methods 103 | 104 | This repository implements the following papers: 105 | - [Discrete Flow Matching](https://arxiv.org/abs/2407.15595) 106 | - [Flow Matching with General Discrete Paths: A Kinetic-Optimal Perspective](https://arxiv.org/abs/2412.03487) 107 | - [Generative Flows on Discrete State-Spaces: Enabling Multimodal Flows with Applications to Protein Co-Design](https://arxiv.org/abs/2402.04997) 108 | - [Simplified and Generalized Masked Diffusion for Discrete Data](https://arxiv.org/abs/2406.04329) 109 | 110 | 111 | ## Acknowledgements 112 | 113 | This example partially use code from: 114 | - [Flash attention](https://github.com/Dao-AILab/flash-attention) 115 | - [Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution](https://github.com/louaaron/Score-Entropy-Discrete-Diffusion) 116 | - [GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models](https://github.com/openai/glide-text2im/) 117 | - [TorchData](https://github.com/pytorch/data/tree/main) 118 | 119 | ## License 120 | 121 | The majority of the code in this example is licensed under CC-BY-NC, however portions of the project are available under separate license terms: 122 | - flash attention and TorchData are under BSD 3 license. 123 | - Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution and GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models are under MIT license. 124 | -------------------------------------------------------------------------------- /tests/solver/test_discrete_solver.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import unittest 7 | 8 | import torch 9 | from flow_matching.path import MixtureDiscreteProbPath 10 | from flow_matching.path.scheduler import PolynomialConvexScheduler 11 | from flow_matching.solver import MixtureDiscreteEulerSolver 12 | 13 | 14 | class DummyModel(torch.nn.Module): 15 | def forward(self, x, t, **extras): 16 | return torch.stack( 17 | [torch.zeros_like(x), torch.zeros_like(x), torch.ones_like(x)], dim=-1 18 | ) 19 | 20 | 21 | class TestMixtureDiscreteEulerSolver(unittest.TestCase): 22 | def setUp(self): 23 | self.model = DummyModel() 24 | self.path = MixtureDiscreteProbPath(scheduler=PolynomialConvexScheduler(n=1.0)) 25 | self.vocabulary_size = 3 26 | self.source_distribution_p = torch.tensor([0.5, 0.5, 0.0]) 27 | 28 | def test_init(self): 29 | solver = MixtureDiscreteEulerSolver( 30 | model=self.model, 31 | path=self.path, 32 | vocabulary_size=self.vocabulary_size, 33 | source_distribution_p=self.source_distribution_p, 34 | ) 35 | self.assertEqual(solver.model, self.model) 36 | self.assertEqual(solver.path, self.path) 37 | self.assertEqual(solver.vocabulary_size, self.vocabulary_size) 38 | self.assertTrue( 39 | torch.allclose(solver.source_distribution_p, self.source_distribution_p) 40 | ) 41 | 42 | def test_sample(self): 43 | solver = MixtureDiscreteEulerSolver( 44 | model=self.model, 45 | path=self.path, 46 | vocabulary_size=self.vocabulary_size, 47 | source_distribution_p=self.source_distribution_p, 48 | ) 49 | x_init = torch.tensor([[0]]) 50 | step_size = 0.1 51 | time_grid = torch.tensor([0.0, 1.0]) 52 | result = solver.sample(x_init, step_size, time_grid=time_grid) 53 | self.assertEqual(result, torch.ones_like(result) * 2) 54 | 55 | def test_sample_with_sym_term(self): 56 | solver = MixtureDiscreteEulerSolver( 57 | model=self.model, 58 | path=self.path, 59 | vocabulary_size=self.vocabulary_size, 60 | source_distribution_p=self.source_distribution_p, 61 | ) 62 | x_init = torch.tensor([[0]]) 63 | step_size = 0.1 64 | time_grid = torch.tensor([0.0, 1.0]) 65 | div_free = 1.0 66 | result = solver.sample( 67 | x_init, step_size, time_grid=time_grid, div_free=div_free, verbose=True 68 | ) 69 | self.assertIsInstance(result, torch.Tensor) 70 | result = solver.sample( 71 | x_init, step_size, time_grid=time_grid, div_free=lambda t: 1.0, verbose=True 72 | ) 73 | self.assertIsInstance(result, torch.Tensor) 74 | 75 | def test_init_p_none(self): 76 | solver = MixtureDiscreteEulerSolver( 77 | model=self.model, 78 | path=self.path, 79 | vocabulary_size=self.vocabulary_size, 80 | ) 81 | self.assertIsNone(solver.source_distribution_p) 82 | 83 | def test_sample_time_grid(self): 84 | solver = MixtureDiscreteEulerSolver( 85 | model=self.model, 86 | path=self.path, 87 | vocabulary_size=self.vocabulary_size, 88 | source_distribution_p=self.source_distribution_p, 89 | ) 90 | x_init = torch.tensor([0]) 91 | time_grid = torch.linspace(0.0, 1.0, steps=11) 92 | result = solver.sample( 93 | x_init, step_size=None, time_grid=time_grid, return_intermediates=True 94 | ) 95 | self.assertEqual(result[-1], torch.ones_like(result[-1]) * 2) 96 | self.assertEqual(result.shape, (11, 1)) 97 | 98 | def test_sample_return_intermediate(self): 99 | solver = MixtureDiscreteEulerSolver( 100 | model=self.model, 101 | path=self.path, 102 | vocabulary_size=self.vocabulary_size, 103 | source_distribution_p=self.source_distribution_p, 104 | ) 105 | x_init = torch.tensor([0]) 106 | time_grid = torch.linspace(0.0, 1.0, steps=3) 107 | result = solver.sample( 108 | x_init, step_size=0.1, time_grid=time_grid, return_intermediates=True 109 | ) 110 | self.assertEqual(result[-1], torch.ones_like(result[-1]) * 2) 111 | self.assertEqual(result.shape, (3, 1)) 112 | 113 | 114 | if __name__ == "__main__": 115 | unittest.main() 116 | -------------------------------------------------------------------------------- /flow_matching/path/mixture.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from torch import Tensor 11 | 12 | from flow_matching.path.path import ProbPath 13 | 14 | from flow_matching.path.path_sample import DiscretePathSample 15 | from flow_matching.path.scheduler import ConvexScheduler 16 | from flow_matching.utils import expand_tensor_like, unsqueeze_to_match 17 | 18 | 19 | class MixtureDiscreteProbPath(ProbPath): 20 | r"""The ``MixtureDiscreteProbPath`` class defines a factorized discrete probability path. 21 | 22 | This path remains constant at the source data point :math:`X_0` until a random time, determined by the scheduler, when it flips to the target data point :math:`X_1`. 23 | The scheduler determines the flip probability using the parameter :math:`\sigma_t`, which is a function of time `t`. Specifically, :math:`\sigma_t` represents the probability of remaining at :math:`X_0`, while :math:`1 - \sigma_t` is the probability of flipping to :math:`X_1`: 24 | 25 | .. math:: 26 | 27 | P(X_t = X_0) = \sigma_t \quad \text{and} \quad P(X_t = X_1) = 1 - \sigma_t, 28 | 29 | where :math:`\sigma_t` is provided by the scheduler. 30 | 31 | Example: 32 | 33 | .. code-block:: python 34 | 35 | >>> x_0 = torch.zeros((1, 3, 3)) 36 | >>> x_1 = torch.ones((1, 3, 3)) 37 | 38 | >>> path = MixtureDiscreteProbPath(PolynomialConvexScheduler(n=1.0)) 39 | >>> result = path.sample(x_0, x_1, t=torch.tensor([0.1])).x_t 40 | >>> result 41 | tensor([[[0.0, 0.0, 0.0], 42 | [0.0, 0.0, 1.0], 43 | [0.0, 0.0, 0.0]]]) 44 | 45 | >>> result = path.sample(x_0, x_1, t=torch.tensor([0.5])).x_t 46 | >>> result 47 | tensor([[[1.0, 0.0, 1.0], 48 | [0.0, 1.0, 0.0], 49 | [0.0, 1.0, 0.0]]]) 50 | 51 | >>> result = path.sample(x_0, x_1, t=torch.tensor([1.0])).x_t 52 | >>> result 53 | tensor([[[1.0, 1.0, 1.0], 54 | [1.0, 1.0, 1.0], 55 | [1.0, 1.0, 1.0]]]) 56 | 57 | Args: 58 | scheduler (ConvexScheduler): The scheduler that provides :math:`\sigma_t`. 59 | """ 60 | 61 | def __init__(self, scheduler: ConvexScheduler): 62 | assert isinstance( 63 | scheduler, ConvexScheduler 64 | ), "Scheduler for ConvexProbPath must be a ConvexScheduler." 65 | 66 | self.scheduler = scheduler 67 | 68 | def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> DiscretePathSample: 69 | r"""Sample from the affine probability path: 70 | | given :math:`(X_0,X_1) \sim \pi(X_0,X_1)` and a scheduler :math:`(\alpha_t,\sigma_t)`. 71 | | return :math:`X_0, X_1, t`, and :math:`X_t \sim p_t`. 72 | Args: 73 | x_0 (Tensor): source data point, shape (batch_size, ...). 74 | x_1 (Tensor): target data point, shape (batch_size, ...). 75 | t (Tensor): times in [0,1], shape (batch_size). 76 | 77 | Returns: 78 | DiscretePathSample: a conditional sample at :math:`X_t ~ p_t`. 79 | """ 80 | self.assert_sample_shape(x_0=x_0, x_1=x_1, t=t) 81 | 82 | sigma_t = self.scheduler(t).sigma_t 83 | 84 | sigma_t = expand_tensor_like(input_tensor=sigma_t, expand_to=x_1) 85 | 86 | source_indices = torch.rand(size=x_1.shape, device=x_1.device) < sigma_t 87 | x_t = torch.where(condition=source_indices, input=x_0, other=x_1) 88 | 89 | return DiscretePathSample(x_t=x_t, x_1=x_1, x_0=x_0, t=t) 90 | 91 | def posterior_to_velocity( 92 | self, posterior_logits: Tensor, x_t: Tensor, t: Tensor 93 | ) -> Tensor: 94 | r"""Convert the factorized posterior to velocity. 95 | 96 | | given :math:`p(X_1|X_t)`. In the factorized case: :math:`\prod_i p(X_1^i | X_t)`. 97 | | return :math:`u_t`. 98 | 99 | Args: 100 | posterior_logits (Tensor): logits of the x_1 posterior conditional on x_t, shape (..., vocab size). 101 | x_t (Tensor): path sample at time t, shape (...). 102 | t (Tensor): time in [0,1]. 103 | 104 | Returns: 105 | Tensor: velocity. 106 | """ 107 | posterior = torch.softmax(posterior_logits, dim=-1) 108 | vocabulary_size = posterior.shape[-1] 109 | x_t = F.one_hot(x_t, num_classes=vocabulary_size) 110 | t = unsqueeze_to_match(source=t, target=x_t) 111 | 112 | scheduler_output = self.scheduler(t) 113 | 114 | kappa_t = scheduler_output.alpha_t 115 | d_kappa_t = scheduler_output.d_alpha_t 116 | 117 | return (d_kappa_t / (1 - kappa_t)) * (posterior - x_t) 118 | -------------------------------------------------------------------------------- /examples/image/training/train_loop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import argparse 7 | import gc 8 | import logging 9 | import math 10 | from typing import Iterable 11 | 12 | import torch 13 | from flow_matching.path import CondOTProbPath, MixtureDiscreteProbPath 14 | from flow_matching.path.scheduler import PolynomialConvexScheduler 15 | from models.ema import EMA 16 | from torch.nn.parallel import DistributedDataParallel 17 | from torchmetrics.aggregation import MeanMetric 18 | from training.grad_scaler import NativeScalerWithGradNormCount 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | MASK_TOKEN = 256 23 | PRINT_FREQUENCY = 50 24 | 25 | 26 | def skewed_timestep_sample(num_samples: int, device: torch.device) -> torch.Tensor: 27 | P_mean = -1.2 28 | P_std = 1.2 29 | rnd_normal = torch.randn((num_samples,), device=device) 30 | sigma = (rnd_normal * P_std + P_mean).exp() 31 | time = 1 / (1 + sigma) 32 | time = torch.clip(time, min=0.0001, max=1.0) 33 | return time 34 | 35 | 36 | def train_one_epoch( 37 | model: torch.nn.Module, 38 | data_loader: Iterable, 39 | optimizer: torch.optim.Optimizer, 40 | lr_schedule: torch.torch.optim.lr_scheduler.LRScheduler, 41 | device: torch.device, 42 | epoch: int, 43 | loss_scaler: NativeScalerWithGradNormCount, 44 | args: argparse.Namespace, 45 | ): 46 | gc.collect() 47 | model.train(True) 48 | batch_loss = MeanMetric().to(device, non_blocking=True) 49 | epoch_loss = MeanMetric().to(device, non_blocking=True) 50 | 51 | accum_iter = args.accum_iter 52 | if args.discrete_flow_matching: 53 | scheduler = PolynomialConvexScheduler(n=3.0) 54 | path = MixtureDiscreteProbPath(scheduler=scheduler) 55 | else: 56 | path = CondOTProbPath() 57 | 58 | for data_iter_step, (samples, labels) in enumerate(data_loader): 59 | if data_iter_step % accum_iter == 0: 60 | optimizer.zero_grad() 61 | batch_loss.reset() 62 | if data_iter_step > 0 and args.test_run: 63 | break 64 | 65 | samples = samples.to(device, non_blocking=True) 66 | labels = labels.to(device, non_blocking=True) 67 | 68 | if torch.rand(1) < args.class_drop_prob: 69 | conditioning = {} 70 | else: 71 | conditioning = {"label": labels} 72 | 73 | if args.discrete_flow_matching: 74 | samples = (samples * 255.0).to(torch.long) 75 | t = torch.torch.rand(samples.shape[0]).to(device) 76 | 77 | # sample probability path 78 | x_0 = ( 79 | torch.zeros(samples.shape, dtype=torch.long, device=device) + MASK_TOKEN 80 | ) 81 | path_sample = path.sample(t=t, x_0=x_0, x_1=samples) 82 | 83 | # discrete flow matching loss 84 | logits = model(path_sample.x_t, t=t, extra=conditioning) 85 | loss = torch.nn.functional.cross_entropy( 86 | logits.reshape([-1, 257]), samples.reshape([-1]) 87 | ).mean() 88 | else: 89 | # Scaling to [-1, 1] from [0, 1] 90 | samples = samples * 2.0 - 1.0 91 | noise = torch.randn_like(samples).to(device) 92 | if args.skewed_timesteps: 93 | t = skewed_timestep_sample(samples.shape[0], device=device) 94 | else: 95 | t = torch.torch.rand(samples.shape[0]).to(device) 96 | path_sample = path.sample(t=t, x_0=noise, x_1=samples) 97 | x_t = path_sample.x_t 98 | u_t = path_sample.dx_t 99 | 100 | with torch.cuda.amp.autocast(): 101 | loss = torch.pow(model(x_t, t, extra=conditioning) - u_t, 2).mean() 102 | 103 | loss_value = loss.item() 104 | batch_loss.update(loss) 105 | epoch_loss.update(loss) 106 | 107 | if not math.isfinite(loss_value): 108 | raise ValueError(f"Loss is {loss_value}, stopping training") 109 | 110 | loss /= accum_iter 111 | 112 | # Loss scaler applies the optimizer when update_grad is set to true. 113 | # Otherwise just updates the internal gradient scales 114 | apply_update = (data_iter_step + 1) % accum_iter == 0 115 | loss_scaler( 116 | loss, 117 | optimizer, 118 | parameters=model.parameters(), 119 | update_grad=apply_update, 120 | ) 121 | if apply_update and isinstance(model, EMA): 122 | model.update_ema() 123 | elif ( 124 | apply_update 125 | and isinstance(model, DistributedDataParallel) 126 | and isinstance(model.module, EMA) 127 | ): 128 | model.module.update_ema() 129 | 130 | lr = optimizer.param_groups[0]["lr"] 131 | if data_iter_step % PRINT_FREQUENCY == 0: 132 | logger.info( 133 | f"Epoch {epoch} [{data_iter_step}/{len(data_loader)}]: loss = {batch_loss.compute()}, lr = {lr}" 134 | ) 135 | 136 | lr_schedule.step() 137 | return {"loss": float(epoch_loss.compute().detach().cpu())} 138 | -------------------------------------------------------------------------------- /examples/image/README.md: -------------------------------------------------------------------------------- 1 | # Image example 2 | 3 | ## Training instructions 4 | 5 | 1. Download and unpack blurred ImageNet from the [official website](https://image-net.org/download.php). 6 | 7 | ``` 8 | export IMAGENET_DIR=~/flow_matching/examples/image/data/ 9 | export IMAGENET_RES=64 10 | tar -xf ~/Downloads/train_blurred.tar.gz -C $IMAGENET_DIR 11 | ``` 12 | 13 | 2. Downsample Imagenet to the desired resolution. 14 | 15 | ``` 16 | cd ~/ 17 | git clone git@github.com:PatrykChrabaszcz/Imagenet32_Scripts.git 18 | python Imagenet32_Scripts/image_resizer_imagent.py -i ${IMAGENET_DIR}train_blurred -o ${IMAGENET_DIR}train_blurred_$IMAGENET_RES -s $IMAGENET_RES -a box -r -j 10 19 | ``` 20 | 21 | 3. Set up the virtual environment. First, set up the virtual environment by following the steps in the repository's `README.md`. Then, 22 | 23 | ``` 24 | conda activate flow_matching 25 | 26 | cd examples/image 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | 4. [Optional] Test-run training locally. A test run executes one step of training followed by one step of evaluation. 31 | 32 | ``` 33 | python train.py --data_path=${IMAGENET_DIR}train_blurred_$IMAGENET_RES/box/ --test_run 34 | ``` 35 | 36 | 5. Launch training on a SLURM cluster 37 | 38 | ``` 39 | python submitit_train.py --data_path=${IMAGENET_DIR}train_blurred_$IMAGENET_RES/box/ 40 | ``` 41 | 42 | 6. Evaluate the model using the `--eval_only` flag. The evaluation script will generate snapshots under the `/snapshots` folder. Specify the `--compute_fid` flag to also compute the FID with respect to the training set. Make sure to specify your most recent checkpoint to resume from. The results are printed to `log.txt`. 43 | 44 | ``` 45 | python submitit_train.py --data_path=${IMAGENET_DIR}train_blurred_$IMAGENET_RES/box/ --resume=./output_dir/checkpoint-899.pth --compute_fid --eval_only 46 | ``` 47 | 48 | 49 | ## Results 50 | | Data | Model type | Epochs | FID | Command | 51 | |-----------------------|----------------------------------|-------|------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| 52 | | Cifar10 | Unconditional UNet | 1800 | 2.07 | `python submitit_train.py \`
`--dataset=cifar10 \`
`--batch_size=64 \`
`--nodes=1 \`
`--accum_iter=1 \`
`--eval_frequency=100 \`
`--epochs=3000 \`
`--class_drop_prob=1.0 \`
`--cfg_scale=0.0 \`
`--compute_fid \`
`--ode_method heun2 \`
`--ode_options '{"nfe": 50}' \`
`--use_ema \`
`--edm_schedule \`
`--skewed_timesteps` | 53 | | ImageNet32 (Blurred) | Class conditional Unet | 900 | 1.14 | `export IMAGENET_RES=32 \`
`python submitit_train.py \`
`--data_path=${IMAGENET_DIR}train_blurred_$IMAGENET_RES/box/ \`
`--batch_size=32 \`
`--nodes=8 \`
`--accum_iter=1 \`
`--eval_frequency=100 \`
`--decay_lr \`
`--compute_fid \`
`--ode_method dopri5 \`
`--ode_options '{"atol": 1e-5, "rtol":1e-5}'` | 54 | | ImageNet64 (Blurred) | Class conditional Unet | 900 | 1.64 | `export IMAGENET_RES=64 \`
`python submitit_train.py \`
`--data_path=${IMAGENET_DIR}train_blurred_$IMAGENET_RES/box/ \`
`--batch_size=32 \`
`--nodes=8 \`
`--accum_iter=1 \`
`--eval_frequency=100 \`
`--decay_lr \`
`--compute_fid \`
`--ode_method dopri5 \`
`--ode_options '{"atol": 1e-5, "rtol":1e-5}'` | 55 | | Cifar10 (Discrete Flow) | Unconditional Unet | 2500 | 3.58 | `python submitit_train.py \`
`--dataset=cifar10 \`
`--nodes=1 \`
`--discrete_flow_matching \`
`--batch_size=32 \`
`--accum_iter=1 \`
`--cfg_scale=0.0 \`
`--use_ema \`
`--epochs=3000 \`
`--class_drop_prob=1.0 \`
`--compute_fid \`
`--sym_func` | 56 | 57 | 58 | 59 | ## Acknowledgements 60 | 61 | This example partially use code from: 62 | - [Guided diffusion](https://github.com/openai/guided-diffusion/) 63 | - [ConvNext](https://github.com/facebookresearch/ConvNeXt) 64 | 65 | ## License 66 | 67 | The majority of the code in this example is licensed under CC-BY-NC, however portions of the project are available under separate license terms: 68 | - The UNet model is under MIT license. 69 | - The distributed computing and the grad scaler code is under MIT license. 70 | 71 | ## Citations 72 | 73 | Deng, Jia, et al. "Imagenet: A large-scale hierarchical image database." 2009 IEEE conference on computer vision and pattern recognition. Ieee, 2009. 74 | 75 | Karras, Tero, et al. "Elucidating the design space of diffusion-based generative models." Advances in neural information processing systems 35 (2022): 26565-26577. 76 | 77 | Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image segmentation." Medical image computing and computer-assisted intervention–MICCAI 2015: 18th international conference, Munich, Germany, October 5-9, 2015, proceedings, part III 18. Springer International Publishing, 2015. 78 | -------------------------------------------------------------------------------- /flow_matching/path/scheduler/schedule_transform.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from torch import Tensor 8 | 9 | from flow_matching.path.scheduler.scheduler import Scheduler 10 | from flow_matching.utils import ModelWrapper 11 | 12 | 13 | class ScheduleTransformedModel(ModelWrapper): 14 | """ 15 | Change of scheduler for a velocity model. 16 | 17 | This class wraps a given velocity model and transforms its scheduling 18 | to a new scheduler function. It modifies the time 19 | dynamics of the model according to the new scheduler while maintaining 20 | the original model's behavior. 21 | 22 | Example: 23 | 24 | .. code-block:: python 25 | 26 | import torch 27 | from flow_matching.path.scheduler import CondOTScheduler, CosineScheduler, ScheduleTransformedModel 28 | from flow_matching.solver import ODESolver 29 | 30 | # Initialize the model and schedulers 31 | model = ... 32 | 33 | original_scheduler = CondOTScheduler() 34 | new_scheduler = CosineScheduler() 35 | 36 | # Create the transformed model 37 | transformed_model = ScheduleTransformedModel( 38 | velocity_model=model, 39 | original_scheduler=original_scheduler, 40 | new_scheduler=new_scheduler 41 | ) 42 | 43 | # Set up the solver 44 | solver = ODESolver(velocity_model=transformed_model) 45 | 46 | x_0 = torch.randn([10, 2]) # Example initial condition 47 | 48 | x_1 = solver.sample( 49 | time_steps=torch.tensor([0.0, 1.0]), 50 | x_init=x_0, 51 | step_size=1/1000 52 | )[1] 53 | 54 | Args: 55 | velocity_model (ModelWrapper): The original velocity model to be transformed. 56 | original_scheduler (Scheduler): The scheduler used by the original model. Must implement the snr_inverse function. 57 | new_scheduler (Scheduler): The new scheduler to be applied to the model. 58 | """ 59 | 60 | def __init__( 61 | self, 62 | velocity_model: ModelWrapper, 63 | original_scheduler: Scheduler, 64 | new_scheduler: Scheduler, 65 | ): 66 | super().__init__(model=velocity_model) 67 | self.original_scheduler = original_scheduler 68 | self.new_scheduler = new_scheduler 69 | 70 | assert hasattr(self.original_scheduler, "snr_inverse") and callable( 71 | getattr(self.original_scheduler, "snr_inverse") 72 | ), "The original scheduler must have a callable 'snr_inverse' method." 73 | 74 | def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor: 75 | r""" 76 | Compute the transformed marginal velocity field for a new scheduler. 77 | This method implements a post-training velocity scheduler change for 78 | affine conditional flows. It transforms a generating marginal velocity 79 | field :math:`u_t(x)` based on an original scheduler to a new marginal velocity 80 | field :math:`\bar{u}_r(x)` based on a different scheduler, while maintaining 81 | the same data coupling. 82 | The transformation is based on the scale-time (ST) transformation 83 | between the two conditional flows, defined as: 84 | 85 | .. math:: 86 | 87 | \bar{X}_r = s_r X_{t_r}, 88 | 89 | where :math:`X_t` and :math:`\bar{X}_r` are defined by their respective schedulers. 90 | The ST transformation is computed as: 91 | 92 | .. math:: 93 | 94 | t_r = \rho^{-1}(\bar{\rho}(r)) \quad \text{and} \quad s_r = \frac{\bar{\sigma}_r}{\sigma_{t_r}}. 95 | 96 | Here, :math:`\rho(t)` is the signal-to-noise ratio (SNR) defined as: 97 | 98 | .. math:: 99 | 100 | \rho(t) = \frac{\alpha_t}{\sigma_t}. 101 | 102 | :math:`\bar{\rho}(r)` is similarly defined for the new scheduler. 103 | The marginal velocity for the new scheduler is then given by: 104 | 105 | .. math:: 106 | 107 | \bar{u}_r(x) = \left(\frac{\dot{s}_r}{s_r}\right) x + s_r \dot{t}_r u_{t_r}\left(\frac{x}{s_r}\right). 108 | 109 | Args: 110 | x (Tensor): :math:`x_t`, the input tensor. 111 | t (Tensor): The time tensor (denoted as :math:`r` above). 112 | **extras: Additional arguments for the model. 113 | Returns: 114 | Tensor: The transformed velocity. 115 | """ 116 | r = t 117 | 118 | r_scheduler_output = self.new_scheduler(t=r) 119 | 120 | alpha_r = r_scheduler_output.alpha_t 121 | sigma_r = r_scheduler_output.sigma_t 122 | d_alpha_r = r_scheduler_output.d_alpha_t 123 | d_sigma_r = r_scheduler_output.d_sigma_t 124 | 125 | t = self.original_scheduler.snr_inverse(alpha_r / sigma_r) 126 | 127 | t_scheduler_output = self.original_scheduler(t=t) 128 | 129 | alpha_t = t_scheduler_output.alpha_t 130 | sigma_t = t_scheduler_output.sigma_t 131 | d_alpha_t = t_scheduler_output.d_alpha_t 132 | d_sigma_t = t_scheduler_output.d_sigma_t 133 | 134 | s_r = sigma_r / sigma_t 135 | 136 | dt_r = ( 137 | sigma_t 138 | * sigma_t 139 | * (sigma_r * d_alpha_r - alpha_r * d_sigma_r) 140 | / (sigma_r * sigma_r * (sigma_t * d_alpha_t - alpha_t * d_sigma_t)) 141 | ) 142 | 143 | ds_r = (sigma_t * d_sigma_r - sigma_r * d_sigma_t * dt_r) / (sigma_t * sigma_t) 144 | 145 | u_t = self.model(x=x / s_r, t=t, **extras) 146 | u_r = ds_r * x / s_r + dt_r * s_r * u_t 147 | 148 | return u_r 149 | -------------------------------------------------------------------------------- /examples/image/models/nn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Various utilities for neural networks. 8 | Taken from https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/nn.py 9 | """ 10 | 11 | import math 12 | 13 | import torch as th 14 | import torch.nn as nn 15 | 16 | 17 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 18 | class SiLU(nn.Module): 19 | def forward(self, x): 20 | return x * th.sigmoid(x) 21 | 22 | 23 | class GroupNorm32(nn.GroupNorm): 24 | def forward(self, x): 25 | return super().forward(x.float()).type(x.dtype) 26 | 27 | 28 | def conv_nd(dims, *args, **kwargs): 29 | """ 30 | Create a 1D, 2D, or 3D convolution module. 31 | """ 32 | if dims == 1: 33 | return nn.Conv1d(*args, **kwargs) 34 | elif dims == 2: 35 | return nn.Conv2d(*args, **kwargs) 36 | elif dims == 3: 37 | return nn.Conv3d(*args, **kwargs) 38 | raise ValueError(f"unsupported dimensions: {dims}") 39 | 40 | 41 | def linear(*args, **kwargs): 42 | """ 43 | Create a linear module. 44 | """ 45 | return nn.Linear(*args, **kwargs) 46 | 47 | 48 | def avg_pool_nd(dims, *args, **kwargs): 49 | """ 50 | Create a 1D, 2D, or 3D average pooling module. 51 | """ 52 | if dims == 1: 53 | return nn.AvgPool1d(*args, **kwargs) 54 | elif dims == 2: 55 | return nn.AvgPool2d(*args, **kwargs) 56 | elif dims == 3: 57 | return nn.AvgPool3d(*args, **kwargs) 58 | raise ValueError(f"unsupported dimensions: {dims}") 59 | 60 | 61 | def update_ema(target_params, source_params, rate=0.99): 62 | """ 63 | Update target parameters to be closer to those of source parameters using 64 | an exponential moving average. 65 | :param target_params: the target parameter sequence. 66 | :param source_params: the source parameter sequence. 67 | :param rate: the EMA rate (closer to 1 means slower). 68 | """ 69 | for targ, src in zip(target_params, source_params): 70 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 71 | 72 | 73 | def zero_module(module): 74 | """ 75 | Zero out the parameters of a module and return it. 76 | """ 77 | for p in module.parameters(): 78 | p.detach().zero_() 79 | return module 80 | 81 | 82 | def scale_module(module, scale): 83 | """ 84 | Scale the parameters of a module and return it. 85 | """ 86 | for p in module.parameters(): 87 | p.detach().mul_(scale) 88 | return module 89 | 90 | 91 | def mean_flat(tensor): 92 | """ 93 | Take the mean over all non-batch dimensions. 94 | """ 95 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 96 | 97 | 98 | def normalization(channels): 99 | """ 100 | Make a standard normalization layer. 101 | :param channels: number of input channels. 102 | :return: an nn.Module for normalization. 103 | """ 104 | return GroupNorm32(32, channels) 105 | 106 | 107 | def timestep_embedding(timesteps, dim, max_period=10000): 108 | """ 109 | Create sinusoidal timestep embeddings. 110 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 111 | These may be fractional. 112 | :param dim: the dimension of the output. 113 | :param max_period: controls the minimum frequency of the embeddings. 114 | :return: an [N x dim] Tensor of positional embeddings. 115 | """ 116 | half = dim // 2 117 | freqs = th.exp( 118 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 119 | ).to(device=timesteps.device) 120 | args = timesteps[:, None].float() * freqs[None] 121 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 122 | if dim % 2: 123 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 124 | return embedding 125 | 126 | 127 | def checkpoint(func, inputs, params, flag): 128 | """ 129 | Evaluate a function without caching intermediate activations, allowing for 130 | reduced memory at the expense of extra compute in the backward pass. 131 | :param func: the function to evaluate. 132 | :param inputs: the argument sequence to pass to `func`. 133 | :param params: a sequence of parameters `func` depends on but does not 134 | explicitly take as arguments. 135 | :param flag: if False, disable gradient checkpointing. 136 | """ 137 | if flag: 138 | # Use pytorch's activation checkpointing. This has support for fp16 autocast 139 | return th.utils.checkpoint.checkpoint(func, *inputs) 140 | # args = tuple(inputs) + tuple(params) 141 | # return CheckpointFunction.apply(func, len(inputs), *args) 142 | else: 143 | return func(*inputs) 144 | 145 | 146 | class CheckpointFunction(th.autograd.Function): 147 | @staticmethod 148 | def forward(ctx, run_function, length, *args): 149 | ctx.run_function = run_function 150 | ctx.input_tensors = list(args[:length]) 151 | ctx.input_params = list(args[length:]) 152 | with th.no_grad(): 153 | output_tensors = ctx.run_function(*ctx.input_tensors) 154 | return output_tensors 155 | 156 | @staticmethod 157 | def backward(ctx, *output_grads): 158 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 159 | with th.enable_grad(): 160 | # Fixes a bug where the first op in run_function modifies the 161 | # Tensor storage in place, which is not allowed for detach()'d 162 | # Tensors. 163 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 164 | output_tensors = ctx.run_function(*shallow_copies) 165 | input_grads = th.autograd.grad( 166 | output_tensors, 167 | ctx.input_tensors + ctx.input_params, 168 | output_grads, 169 | allow_unused=True, 170 | ) 171 | del ctx.input_tensors 172 | del ctx.input_params 173 | del output_tensors 174 | return (None, None) + input_grads 175 | -------------------------------------------------------------------------------- /examples/text/scripts/eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import datetime 8 | import os 9 | 10 | import torch 11 | import torch.distributed as dist 12 | 13 | from data import data 14 | from flow_matching.loss import MixturePathGeneralizedKL 15 | 16 | from logic import evaluate, flow, generate 17 | 18 | from torch.utils.data import DataLoader 19 | from transformers import GPT2TokenizerFast 20 | from utils import checkpointing 21 | 22 | 23 | def run_eval( 24 | rank: int, 25 | seed: int, 26 | work_dir: str, 27 | batch_size: int, 28 | perplexity_n_samples: int, 29 | sampling_steps: int, 30 | eval_perplexity: bool, 31 | eval_elbo: bool, 32 | elbo_data: str, 33 | world_size: int, 34 | n_discretization: float = 1024, 35 | ) -> None: 36 | torch.manual_seed(seed + rank) 37 | 38 | # Logging and configuration 39 | work_dirs = checkpointing.get_work_dirs(work_dir=work_dir, rank=rank) 40 | 41 | device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu") 42 | 43 | cfg = checkpointing.load_cfg_from_path(work_dir=work_dirs.checkpoint) 44 | 45 | # Data 46 | tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") 47 | vocab_size = tokenizer.vocab_size 48 | 49 | # Flow matching 50 | path = flow.get_path( 51 | scheduler_type=cfg.flow.scheduler_type, exponent=cfg.flow.exponent 52 | ) 53 | loss_fn = flow.get_loss_function(loss_function=cfg.flow.loss_function, path=path) 54 | # Elbo may have singularity at 1 55 | time_epsilon = 1e-3 if isinstance(loss_fn, MixturePathGeneralizedKL) else 0.0 56 | 57 | source_distribution = flow.get_source_distribution( 58 | source_distribution=cfg.flow.source_distribution, vocab_size=vocab_size 59 | ) 60 | 61 | model = checkpointing.load_model_from_path( 62 | work_dir=work_dirs.checkpoint, 63 | device=device, 64 | source_distribution=source_distribution, 65 | cfg=cfg.model, 66 | vocab_size=vocab_size, 67 | ) 68 | model.eval() 69 | 70 | if cfg.model.compile: 71 | model = torch.compile(model) 72 | torch.set_float32_matmul_precision("high") 73 | 74 | if eval_perplexity: 75 | assert perplexity_n_samples // batch_size > 0 76 | 77 | samples = [] 78 | 79 | for _ in range(perplexity_n_samples // batch_size): 80 | samples.append( 81 | generate.generate_samples( 82 | model=model, 83 | step=0, 84 | sample_dir=work_dirs.samples, 85 | vocab_size=vocab_size, 86 | tokenizer=tokenizer, 87 | rank=rank, 88 | device=device, 89 | path=path, 90 | source_distribution=source_distribution, 91 | sample_batch_size=batch_size, 92 | sequence_length=cfg.model.length, 93 | sampling_steps=sampling_steps, 94 | time_epsilon=time_epsilon, 95 | ) 96 | ) 97 | 98 | dist.barrier() 99 | 100 | samples = torch.cat(samples, dim=0) 101 | 102 | perplexity = evaluate.compute_perplexity( 103 | samples=samples, 104 | perplexity_batch_size=cfg.eval.perplexity_batch_size, 105 | ) 106 | dist.all_reduce(perplexity, dist.ReduceOp.AVG) 107 | 108 | entropy = evaluate.compute_entropy(samples=samples) 109 | dist.all_reduce(entropy, dist.ReduceOp.AVG) 110 | 111 | if rank == 0: 112 | print(f"Perplexity: {perplexity:.2f}, Entropy: {entropy:.2f}") 113 | 114 | if eval_elbo: 115 | data_state = data._get_dataset( 116 | name=elbo_data, 117 | mode="validation", 118 | cache_dir=cfg.data.cache_dir, 119 | block_size=cfg.model.length, 120 | num_proc=cfg.data.num_workers, 121 | batch_size=batch_size, 122 | ngpus=world_size, 123 | ) 124 | 125 | dataloader = DataLoader( 126 | data_state.dataset, 127 | batch_size=batch_size, 128 | sampler=data_state.sampler, 129 | num_workers=cfg.data.num_workers, 130 | pin_memory=True, 131 | shuffle=(data_state.sampler is None), 132 | ) 133 | 134 | elbo, num_elements = evaluate.estimate_likelihood( 135 | model=model, 136 | dataloader=dataloader, 137 | source_distribution=source_distribution, 138 | n_discretization=n_discretization, 139 | device=device, 140 | batch_size=batch_size, 141 | path=path, 142 | ) 143 | dist.barrier() 144 | 145 | dist.all_reduce(elbo, dist.ReduceOp.SUM) 146 | dist.all_reduce(num_elements, dist.ReduceOp.SUM) 147 | 148 | if rank == 0: 149 | print(f"ELBO: {torch.exp(elbo / num_elements).item():.2f}") 150 | 151 | 152 | def setup(rank: int, world_size: int, port: int) -> None: 153 | os.environ["MASTER_ADDR"] = "localhost" 154 | os.environ["MASTER_PORT"] = str(port) 155 | 156 | torch.cuda.set_device(rank) 157 | 158 | timeout = datetime.timedelta(minutes=30) 159 | dist.init_process_group("nccl", rank=rank, world_size=world_size, timeout=timeout) 160 | 161 | 162 | def cleanup() -> None: 163 | dist.destroy_process_group() 164 | 165 | 166 | def run_mp_eval( 167 | rank: int, 168 | world_size: int, 169 | seed: int, 170 | work_dir: str, 171 | batch_size: int, 172 | sampling_steps: int, 173 | eval_elbo: bool, 174 | eval_perplexity: bool, 175 | elbo_data: str, 176 | perplexity_n_samples: int, 177 | port: int, 178 | ) -> None: 179 | try: 180 | setup(rank=rank, world_size=world_size, port=port) 181 | run_eval( 182 | rank=rank, 183 | seed=seed, 184 | work_dir=work_dir, 185 | batch_size=batch_size, 186 | sampling_steps=sampling_steps, 187 | eval_elbo=eval_elbo, 188 | eval_perplexity=eval_perplexity, 189 | elbo_data=elbo_data, 190 | world_size=world_size, 191 | perplexity_n_samples=perplexity_n_samples, 192 | ) 193 | finally: 194 | cleanup() 195 | -------------------------------------------------------------------------------- /flow_matching/path/scheduler/scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from abc import ABC, abstractmethod 8 | from dataclasses import dataclass, field 9 | 10 | from typing import Union 11 | 12 | import torch 13 | 14 | from torch import Tensor 15 | 16 | 17 | @dataclass 18 | class SchedulerOutput: 19 | r"""Represents a sample of a conditional-flow generated probability path. 20 | 21 | Attributes: 22 | alpha_t (Tensor): :math:`\alpha_t`, shape (...). 23 | sigma_t (Tensor): :math:`\sigma_t`, shape (...). 24 | d_alpha_t (Tensor): :math:`\frac{\partial}{\partial t}\alpha_t`, shape (...). 25 | d_sigma_t (Tensor): :math:`\frac{\partial}{\partial t}\sigma_t`, shape (...). 26 | 27 | """ 28 | 29 | alpha_t: Tensor = field(metadata={"help": "alpha_t"}) 30 | sigma_t: Tensor = field(metadata={"help": "sigma_t"}) 31 | d_alpha_t: Tensor = field(metadata={"help": "Derivative of alpha_t."}) 32 | d_sigma_t: Tensor = field(metadata={"help": "Derivative of sigma_t."}) 33 | 34 | 35 | class Scheduler(ABC): 36 | """Base Scheduler class.""" 37 | 38 | @abstractmethod 39 | def __call__(self, t: Tensor) -> SchedulerOutput: 40 | r""" 41 | Args: 42 | t (Tensor): times in [0,1], shape (...). 43 | 44 | Returns: 45 | SchedulerOutput: :math:`\alpha_t,\sigma_t,\frac{\partial}{\partial t}\alpha_t,\frac{\partial}{\partial t}\sigma_t` 46 | """ 47 | ... 48 | 49 | @abstractmethod 50 | def snr_inverse(self, snr: Tensor) -> Tensor: 51 | r""" 52 | Computes :math:`t` from the signal-to-noise ratio :math:`\frac{\alpha_t}{\sigma_t}`. 53 | 54 | Args: 55 | snr (Tensor): The signal-to-noise, shape (...) 56 | 57 | Returns: 58 | Tensor: t, shape (...) 59 | """ 60 | ... 61 | 62 | 63 | class ConvexScheduler(Scheduler): 64 | @abstractmethod 65 | def __call__(self, t: Tensor) -> SchedulerOutput: 66 | r"""Scheduler for convex paths. 67 | 68 | Args: 69 | t (Tensor): times in [0,1], shape (...). 70 | 71 | Returns: 72 | SchedulerOutput: :math:`\alpha_t,\sigma_t,\frac{\partial}{\partial t}\alpha_t,\frac{\partial}{\partial t}\sigma_t` 73 | """ 74 | ... 75 | 76 | @abstractmethod 77 | def kappa_inverse(self, kappa: Tensor) -> Tensor: 78 | r""" 79 | Computes :math:`t` from :math:`\kappa_t`. 80 | 81 | Args: 82 | kappa (Tensor): :math:`\kappa`, shape (...) 83 | 84 | Returns: 85 | Tensor: t, shape (...) 86 | """ 87 | ... 88 | 89 | def snr_inverse(self, snr: Tensor) -> Tensor: 90 | r""" 91 | Computes :math:`t` from the signal-to-noise ratio :math:`\frac{\alpha_t}{\sigma_t}`. 92 | 93 | Args: 94 | snr (Tensor): The signal-to-noise, shape (...) 95 | 96 | Returns: 97 | Tensor: t, shape (...) 98 | """ 99 | kappa_t = snr / (1.0 + snr) 100 | 101 | return self.kappa_inverse(kappa=kappa_t) 102 | 103 | 104 | class CondOTScheduler(ConvexScheduler): 105 | """CondOT Scheduler.""" 106 | 107 | def __call__(self, t: Tensor) -> SchedulerOutput: 108 | return SchedulerOutput( 109 | alpha_t=t, 110 | sigma_t=1 - t, 111 | d_alpha_t=torch.ones_like(t), 112 | d_sigma_t=-torch.ones_like(t), 113 | ) 114 | 115 | def kappa_inverse(self, kappa: Tensor) -> Tensor: 116 | return kappa 117 | 118 | 119 | class PolynomialConvexScheduler(ConvexScheduler): 120 | """Polynomial Scheduler.""" 121 | 122 | def __init__(self, n: Union[float, int]) -> None: 123 | assert isinstance( 124 | n, (float, int) 125 | ), f"`n` must be a float or int. Got {type(n)=}." 126 | assert n > 0, f"`n` must be positive. Got {n=}." 127 | 128 | self.n = n 129 | 130 | def __call__(self, t: Tensor) -> SchedulerOutput: 131 | return SchedulerOutput( 132 | alpha_t=t**self.n, 133 | sigma_t=1 - t**self.n, 134 | d_alpha_t=self.n * (t ** (self.n - 1)), 135 | d_sigma_t=-self.n * (t ** (self.n - 1)), 136 | ) 137 | 138 | def kappa_inverse(self, kappa: Tensor) -> Tensor: 139 | return torch.pow(kappa, 1.0 / self.n) 140 | 141 | 142 | class VPScheduler(Scheduler): 143 | """Variance Preserving Scheduler.""" 144 | 145 | def __init__(self, beta_min: float = 0.1, beta_max: float = 20.0) -> None: 146 | self.beta_min = beta_min 147 | self.beta_max = beta_max 148 | super().__init__() 149 | 150 | def __call__(self, t: Tensor) -> SchedulerOutput: 151 | b = self.beta_min 152 | B = self.beta_max 153 | T = 0.5 * (1 - t) ** 2 * (B - b) + (1 - t) * b 154 | dT = -(1 - t) * (B - b) - b 155 | 156 | return SchedulerOutput( 157 | alpha_t=torch.exp(-0.5 * T), 158 | sigma_t=torch.sqrt(1 - torch.exp(-T)), 159 | d_alpha_t=-0.5 * dT * torch.exp(-0.5 * T), 160 | d_sigma_t=0.5 * dT * torch.exp(-T) / torch.sqrt(1 - torch.exp(-T)), 161 | ) 162 | 163 | def snr_inverse(self, snr: Tensor) -> Tensor: 164 | T = -torch.log(snr**2 / (snr**2 + 1)) 165 | b = self.beta_min 166 | B = self.beta_max 167 | t = 1 - ((-b + torch.sqrt(b**2 + 2 * (B - b) * T)) / (B - b)) 168 | return t 169 | 170 | 171 | class LinearVPScheduler(Scheduler): 172 | """Linear Variance Preserving Scheduler.""" 173 | 174 | def __call__(self, t: Tensor) -> SchedulerOutput: 175 | return SchedulerOutput( 176 | alpha_t=t, 177 | sigma_t=(1 - t**2) ** 0.5, 178 | d_alpha_t=torch.ones_like(t), 179 | d_sigma_t=-t / (1 - t**2) ** 0.5, 180 | ) 181 | 182 | def snr_inverse(self, snr: Tensor) -> Tensor: 183 | return torch.sqrt(snr**2 / (1 + snr**2)) 184 | 185 | 186 | class CosineScheduler(Scheduler): 187 | """Cosine Scheduler.""" 188 | 189 | def __call__(self, t: Tensor) -> SchedulerOutput: 190 | pi = torch.pi 191 | return SchedulerOutput( 192 | alpha_t=torch.sin(pi / 2 * t), 193 | sigma_t=torch.cos(pi / 2 * t), 194 | d_alpha_t=pi / 2 * torch.cos(pi / 2 * t), 195 | d_sigma_t=-pi / 2 * torch.sin(pi / 2 * t), 196 | ) 197 | 198 | def snr_inverse(self, snr: Tensor) -> Tensor: 199 | return 2.0 * torch.atan(snr) / torch.pi 200 | -------------------------------------------------------------------------------- /examples/image/submitit_train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Copyright (c) Meta Platforms, Inc. and affiliates. 7 | # All rights reserved. 8 | 9 | # This source code is licensed under the license found in the 10 | # LICENSE file in the root directory of this source tree. 11 | # -------------------------------------------------------- 12 | # A script to run multinode training with submitit. 13 | # -------------------------------------------------------- 14 | 15 | import argparse 16 | import logging 17 | import os 18 | import sys 19 | import uuid 20 | from pathlib import Path 21 | 22 | import submitit 23 | import train 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | def parse_args(): 29 | trainer_parser = train.get_args_parser() 30 | parser = argparse.ArgumentParser( 31 | "Submitit for flow_matching training", parents=[trainer_parser] 32 | ) 33 | parser.add_argument( 34 | "--ngpus", default=8, type=int, help="Number of gpus to request on each node" 35 | ) 36 | parser.add_argument( 37 | "--nodes", default=8, type=int, help="Number of nodes to request" 38 | ) 39 | parser.add_argument("--timeout", default=4320, type=int, help="Duration of the job") 40 | parser.add_argument( 41 | "--job_dir", default="", type=str, help="Job dir. Leave empty for automatic." 42 | ) 43 | parser.add_argument( 44 | "--shared_dir", 45 | default="/checkpoint", 46 | type=str, 47 | help="Directory shared among the nodes. A directory named USER/experiments is created under shared_dir that is used to coordinate in distributed mode.", 48 | ) 49 | 50 | parser.add_argument( 51 | "--partition", default="learnlab", type=str, help="Partition where to submit" 52 | ) 53 | parser.add_argument( 54 | "--constraint", 55 | default="", 56 | type=str, 57 | help="Slurm constraint eg.: ampere80gb For using A100s or volta32gb for using V100s.", 58 | ) 59 | parser.add_argument( 60 | "--comment", default="", type=str, help="Comment to pass to scheduler" 61 | ) 62 | parser.add_argument("--qos", default="", type=str, help="Slurm QOS") 63 | parser.add_argument("--account", default="", type=str, help="Slurm account") 64 | parser.add_argument( 65 | "--exclude", 66 | default="", 67 | type=str, 68 | help="Exclude certain nodes from the slurm job.", 69 | ) 70 | return parser.parse_args() 71 | 72 | 73 | def get_shared_folder(shared_dir: str) -> Path: 74 | user = os.getenv("USER") 75 | if Path(shared_dir).is_dir(): 76 | p = Path(shared_dir) / user / "experiments" 77 | p.mkdir(exist_ok=True) 78 | return p 79 | raise RuntimeError("No shared folder available") 80 | 81 | 82 | def get_init_file(shared_dir: str): 83 | # Init file must not exist, but it's parent dir must exist. 84 | os.makedirs(str(get_shared_folder(shared_dir)), exist_ok=True) 85 | init_file = get_shared_folder(shared_dir) / f"{uuid.uuid4().hex}_init" 86 | if init_file.exists(): 87 | os.remove(str(init_file)) 88 | return init_file 89 | 90 | 91 | class Trainer(object): 92 | def __init__(self, args): 93 | self.args = args 94 | 95 | def __call__(self): 96 | import train 97 | 98 | self._setup_gpu_args() 99 | train.main(self.args) 100 | 101 | def checkpoint(self): 102 | import os 103 | 104 | import submitit 105 | 106 | self.args.dist_url = get_init_file(self.args.shared_dir).as_uri() 107 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") 108 | if os.path.exists(checkpoint_file) and not self.args.eval_only: 109 | self.args.resume = checkpoint_file 110 | logger.info("Requeuing ", self.args) 111 | empty_trainer = type(self)(self.args) 112 | return submitit.helpers.DelayedSubmission(empty_trainer) 113 | 114 | def _setup_gpu_args(self): 115 | 116 | import submitit 117 | 118 | job_env = submitit.JobEnvironment() 119 | self.args.output_dir = str(self.args.output_dir).replace( 120 | "%j", str(job_env.job_id) 121 | ) 122 | self.args.log_dir = self.args.output_dir 123 | self.args.gpu = job_env.local_rank 124 | self.args.rank = job_env.global_rank 125 | self.args.world_size = job_env.num_tasks 126 | logger.info( 127 | f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}" 128 | ) 129 | 130 | 131 | def main(): 132 | args = parse_args() 133 | if args.job_dir == "": 134 | args.job_dir = get_shared_folder(args.shared_dir) / "%j" 135 | 136 | # Note that the folder will depend on the job_id, to easily track experiments 137 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 138 | 139 | num_gpus_per_node = args.ngpus 140 | nodes = args.nodes 141 | timeout_min = args.timeout 142 | 143 | partition = args.partition 144 | exclude = args.exclude 145 | kwargs = {} 146 | if len(args.constraint): 147 | kwargs["slurm_constraint"] = args.constraint 148 | if args.comment: 149 | kwargs["slurm_comment"] = args.comment 150 | if args.qos: 151 | kwargs["slurm_qos"] = args.qos 152 | if args.account: 153 | kwargs["slurm_account"] = args.account 154 | 155 | executor.update_parameters( 156 | mem_gb=40 * num_gpus_per_node, 157 | gpus_per_node=num_gpus_per_node, 158 | tasks_per_node=num_gpus_per_node, # one task per GPU 159 | cpus_per_task=10, 160 | nodes=nodes, 161 | timeout_min=timeout_min, # max is 60 * 72 162 | # Below are cluster dependent parameters 163 | slurm_partition=partition, 164 | slurm_signal_delay_s=120, 165 | slurm_exclude=exclude, 166 | **kwargs, 167 | ) 168 | 169 | executor.update_parameters(name="flow_matching") 170 | 171 | args.dist_url = get_init_file(args.shared_dir).as_uri() 172 | args.output_dir = args.job_dir 173 | 174 | trainer = Trainer(args) 175 | job = executor.submit(trainer) 176 | 177 | # print("Submitted job_id:", job.job_id) 178 | logger.info(f"Submitted job {job.job_id}") 179 | 180 | 181 | if __name__ == "__main__": 182 | logging.basicConfig( 183 | level=logging.INFO, 184 | stream=sys.stdout, 185 | format="%(asctime)s %(levelname)-8s %(message)s", 186 | datefmt="%Y-%m-%d %H:%M:%S", 187 | ) 188 | main() 189 | -------------------------------------------------------------------------------- /examples/text/data/data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Part of this implementation is adapted from https://github.com/louaaron/Score-Entropy-Discrete-Diffusion 7 | # which is released under MIT license 8 | 9 | from dataclasses import dataclass, field 10 | from itertools import chain 11 | from typing import Dict, Iterable, Tuple 12 | 13 | from datasets import DatasetDict, load_dataset 14 | from omegaconf import OmegaConf 15 | 16 | from torch.utils.data import DataLoader 17 | from transformers import GPT2TokenizerFast 18 | 19 | from data.tokenizer import wt_detokenizer 20 | from data.utils import cycle_loader, StatefulDistributedSampler 21 | 22 | 23 | def _get_hf_dataset( 24 | name: str, 25 | mode: str, 26 | cache_dir: str = None, 27 | block_size: int = 1024, 28 | num_proc: int = 8, 29 | ) -> DatasetDict: 30 | detokenizer = None 31 | 32 | if name == "wikitext103": 33 | data = load_dataset( 34 | "wikitext", name="wikitext-103-raw-v1", cache_dir=cache_dir 35 | )[mode] 36 | detokenizer = wt_detokenizer 37 | elif name == "fineweb-edu": 38 | data = load_dataset( 39 | "HuggingFaceFW/fineweb-edu", name="CC-MAIN-2024-10", cache_dir=cache_dir 40 | )[mode] 41 | else: 42 | data = load_dataset(name, cache_dir=cache_dir)[mode] 43 | 44 | def _apply_detokenizer(detokenizer): 45 | def detok(text): 46 | for i, t in enumerate(text, 0): 47 | text[i] = detokenizer(t) 48 | return text 49 | 50 | return detok 51 | 52 | tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") 53 | EOS = tokenizer.encode(tokenizer.eos_token)[0] 54 | 55 | def preprocess_and_tokenize(example: Dict): 56 | text = example["text"] 57 | 58 | if detokenizer is not None: 59 | text = _apply_detokenizer(detokenizer)(text) 60 | 61 | tokens = tokenizer(text, return_attention_mask=False) 62 | # add in EOS token following 63 | # https://github.com/jcpeterson/openwebtext/blob/master/tokenize_text.py#L67 64 | for token in tokens["input_ids"]: 65 | token.append(EOS) 66 | 67 | return tokens 68 | 69 | tokenized_dataset = data.map( 70 | preprocess_and_tokenize, 71 | batched=True, 72 | num_proc=num_proc, 73 | load_from_cache_file=True, 74 | ) 75 | 76 | if name == "fineweb-edu": 77 | features = tokenized_dataset.features.keys() 78 | for k in features: 79 | if k != "input_ids": 80 | tokenized_dataset = tokenized_dataset.remove_columns(k) 81 | else: 82 | tokenized_dataset = tokenized_dataset.remove_columns("text") 83 | 84 | def group_texts(examples: Dict): 85 | # Concatenate all texts. 86 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} 87 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 88 | # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict. 89 | # We could add padding if the model supported it instead of this drop, you can customize this part to your needs. 90 | total_length = (total_length // block_size) * block_size 91 | # Split by chunks of max_len. 92 | result = { 93 | k: [t[i : i + block_size] for i in range(0, total_length, block_size)] 94 | for k, t in concatenated_examples.items() 95 | } 96 | 97 | return result 98 | 99 | chunked_dataset = tokenized_dataset.map( 100 | group_texts, batched=True, num_proc=num_proc, load_from_cache_file=True 101 | ) 102 | chunked_dataset = chunked_dataset.with_format("torch") 103 | 104 | return chunked_dataset 105 | 106 | 107 | @dataclass 108 | class Dataset: 109 | dataset: DatasetDict = field(metadata={"help": "Huggingface dataset"}) 110 | sampler: StatefulDistributedSampler = field( 111 | metadata={"help": "Stateful sampler for `dataset`"} 112 | ) 113 | 114 | 115 | @dataclass 116 | class DataState: 117 | train: Dataset = field(metadata={"help": "Train dataset"}) 118 | test: Dataset = field(metadata={"help": "Test dataset"}) 119 | 120 | 121 | def _get_dataset( 122 | name: str, 123 | mode: str, 124 | cache_dir: str, 125 | block_size: int, 126 | num_proc: int, 127 | batch_size: int, 128 | ngpus: int, 129 | ) -> Dataset: 130 | assert ( 131 | batch_size % ngpus == 0 132 | ), f"{mode} batch size must be divisible by number of gpus." 133 | 134 | dataset = _get_hf_dataset( 135 | name=name, 136 | mode=mode, 137 | cache_dir=cache_dir, 138 | block_size=block_size, 139 | num_proc=num_proc, 140 | ) 141 | 142 | sampler = StatefulDistributedSampler(dataset=dataset) 143 | 144 | return Dataset(dataset=dataset, sampler=sampler) 145 | 146 | 147 | def get_data_state(config: OmegaConf) -> DataState: 148 | train = _get_dataset( 149 | name=config.data.train, 150 | mode="train", 151 | cache_dir=config.data.cache_dir, 152 | block_size=config.model.length, 153 | num_proc=config.data.num_workers, 154 | batch_size=config.training.batch_size, 155 | ngpus=config.compute.ngpus, 156 | ) 157 | test = _get_dataset( 158 | name=config.data.valid, 159 | mode="validation", 160 | cache_dir=config.data.cache_dir, 161 | block_size=config.model.length, 162 | num_proc=config.data.num_workers, 163 | batch_size=config.eval.batch_size, 164 | ngpus=config.compute.ngpus, 165 | ) 166 | 167 | return DataState(train=train, test=test) 168 | 169 | 170 | def get_data_loaders( 171 | config: OmegaConf, 172 | data_state: DataState, 173 | ) -> Tuple[Iterable, Iterable]: 174 | train_loader = cycle_loader( 175 | DataLoader( 176 | data_state.train.dataset, 177 | batch_size=config.training.batch_size // config.compute.ngpus, 178 | sampler=data_state.train.sampler, 179 | num_workers=config.data.num_workers, 180 | pin_memory=True, 181 | shuffle=(data_state.train.sampler is None), 182 | persistent_workers=True, 183 | ) 184 | ) 185 | 186 | valid_loader = cycle_loader( 187 | DataLoader( 188 | data_state.test.dataset, 189 | batch_size=config.eval.batch_size // config.compute.ngpus, 190 | sampler=data_state.test.sampler, 191 | num_workers=config.data.num_workers, 192 | pin_memory=True, 193 | shuffle=(data_state.test.sampler is None), 194 | ) 195 | ) 196 | 197 | return iter(train_loader), iter(valid_loader) 198 | -------------------------------------------------------------------------------- /tests/path/test_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import math 7 | import unittest 8 | 9 | import torch 10 | from flow_matching.path import ( 11 | AffineProbPath, 12 | CondOTProbPath, 13 | GeodesicProbPath, 14 | MixtureDiscreteProbPath, 15 | ) 16 | from flow_matching.path.scheduler import CondOTScheduler 17 | from flow_matching.utils.manifolds import FlatTorus, Sphere 18 | 19 | 20 | class TestAffineProbPath(unittest.TestCase): 21 | def test_affine_prob_path_sample(self): 22 | scheduler = CondOTScheduler() 23 | affine_prob_path = AffineProbPath(scheduler) 24 | x_0 = torch.randn(10, 5) 25 | x_1 = torch.randn(10, 5) 26 | t = torch.randn(10) 27 | sample = affine_prob_path.sample(x_0, x_1, t) 28 | self.assertEqual(sample.x_t.shape, x_0.shape) 29 | self.assertEqual(sample.dx_t.shape, x_0.shape) 30 | self.assertTrue((sample.t == t).all()) 31 | self.assertTrue((sample.x_0 == x_0).all()) 32 | self.assertTrue((sample.x_1 == x_1).all()) 33 | 34 | def test_assert_sample_shape(self): 35 | scheduler = CondOTScheduler() 36 | path = AffineProbPath(scheduler) 37 | x_0 = torch.randn(10, 5) 38 | x_1 = torch.randn(10, 5) 39 | t = torch.randn(10) 40 | path.assert_sample_shape(x_0, x_1, t) 41 | 42 | x_0 = torch.randn(10, 5) 43 | x_1 = torch.randn(10, 5) 44 | t = torch.randn(5) 45 | with self.assertRaises(AssertionError): 46 | path.assert_sample_shape(x_0, x_1, t) 47 | 48 | def test_cond_ot_prob_path_sample(self): 49 | cond_ot_prob_path = CondOTProbPath() 50 | scheduler = CondOTScheduler() 51 | affine_path = AffineProbPath(scheduler) 52 | x_0 = torch.randn(10, 5) 53 | x_1 = torch.randn(10, 5) 54 | t = torch.randn(10) 55 | sample1 = cond_ot_prob_path.sample(x_0, x_1, t) 56 | sample2 = affine_path.sample(x_0, x_1, t) 57 | self.assertTrue(torch.allclose(sample1.x_t, sample2.x_t)) 58 | 59 | def test_to_velocity(self): 60 | path = CondOTProbPath() 61 | x_1 = torch.randn(10, 5, dtype=torch.float64) 62 | x_t = torch.randn(10, 5, dtype=torch.float64) 63 | t = torch.randn(10, 5, dtype=torch.float64) 64 | velocity = path.target_to_velocity(x_1, x_t, t) 65 | target = path.velocity_to_target(velocity, x_t, t) 66 | self.assertTrue(torch.allclose(target, x_1)) 67 | 68 | def test_to_epsilon(self): 69 | path = CondOTProbPath() 70 | x_1 = torch.randn(10, 5, dtype=torch.float64) 71 | x_t = torch.randn(10, 5, dtype=torch.float64) 72 | t = torch.randn(10, 5, dtype=torch.float64) 73 | epsilon = path.target_to_epsilon(x_1, x_t, t) 74 | target = path.epsilon_to_target(epsilon, x_t, t) 75 | self.assertTrue(torch.allclose(target, x_1)) 76 | 77 | def test_epsilson_velocity(self): 78 | path = CondOTProbPath() 79 | velocity = torch.randn(10, 5, dtype=torch.float64) 80 | x_t = torch.randn(10, 5, dtype=torch.float64) 81 | t = torch.randn(10, 5, dtype=torch.float64) 82 | 83 | epsilon = path.velocity_to_epsilon(velocity, x_t, t) 84 | v = path.epsilon_to_velocity(epsilon, x_t, t) 85 | self.assertTrue(torch.allclose(v, velocity)) 86 | 87 | 88 | class TestGeodesicProbPath(unittest.TestCase): 89 | def test_sphere(self): 90 | manifold = Sphere() 91 | path = GeodesicProbPath(manifold=manifold, scheduler=CondOTScheduler()) 92 | 93 | def wrap(samples): 94 | center = torch.cat( 95 | [torch.zeros_like(samples), torch.ones_like(samples[..., 0:1])], dim=-1 96 | ) 97 | samples = ( 98 | torch.cat([samples, torch.zeros_like(samples[..., 0:1])], dim=-1) / 2 99 | ) 100 | return manifold.expmap(center, samples) 101 | 102 | x1 = manifold.projx(torch.rand(5, 5, dtype=torch.float64)) 103 | x0 = torch.randn_like(x1) 104 | x0 = wrap(x0) 105 | x1 = wrap(x1) 106 | t = torch.rand(x0.size(0), dtype=torch.float64) 107 | 108 | sample = path.sample(t=t, x_0=x0, x_1=x1) 109 | 110 | # Check that x_t is on the sphere 111 | self.assertTrue( 112 | torch.allclose( 113 | sample.x_t.norm(2, -1), torch.ones(x0.size(0), dtype=torch.float64) 114 | ) 115 | ) 116 | 117 | def test_torus(self): 118 | manifold = FlatTorus() 119 | path = GeodesicProbPath(manifold=manifold, scheduler=CondOTScheduler()) 120 | 121 | def wrap(samples): 122 | center = torch.zeros_like(samples) 123 | return manifold.expmap(center, samples) 124 | 125 | batch_size = 5 126 | coord1 = torch.rand(batch_size, dtype=torch.float64) * 4 - 2 127 | coord2_ = ( 128 | torch.rand(batch_size, dtype=torch.float64) 129 | - torch.randint(high=2, size=(batch_size,), dtype=torch.float64) * 2 130 | ) 131 | coord2 = coord2_ + (torch.floor(coord1) % 2) 132 | 133 | x1 = torch.stack([coord1, coord2], dim=1) 134 | x0 = torch.randn_like(x1) 135 | x0 = wrap(x0) 136 | x1 = wrap(x1) 137 | t = torch.rand(x0.size(0), dtype=torch.float64) 138 | 139 | sample = path.sample(t=t, x_0=x0, x_1=x1) 140 | 141 | self.assertTrue((sample.x_t < 2 * math.pi).all()) 142 | 143 | 144 | class TestMixtureDiscreteProbPath(unittest.TestCase): 145 | def test_mixture_discrete_prob_path_sample(self): 146 | scheduler = CondOTScheduler() 147 | discrete_prob_path = MixtureDiscreteProbPath(scheduler) 148 | x_0 = torch.randn(10, 5) 149 | x_1 = torch.randn(10, 5) 150 | t = torch.randn(10) 151 | sample = discrete_prob_path.sample(x_0, x_1, t) 152 | self.assertEqual(sample.x_t.shape, x_0.shape) 153 | self.assertTrue((sample.t == t).all()) 154 | self.assertTrue((sample.x_0 == x_0).all()) 155 | self.assertTrue((sample.x_1 == x_1).all()) 156 | 157 | # Test at t=0 158 | t = torch.zeros(10) 159 | sample = discrete_prob_path.sample(x_0, x_1, t) 160 | self.assertTrue(torch.allclose(sample.x_t, x_0)) 161 | # Test at t=1 162 | t = torch.ones(10) 163 | sample = discrete_prob_path.sample(x_0, x_1, t) 164 | self.assertTrue(torch.allclose(sample.x_t, x_1)) 165 | 166 | def test_posterior_to_velocity(self): 167 | scheduler = CondOTScheduler() 168 | discrete_prob_path = MixtureDiscreteProbPath(scheduler) 169 | posterior_logits = torch.randn(10, 5) 170 | x_t = torch.randint(0, 5, size=[10]) 171 | t = torch.randn(10) 172 | x_t_one_hot = torch.nn.functional.one_hot(x_t, num_classes=5) 173 | velocity = discrete_prob_path.posterior_to_velocity(posterior_logits, x_t, t) 174 | expected_velocity = (torch.softmax(posterior_logits, dim=-1) - x_t_one_hot) / ( 175 | 1 - t 176 | ).unsqueeze(-1) 177 | self.assertTrue(torch.allclose(velocity, expected_velocity)) 178 | 179 | 180 | if __name__ == "__main__": 181 | unittest.main() 182 | -------------------------------------------------------------------------------- /examples/image/train_arg_parser.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import argparse 7 | import json 8 | import logging 9 | 10 | from models.model_configs import MODEL_CONFIGS 11 | from torchdiffeq._impl.odeint import SOLVERS 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def get_args_parser(): 17 | parser = argparse.ArgumentParser("Image dataset training", add_help=False) 18 | parser.add_argument( 19 | "--batch_size", 20 | default=32, 21 | type=int, 22 | help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus", 23 | ) 24 | parser.add_argument("--epochs", default=921, type=int) 25 | parser.add_argument( 26 | "--accum_iter", 27 | default=1, 28 | type=int, 29 | help="Accumulate gradient iterations (for increasing the effective batch size under memory constraints)", 30 | ) 31 | 32 | # Optimizer parameters 33 | parser.add_argument( 34 | "--lr", 35 | type=float, 36 | default=0.0001, 37 | help="learning rate (absolute lr)", 38 | ) 39 | parser.add_argument( 40 | "--optimizer_betas", 41 | nargs="+", 42 | type=float, 43 | default=[0.9, 0.95], 44 | help="learning rate (absolute lr)", 45 | ) 46 | parser.add_argument( 47 | "--decay_lr", 48 | action="store_true", 49 | help="Adds a linear decay to the lr during training.", 50 | ) 51 | parser.add_argument( 52 | "--class_drop_prob", 53 | type=float, 54 | default=0.2, 55 | help="Probability to drop conditioning during training", 56 | ) 57 | parser.add_argument( 58 | "--skewed_timesteps", 59 | action="store_true", 60 | help="Use skewed timestep sampling proposed in the EDM paper: https://arxiv.org/abs/2206.00364.", 61 | ) 62 | parser.add_argument( 63 | "--edm_schedule", 64 | action="store_true", 65 | help="Use the alternative time discretization during sampling proposed in the EDM paper: https://arxiv.org/abs/2206.00364.", 66 | ) 67 | parser.add_argument( 68 | "--use_ema", 69 | action="store_true", 70 | help="When evaluating, use the model Exponential Moving Average weights.", 71 | ) 72 | 73 | # Dataset parameters 74 | parser.add_argument( 75 | "--dataset", 76 | default=list(MODEL_CONFIGS.keys())[0], 77 | type=str, 78 | choices=list(MODEL_CONFIGS.keys()), 79 | help="Dataset to use.", 80 | ) 81 | parser.add_argument( 82 | "--data_path", 83 | default="./data/image_generation", 84 | type=str, 85 | help="imagenet root folder with train, val and test subfolders", 86 | ) 87 | 88 | parser.add_argument( 89 | "--output_dir", 90 | default="./output_dir", 91 | help="path where to save, empty for no saving", 92 | ) 93 | parser.add_argument( 94 | "--ode_method", 95 | default="midpoint", 96 | choices=list(SOLVERS.keys()) + ["edm_heun"], 97 | help="ODE solver used to generate samples.", 98 | ) 99 | parser.add_argument( 100 | "--ode_options", 101 | default='{"step_size": 0.01}', 102 | type=json.loads, 103 | help="ODE solver options. Eg. the midpoint solver requires step-size, dopri5 has no options to set.", 104 | ) 105 | parser.add_argument( 106 | "--sym", 107 | default=0.0, 108 | type=float, 109 | help="Symmetric term for sampling the discrete flow.", 110 | ) 111 | parser.add_argument( 112 | "--temp", 113 | default=1.0, 114 | type=float, 115 | help="Temperature for sampling the discrete flow.", 116 | ) 117 | parser.add_argument( 118 | "--sym_func", 119 | action="store_true", 120 | help="Use a fixed function for the symmetric term in the discrete flow.", 121 | ) 122 | parser.add_argument( 123 | "--sampling_dtype", 124 | default="float32", 125 | choices=["float32", "float64"], 126 | help="Solver dtype for sampling the discrete flow.", 127 | ) 128 | parser.add_argument( 129 | "--cfg_scale", 130 | default=0.2, 131 | type=float, 132 | help="Classifier-free guidance scale for generating samples.", 133 | ) 134 | parser.add_argument( 135 | "--fid_samples", 136 | default=50000, 137 | type=int, 138 | help="number of synthetic samples for FID evaluations", 139 | ) 140 | parser.add_argument( 141 | "--device", default="cuda", help="device to use for training / testing" 142 | ) 143 | parser.add_argument("--seed", default=0, type=int) 144 | parser.add_argument("--resume", default="", help="resume from checkpoint") 145 | 146 | parser.add_argument( 147 | "--start_epoch", 148 | default=0, 149 | type=int, 150 | metavar="N", 151 | help="start epoch (used when resumed from checkpoint)", 152 | ) 153 | parser.add_argument( 154 | "--eval_only", action="store_true", help="No training, only run evaluation" 155 | ) 156 | parser.add_argument( 157 | "--eval_frequency", 158 | default=50, 159 | type=int, 160 | help="Frequency (in number of epochs) for running FID evaluation. -1 to never run evaluation.", 161 | ) 162 | parser.add_argument( 163 | "--compute_fid", 164 | action="store_true", 165 | help="Whether to compute FID in the evaluation loop. When disabled, the evaluation loop still runs and saves snapshots, but skips the FID computation.", 166 | ) 167 | parser.add_argument( 168 | "--save_fid_samples", 169 | action="store_true", 170 | help="Save all samples generated for FID computation.", 171 | ) 172 | parser.add_argument("--num_workers", default=10, type=int) 173 | parser.add_argument( 174 | "--pin_mem", 175 | action="store_true", 176 | help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.", 177 | ) 178 | parser.add_argument("--no_pin_mem", action="store_false", dest="pin_mem") 179 | parser.set_defaults(pin_mem=True) 180 | # distributed training parameters 181 | parser.add_argument( 182 | "--world_size", default=1, type=int, help="number of distributed processes" 183 | ) 184 | parser.add_argument("--local_rank", default=-1, type=int) 185 | parser.add_argument("--dist_on_itp", action="store_true") 186 | parser.add_argument( 187 | "--dist_url", default="env://", help="url used to set up distributed training" 188 | ) 189 | parser.add_argument( 190 | "--test_run", 191 | action="store_true", 192 | help="Only run one batch of training and evaluation.", 193 | ) 194 | parser.add_argument( 195 | "--discrete_flow_matching", 196 | action="store_true", 197 | help="Train discrete flow matching model.", 198 | ) 199 | parser.add_argument( 200 | "--discrete_fm_steps", 201 | default=1024, 202 | type=int, 203 | help="Number of sampling steps for discrete FM.", 204 | ) 205 | 206 | return parser 207 | --------------------------------------------------------------------------------