├── .github ├── PULL_REQUEST_TEMPLATE.md ├── dependabot.yml └── workflows │ ├── code-quality-main.yaml │ ├── code-quality-pr.yaml │ ├── python-publish.yml │ ├── test.yaml │ └── test_runner.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── assets ├── 169_generated_samples_otcfm.gif ├── 169_generated_samples_otcfm.png ├── 8gaussians-to-moons.gif ├── DF_logo.png ├── DF_logo_dark.png └── gaussian-to-moons.gif ├── examples ├── 2D_tutorials │ ├── Flow_matching_tutorial.ipynb │ ├── Maximum_likelihood_CNF_tutorial.ipynb │ ├── SF2M_tutorial.ipynb │ ├── The_unreasonable_performance_of_minibatch_OT.ipynb │ ├── model-comparison-plotting.ipynb │ ├── preprocessing │ │ ├── 1.0-embryoid_body_data_to_h5ad.ipynb │ │ ├── 1.1-eb-data-1d-phate.ipynb │ │ ├── 2.1-split-cite-data.ipynb │ │ └── README.md │ └── tutorial_training_8_gaussians_to_moons.ipynb ├── images │ ├── cifar10 │ │ ├── README.md │ │ ├── compute_fid.py │ │ ├── train_cifar10.py │ │ ├── train_cifar10_ddp.py │ │ └── utils_cifar.py │ ├── conditional_mnist.ipynb │ └── mnist_example.ipynb ├── single_cell │ └── single-cell_example.ipynb └── tabular │ ├── README.md │ └── Tabular_Data_Generation_with_XGBoost_Conditional_Flow_Matching.ipynb ├── pyproject.toml ├── requirements.txt ├── runner-requirements.txt ├── runner ├── README.md ├── configs │ ├── callbacks │ │ ├── default.yaml │ │ ├── early_stopping.yaml │ │ ├── model_checkpoint.yaml │ │ ├── model_summary.yaml │ │ ├── no_stopping.yaml │ │ ├── none.yaml │ │ └── rich_progress_bar.yaml │ ├── datamodule │ │ ├── cifar.yaml │ │ ├── custom_dist.yaml │ │ ├── eb_full.yaml │ │ ├── funnel.yaml │ │ ├── gaussians.yaml │ │ ├── moons.yaml │ │ ├── scurve.yaml │ │ ├── sklearn.yaml │ │ ├── time_dist.yaml │ │ ├── torchdyn.yaml │ │ ├── tree.yaml │ │ └── twodim.yaml │ ├── debug │ │ ├── default.yaml │ │ ├── fdr.yaml │ │ ├── limit.yaml │ │ ├── overfit.yaml │ │ └── profiler.yaml │ ├── eval.yaml │ ├── experiment │ │ ├── cfm.yaml │ │ ├── cnf.yaml │ │ ├── icnn.yaml │ │ ├── image_cfm.yaml │ │ ├── image_fm.yaml │ │ ├── image_otcfm.yaml │ │ └── trajectorynet.yaml │ ├── extras │ │ └── default.yaml │ ├── hparams_search │ │ └── optuna.yaml │ ├── hydra │ │ └── default.yaml │ ├── launcher │ │ ├── mila_cluster.yaml │ │ └── mila_cpu_cluster.yaml │ ├── local │ │ ├── .gitkeep │ │ └── default.yaml │ ├── logger │ │ ├── comet.yaml │ │ ├── csv.yaml │ │ ├── many_loggers.yaml │ │ ├── mlflow.yaml │ │ ├── neptune.yaml │ │ ├── tensorboard.yaml │ │ └── wandb.yaml │ ├── model │ │ ├── cfm.yaml │ │ ├── cfm_v2.yaml │ │ ├── cnf.yaml │ │ ├── fm.yaml │ │ ├── icnn.yaml │ │ ├── image_cfm.yaml │ │ ├── otcfm.yaml │ │ ├── sbcfm.yaml │ │ └── trajectorynet.yaml │ ├── paths │ │ └── default.yaml │ ├── train.yaml │ └── trainer │ │ ├── cpu.yaml │ │ ├── ddp.yaml │ │ ├── ddp_sim.yaml │ │ ├── default.yaml │ │ ├── gpu.yaml │ │ └── mps.yaml ├── data │ └── .gitkeep ├── logs │ └── .gitkeep ├── scripts │ ├── schedule.sh │ └── two-dim-cfm.sh ├── src │ ├── __init__.py │ ├── datamodules │ │ ├── __init__.py │ │ ├── cifar10_datamodule.py │ │ ├── components │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── generators2d.py │ │ │ ├── time_dataset.py │ │ │ ├── tnet_dataset.py │ │ │ └── two_dim.py │ │ └── distribution_datamodule.py │ ├── eval.py │ ├── models │ │ ├── __init__.py │ │ ├── cfm_module.py │ │ ├── components │ │ │ ├── __init__.py │ │ │ ├── augmentation.py │ │ │ ├── base.py │ │ │ ├── distribution_distances.py │ │ │ ├── emd.py │ │ │ ├── evaluation.py │ │ │ ├── fp16_util.py │ │ │ ├── hyper_nets.py │ │ │ ├── icnn_model.py │ │ │ ├── layers │ │ │ │ ├── __init__.py │ │ │ │ ├── diffeq_layers │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── basic.py │ │ │ │ │ ├── container.py │ │ │ │ │ ├── resnet.py │ │ │ │ │ └── wrappers.py │ │ │ │ ├── odefunc.py │ │ │ │ └── squeeze.py │ │ │ ├── logger.py │ │ │ ├── mlpode.py │ │ │ ├── mmd.py │ │ │ ├── nn.py │ │ │ ├── optimal_transport.py │ │ │ ├── plotting.py │ │ │ ├── regularizers.py │ │ │ ├── schedule.py │ │ │ ├── simple_dense_net.py │ │ │ ├── simple_mlp.py │ │ │ ├── sinkhorn_knopp_unbalanced.py │ │ │ ├── solver.py │ │ │ ├── unet.py │ │ │ └── utils.py │ │ ├── icnn_module.py │ │ ├── runner.py │ │ └── utils.py │ ├── train.py │ └── utils │ │ ├── __init__.py │ │ ├── pylogger.py │ │ ├── rich_utils.py │ │ └── utils.py └── tests │ ├── __init__.py │ ├── conftest.py │ ├── helpers │ ├── __init__.py │ ├── package_available.py │ ├── run_if.py │ └── run_sh_command.py │ ├── test_configs.py │ ├── test_datamodule.py │ ├── test_eval.py │ ├── test_sweeps.py │ └── test_train.py ├── setup.py ├── tests ├── test_conditional_flow_matcher.py ├── test_models.py ├── test_optimal_transport.py └── test_time_t.py └── torchcfm ├── __init__.py ├── conditional_flow_matching.py ├── models ├── __init__.py ├── models.py └── unet │ ├── __init__.py │ ├── fp16_util.py │ ├── logger.py │ ├── nn.py │ └── unet.py ├── optimal_transport.py ├── utils.py └── version.py /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## What does this PR do? 2 | 3 | 9 | 10 | Fixes #\ 11 | 12 | ## Before submitting 13 | 14 | - [ ] Did you make sure **title is self-explanatory** and **the description concisely explains the PR**? 15 | - [ ] Did you make sure your **PR does only one thing**, instead of bundling different changes together? 16 | - [ ] Did you list all the **breaking changes** introduced by this pull request? 17 | - [ ] Did you **test your PR locally** with `pytest` command? 18 | - [ ] Did you **run pre-commit hooks** with `pre-commit run -a` command? 19 | 20 | ## Did you have fun? 21 | 22 | Make sure you had fun coding 🙃 23 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "pip" # See documentation for possible values 9 | directory: "/" # Location of package manifests 10 | schedule: 11 | interval: "daily" 12 | ignore: 13 | - dependency-name: "pytorch-lightning" 14 | update-types: ["version-update:semver-patch"] 15 | - dependency-name: "torchmetrics" 16 | update-types: ["version-update:semver-patch"] 17 | -------------------------------------------------------------------------------- /.github/workflows/code-quality-main.yaml: -------------------------------------------------------------------------------- 1 | # Same as `code-quality-pr.yaml` but triggered on commit to main branch 2 | # and runs on all files (instead of only the changed ones) 3 | 4 | name: Code Quality Main 5 | 6 | on: 7 | push: 8 | branches: [main] 9 | 10 | jobs: 11 | code-quality: 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - name: Checkout 16 | uses: actions/checkout@v3 17 | 18 | - name: Set up Python 19 | uses: actions/setup-python@v4 20 | with: 21 | python-version: "3.10" 22 | 23 | - name: Run pre-commits 24 | uses: pre-commit/action@v2.0.3 25 | -------------------------------------------------------------------------------- /.github/workflows/code-quality-pr.yaml: -------------------------------------------------------------------------------- 1 | # This workflow finds which files were changed, prints them, 2 | # and runs `pre-commit` on those files. 3 | 4 | # Inspired by the sktime library: 5 | # https://github.com/alan-turing-institute/sktime/blob/main/.github/workflows/test.yml 6 | 7 | name: Code Quality PR 8 | 9 | env: 10 | SKLEARN_ALLOW_DEPRECATED_SKLEARN_PACKAGE_INSTALL: True 11 | 12 | on: 13 | pull_request: 14 | branches: [main, "release/*"] 15 | 16 | jobs: 17 | code-quality: 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - name: Checkout 22 | uses: actions/checkout@v3 23 | 24 | - name: Set up Python 25 | uses: actions/setup-python@v3 26 | with: 27 | python-version: "3.10" 28 | 29 | - name: Find modified files 30 | id: file_changes 31 | uses: trilom/file-changes-action@v1.2.4 32 | with: 33 | output: " " 34 | 35 | - name: List modified files 36 | run: echo '${{ steps.file_changes.outputs.files}}' 37 | 38 | - name: Run pre-commits 39 | uses: pre-commit/action@v2.0.3 40 | with: 41 | extra_args: --files ${{ steps.file_changes.outputs.files}} 42 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v3 24 | - name: Set up Python 25 | uses: actions/setup-python@v3 26 | with: 27 | python-version: "3.x" 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: TorchCFM Tests 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main, "release/*"] 8 | 9 | jobs: 10 | run_tests_ubuntu: 11 | runs-on: ${{ matrix.os }} 12 | 13 | strategy: 14 | fail-fast: false 15 | matrix: 16 | os: [ubuntu-latest, ubuntu-20.04, macos-latest, windows-latest] 17 | python-version: ["3.9", "3.10", "3.11", "3.12"] 18 | 19 | steps: 20 | - name: Checkout 21 | uses: actions/checkout@v3 22 | 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v4 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install pytest 32 | pip install sh 33 | pip install -e . 34 | 35 | - name: List dependencies 36 | run: | 37 | python -m pip list 38 | 39 | - name: Run pytest 40 | run: | 41 | pytest -v --ignore=examples --ignore=runner 42 | 43 | # upload code coverage report 44 | code-coverage-torchcfm: 45 | runs-on: ubuntu-latest 46 | 47 | steps: 48 | - name: Checkout 49 | uses: actions/checkout@v3 50 | 51 | - name: Set up Python 3.10 52 | uses: actions/setup-python@v4 53 | with: 54 | python-version: "3.10" 55 | 56 | - name: Install dependencies 57 | run: | 58 | python -m pip install --upgrade pip 59 | pip install pytest 60 | pip install pytest-cov[toml] 61 | pip install sh 62 | pip install -e . 63 | 64 | - name: Run tests and collect coverage 65 | run: pytest . --cov torchcfm --ignore=runner --ignore=examples --ignore=torchcfm/models/ --cov-fail-under=30 66 | 67 | - name: Upload coverage to Codecov 68 | uses: codecov/codecov-action@v3 69 | with: 70 | name: codecov-torchcfm 71 | verbose: true 72 | -------------------------------------------------------------------------------- /.github/workflows/test_runner.yaml: -------------------------------------------------------------------------------- 1 | name: Runner Tests 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main, "release/*"] 8 | 9 | jobs: 10 | run_tests_ubuntu: 11 | runs-on: ${{ matrix.os }} 12 | 13 | strategy: 14 | fail-fast: false 15 | matrix: 16 | os: [ubuntu-latest, ubuntu-20.04, macos-latest, windows-latest] 17 | python-version: ["3.9", "3.10"] 18 | 19 | steps: 20 | - name: Checkout 21 | uses: actions/checkout@v3 22 | 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v4 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | 28 | - name: Install dependencies 29 | run: | 30 | # Fix pip version < 24.1 due to lightning incomaptibility 31 | python -m pip install pip==23.2.1 32 | pip install -r runner-requirements.txt 33 | pip install pytest 34 | pip install sh 35 | pip install -e . 36 | 37 | - name: List dependencies 38 | run: | 39 | python -m pip list 40 | 41 | - name: Run pytest 42 | run: | 43 | pytest -v runner 44 | 45 | # upload code coverage report 46 | code-coverage-runner: 47 | runs-on: ubuntu-latest 48 | 49 | steps: 50 | - name: Checkout 51 | uses: actions/checkout@v3 52 | 53 | - name: Set up Python 3.10 54 | uses: actions/setup-python@v4 55 | with: 56 | python-version: "3.10" 57 | 58 | - name: Install dependencies 59 | run: | 60 | # Fix pip version < 24.1 due to lightning incomaptibility 61 | python -m pip install pip==23.2.1 62 | pip install -r runner-requirements.txt 63 | pip install pytest 64 | pip install pytest-cov[toml] 65 | pip install sh 66 | pip install -e . 67 | 68 | - name: Run tests and collect coverage 69 | run: pytest runner --cov runner --cov-fail-under=30 # NEEDS TO BE UPDATED WHEN CHANGING THE NAME OF "src" FOLDER 70 | 71 | - name: Upload coverage to Codecov 72 | uses: codecov/codecov-action@v3 73 | with: 74 | name: codecov-runner 75 | verbose: true 76 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | ### VisualStudioCode 131 | .vscode/* 132 | !.vscode/settings.json 133 | !.vscode/tasks.json 134 | !.vscode/launch.json 135 | !.vscode/extensions.json 136 | *.code-workspace 137 | **/.vscode 138 | 139 | # JetBrains 140 | .idea/ 141 | 142 | # Lightning-Hydra-Template 143 | configs/local/default.yaml 144 | data/ 145 | logs/ 146 | wandb/ 147 | .env 148 | .autoenv 149 | 150 | #Vim 151 | *.sw? 152 | 153 | # Slurm 154 | slurm*.out 155 | 156 | # Data and models 157 | *.pt 158 | *.h5 159 | *.h5ad 160 | *.tar 161 | *.tar.gz 162 | *.pkl 163 | *.npy 164 | *.npz 165 | *.csv 166 | 167 | # Images 168 | *.png 169 | *.svg 170 | *.gif 171 | *.jpg 172 | 173 | notebooks/figures/ 174 | 175 | .DS_Store 176 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | node: 16.14.2 4 | 5 | repos: 6 | - repo: https://github.com/pre-commit/pre-commit-hooks 7 | rev: v4.4.0 8 | hooks: 9 | # list of supported hooks: https://pre-commit.com/hooks.html 10 | - id: trailing-whitespace 11 | exclude: .svg$ 12 | require_serial: true 13 | - id: end-of-file-fixer 14 | require_serial: true 15 | - id: check-docstring-first 16 | require_serial: true 17 | - id: check-yaml 18 | require_serial: true 19 | - id: debug-statements 20 | require_serial: true 21 | - id: detect-private-key 22 | require_serial: true 23 | - id: check-executables-have-shebangs 24 | require_serial: true 25 | - id: check-toml 26 | require_serial: true 27 | - id: check-case-conflict 28 | require_serial: true 29 | - id: check-added-large-files 30 | require_serial: true 31 | 32 | # python code formatting 33 | - repo: https://github.com/psf/black 34 | rev: 23.7.0 35 | hooks: 36 | - id: black 37 | require_serial: true 38 | args: [--line-length, "99"] 39 | 40 | # python import sorting 41 | - repo: https://github.com/PyCQA/isort 42 | rev: 5.12.0 43 | hooks: 44 | - id: isort 45 | require_serial: true 46 | args: ["--profile", "black", "--filter-files"] 47 | 48 | # python upgrading syntax to newer version 49 | - repo: https://github.com/asottile/pyupgrade 50 | rev: v3.9.0 51 | hooks: 52 | - id: pyupgrade 53 | require_serial: true 54 | args: [--py38-plus] 55 | 56 | # python docstring formatting 57 | - repo: https://github.com/myint/docformatter 58 | rev: master 59 | hooks: 60 | - id: docformatter 61 | require_serial: true 62 | args: [--in-place, --wrap-summaries=99, --wrap-descriptions=99] 63 | 64 | # python check (PEP8), programming errors and code complexity 65 | - repo: https://github.com/PyCQA/flake8 66 | rev: 6.0.0 67 | hooks: 68 | - id: flake8 69 | require_serial: true 70 | entry: pflake8 71 | additional_dependencies: ["pyproject-flake8"] 72 | 73 | # python security linter 74 | - repo: https://github.com/PyCQA/bandit 75 | rev: "1.7.5" 76 | hooks: 77 | - id: bandit 78 | require_serial: true 79 | args: ["-c", "pyproject.toml"] 80 | additional_dependencies: ["bandit[toml]"] 81 | 82 | # yaml formatting 83 | - repo: https://github.com/pre-commit/mirrors-prettier 84 | rev: v3.0.0 85 | hooks: 86 | - id: prettier 87 | require_serial: true 88 | types: [yaml] 89 | 90 | # shell scripts linter 91 | - repo: https://github.com/shellcheck-py/shellcheck-py 92 | rev: v0.9.0.5 93 | hooks: 94 | - id: shellcheck 95 | require_serial: true 96 | args: ["-e", "SC2102"] 97 | 98 | # md formatting 99 | - repo: https://github.com/executablebooks/mdformat 100 | rev: 0.7.16 101 | hooks: 102 | - id: mdformat 103 | require_serial: true 104 | args: ["--number"] 105 | additional_dependencies: 106 | - mdformat-gfm 107 | - mdformat-tables 108 | - mdformat_frontmatter 109 | # - mdformat-toc 110 | # - mdformat-black 111 | 112 | # word spelling linter 113 | - repo: https://github.com/codespell-project/codespell 114 | rev: v2.2.5 115 | hooks: 116 | - id: codespell 117 | require_serial: true 118 | args: 119 | - --skip=logs/**,data/**,*.ipynb 120 | - --ignore-words-list=ot,hist 121 | 122 | # jupyter notebook linting 123 | - repo: https://github.com/nbQA-dev/nbQA 124 | rev: 1.7.0 125 | hooks: 126 | - id: nbqa-black 127 | args: ["--line-length=99"] 128 | require_serial: true 129 | - id: nbqa-isort 130 | args: ["--profile=black"] 131 | require_serial: true 132 | - id: nbqa-flake8 133 | args: 134 | [ 135 | "--extend-ignore=E203,E402,E501,F401,F841,F821,F403,F405,F811", 136 | "--exclude=logs/*,data/*,notebooks/*", 137 | ] 138 | require_serial: true 139 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Alexander Tong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /assets/169_generated_samples_otcfm.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atong01/conditional-flow-matching/3fd278f9ef2f02e17e107e5769130b6cb44803e2/assets/169_generated_samples_otcfm.gif -------------------------------------------------------------------------------- /assets/169_generated_samples_otcfm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atong01/conditional-flow-matching/3fd278f9ef2f02e17e107e5769130b6cb44803e2/assets/169_generated_samples_otcfm.png -------------------------------------------------------------------------------- /assets/8gaussians-to-moons.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atong01/conditional-flow-matching/3fd278f9ef2f02e17e107e5769130b6cb44803e2/assets/8gaussians-to-moons.gif -------------------------------------------------------------------------------- /assets/DF_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atong01/conditional-flow-matching/3fd278f9ef2f02e17e107e5769130b6cb44803e2/assets/DF_logo.png -------------------------------------------------------------------------------- /assets/DF_logo_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atong01/conditional-flow-matching/3fd278f9ef2f02e17e107e5769130b6cb44803e2/assets/DF_logo_dark.png -------------------------------------------------------------------------------- /assets/gaussian-to-moons.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atong01/conditional-flow-matching/3fd278f9ef2f02e17e107e5769130b6cb44803e2/assets/gaussian-to-moons.gif -------------------------------------------------------------------------------- /examples/2D_tutorials/preprocessing/README.md: -------------------------------------------------------------------------------- 1 | ### Data Preprocessing 2 | 3 | For the 4 | -------------------------------------------------------------------------------- /examples/images/cifar10/README.md: -------------------------------------------------------------------------------- 1 | # CIFAR-10 experiments using TorchCFM 2 | 3 | This repository is used to reproduce the CIFAR-10 experiments from [1](https://arxiv.org/abs/2302.00482). We have designed a novel experimental procedure that helps us to reach an __FID of 3.5__ on the Cifar10 dataset. 4 | 5 |

6 | 7 |

8 | 9 | To reproduce the experiments and save the weights, install the requirements from the main repository and then run (runs on a single RTX 2080 GPU): 10 | 11 | - For the OT-Conditional Flow Matching method: 12 | 13 | ```bash 14 | python3 train_cifar10.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 15 | ``` 16 | 17 | - For the Independent Conditional Flow Matching (I-CFM) method: 18 | 19 | ```bash 20 | python3 train_cifar10.py --model "icfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 21 | ``` 22 | 23 | - For the original Flow Matching method: 24 | 25 | ```bash 26 | python3 train_cifar10.py --model "fm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 27 | ``` 28 | 29 | Note that you can train all our methods in parallel using multiple GPUs and DistributedDataParallel. You can do this by providing the number of GPUs, setting the parallel flag to True and providing the master address and port in the command line. Please refer to [the official document for the usage](https://pytorch.org/docs/stable/elastic/run.html#usage). As an example: 30 | 31 | ```bash 32 | torchrun --standalone --nnodes=1 --nproc_per_node=NUM_GPUS_YOU_HAVE train_cifar10_ddp.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 --parallel True --master_addr "MASTER_ADDR" --master_port "MASTER_PORT" 33 | ``` 34 | 35 | To compute the FID from the OT-CFM model at end of training, run: 36 | 37 | ```bash 38 | python3 compute_fid.py --model "otcfm" --step 400000 --integration_method dopri5 39 | ``` 40 | 41 | For the other models, change the "otcfm" argument by "icfm" or "fm". For easy reproducibility of our results, you can download the model weights at 400000 iterations here: 42 | 43 | - [icfm weights](https://github.com/atong01/conditional-flow-matching/releases/download/1.0.4/cfm_cifar10_weights_step_400000.pt) 44 | 45 | - [otcfm weights](https://github.com/atong01/conditional-flow-matching/releases/download/1.0.4/otcfm_cifar10_weights_step_400000.pt) 46 | 47 | - [fm weights](https://github.com/atong01/conditional-flow-matching/releases/download/1.0.4/fm_cifar10_weights_step_400000.pt) 48 | 49 | To recompute the FID, change the PATH variable with where you have saved the downloaded weights. 50 | 51 | If you find this code useful in your research, please cite the following papers (expand for BibTeX): 52 | 53 |
54 | 55 | A. Tong, N. Malkin, G. Huguet, Y. Zhang, J. Rector-Brooks, K. Fatras, G. Wolf, Y. Bengio. Improving and Generalizing Flow-Based Generative Models with Minibatch Optimal Transport, 2023. 56 | 57 | 58 | ```bibtex 59 | @article{tong2023improving, 60 | title={Improving and Generalizing Flow-Based Generative Models with Minibatch Optimal Transport}, 61 | author={Tong, Alexander and Malkin, Nikolay and Huguet, Guillaume and Zhang, Yanlei and {Rector-Brooks}, Jarrid and Fatras, Kilian and Wolf, Guy and Bengio, Yoshua}, 62 | year={2023}, 63 | journal={arXiv preprint 2302.00482} 64 | } 65 | ``` 66 | 67 |
68 | 69 |
70 | 71 | A. Tong, N. Malkin, K. Fatras, L. Atanackovic, Y. Zhang, G. Huguet, G. Wolf, Y. Bengio. Simulation-Free Schrödinger Bridges via Score and Flow Matching, 2023. 72 | 73 | 74 | ```bibtex 75 | @article{tong2023simulation, 76 | title={Simulation-Free Schr{\"o}dinger Bridges via Score and Flow Matching}, 77 | author={Tong, Alexander and Malkin, Nikolay and Fatras, Kilian and Atanackovic, Lazar and Zhang, Yanlei and Huguet, Guillaume and Wolf, Guy and Bengio, Yoshua}, 78 | year={2023}, 79 | journal={arXiv preprint 2307.03672} 80 | } 81 | ``` 82 | -------------------------------------------------------------------------------- /examples/images/cifar10/compute_fid.py: -------------------------------------------------------------------------------- 1 | # Inspired from https://github.com/w86763777/pytorch-ddpm/tree/master. 2 | 3 | # Authors: Kilian Fatras 4 | # Alexander Tong 5 | 6 | import os 7 | import sys 8 | 9 | import matplotlib.pyplot as plt 10 | import torch 11 | from absl import app, flags 12 | from cleanfid import fid 13 | from torchdiffeq import odeint 14 | from torchdyn.core import NeuralODE 15 | 16 | from torchcfm.models.unet.unet import UNetModelWrapper 17 | 18 | FLAGS = flags.FLAGS 19 | # UNet 20 | flags.DEFINE_integer("num_channel", 128, help="base channel of UNet") 21 | 22 | # Training 23 | flags.DEFINE_string("input_dir", "./results", help="output_directory") 24 | flags.DEFINE_string("model", "otcfm", help="flow matching model type") 25 | flags.DEFINE_integer("integration_steps", 100, help="number of inference steps") 26 | flags.DEFINE_string("integration_method", "dopri5", help="integration method to use") 27 | flags.DEFINE_integer("step", 400000, help="training steps") 28 | flags.DEFINE_integer("num_gen", 50000, help="number of samples to generate") 29 | flags.DEFINE_float("tol", 1e-5, help="Integrator tolerance (absolute and relative)") 30 | flags.DEFINE_integer("batch_size_fid", 1024, help="Batch size to compute FID") 31 | 32 | FLAGS(sys.argv) 33 | 34 | 35 | # Define the model 36 | use_cuda = torch.cuda.is_available() 37 | device = torch.device("cuda:0" if use_cuda else "cpu") 38 | 39 | new_net = UNetModelWrapper( 40 | dim=(3, 32, 32), 41 | num_res_blocks=2, 42 | num_channels=FLAGS.num_channel, 43 | channel_mult=[1, 2, 2, 2], 44 | num_heads=4, 45 | num_head_channels=64, 46 | attention_resolutions="16", 47 | dropout=0.1, 48 | ).to(device) 49 | 50 | 51 | # Load the model 52 | PATH = f"{FLAGS.input_dir}/{FLAGS.model}/{FLAGS.model}_cifar10_weights_step_{FLAGS.step}.pt" 53 | print("path: ", PATH) 54 | checkpoint = torch.load(PATH, map_location=device) 55 | state_dict = checkpoint["ema_model"] 56 | try: 57 | new_net.load_state_dict(state_dict) 58 | except RuntimeError: 59 | from collections import OrderedDict 60 | 61 | new_state_dict = OrderedDict() 62 | for k, v in state_dict.items(): 63 | new_state_dict[k[7:]] = v 64 | new_net.load_state_dict(new_state_dict) 65 | new_net.eval() 66 | 67 | 68 | # Define the integration method if euler is used 69 | if FLAGS.integration_method == "euler": 70 | node = NeuralODE(new_net, solver=FLAGS.integration_method) 71 | 72 | 73 | def gen_1_img(unused_latent): 74 | with torch.no_grad(): 75 | x = torch.randn(FLAGS.batch_size_fid, 3, 32, 32, device=device) 76 | if FLAGS.integration_method == "euler": 77 | print("Use method: ", FLAGS.integration_method) 78 | t_span = torch.linspace(0, 1, FLAGS.integration_steps + 1, device=device) 79 | traj = node.trajectory(x, t_span=t_span) 80 | else: 81 | print("Use method: ", FLAGS.integration_method) 82 | t_span = torch.linspace(0, 1, 2, device=device) 83 | traj = odeint( 84 | new_net, x, t_span, rtol=FLAGS.tol, atol=FLAGS.tol, method=FLAGS.integration_method 85 | ) 86 | traj = traj[-1, :] # .view([-1, 3, 32, 32]).clip(-1, 1) 87 | img = (traj * 127.5 + 128).clip(0, 255).to(torch.uint8) # .permute(1, 2, 0) 88 | return img 89 | 90 | 91 | print("Start computing FID") 92 | score = fid.compute_fid( 93 | gen=gen_1_img, 94 | dataset_name="cifar10", 95 | batch_size=FLAGS.batch_size_fid, 96 | dataset_res=32, 97 | num_gen=FLAGS.num_gen, 98 | dataset_split="train", 99 | mode="legacy_tensorflow", 100 | ) 101 | print() 102 | print("FID has been computed") 103 | # print() 104 | # print("Total NFE: ", new_net.nfe) 105 | print() 106 | print("FID: ", score) 107 | -------------------------------------------------------------------------------- /examples/images/cifar10/train_cifar10.py: -------------------------------------------------------------------------------- 1 | # Inspired from https://github.com/w86763777/pytorch-ddpm/tree/master. 2 | 3 | # Authors: Kilian Fatras 4 | # Alexander Tong 5 | 6 | import copy 7 | import os 8 | 9 | import torch 10 | from absl import app, flags 11 | from torchdyn.core import NeuralODE 12 | from torchvision import datasets, transforms 13 | from tqdm import trange 14 | from utils_cifar import ema, generate_samples, infiniteloop 15 | 16 | from torchcfm.conditional_flow_matching import ( 17 | ConditionalFlowMatcher, 18 | ExactOptimalTransportConditionalFlowMatcher, 19 | TargetConditionalFlowMatcher, 20 | VariancePreservingConditionalFlowMatcher, 21 | ) 22 | from torchcfm.models.unet.unet import UNetModelWrapper 23 | 24 | FLAGS = flags.FLAGS 25 | 26 | flags.DEFINE_string("model", "otcfm", help="flow matching model type") 27 | flags.DEFINE_string("output_dir", "./results/", help="output_directory") 28 | # UNet 29 | flags.DEFINE_integer("num_channel", 128, help="base channel of UNet") 30 | 31 | # Training 32 | flags.DEFINE_float("lr", 2e-4, help="target learning rate") # TRY 2e-4 33 | flags.DEFINE_float("grad_clip", 1.0, help="gradient norm clipping") 34 | flags.DEFINE_integer( 35 | "total_steps", 400001, help="total training steps" 36 | ) # Lipman et al uses 400k but double batch size 37 | flags.DEFINE_integer("warmup", 5000, help="learning rate warmup") 38 | flags.DEFINE_integer("batch_size", 128, help="batch size") # Lipman et al uses 128 39 | flags.DEFINE_integer("num_workers", 4, help="workers of Dataloader") 40 | flags.DEFINE_float("ema_decay", 0.9999, help="ema decay rate") 41 | flags.DEFINE_bool("parallel", False, help="multi gpu training") 42 | 43 | # Evaluation 44 | flags.DEFINE_integer( 45 | "save_step", 46 | 20000, 47 | help="frequency of saving checkpoints, 0 to disable during training", 48 | ) 49 | 50 | 51 | use_cuda = torch.cuda.is_available() 52 | device = torch.device("cuda" if use_cuda else "cpu") 53 | 54 | 55 | def warmup_lr(step): 56 | return min(step, FLAGS.warmup) / FLAGS.warmup 57 | 58 | 59 | def train(argv): 60 | print( 61 | "lr, total_steps, ema decay, save_step:", 62 | FLAGS.lr, 63 | FLAGS.total_steps, 64 | FLAGS.ema_decay, 65 | FLAGS.save_step, 66 | ) 67 | 68 | # DATASETS/DATALOADER 69 | dataset = datasets.CIFAR10( 70 | root="./data", 71 | train=True, 72 | download=True, 73 | transform=transforms.Compose( 74 | [ 75 | transforms.RandomHorizontalFlip(), 76 | transforms.ToTensor(), 77 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 78 | ] 79 | ), 80 | ) 81 | dataloader = torch.utils.data.DataLoader( 82 | dataset, 83 | batch_size=FLAGS.batch_size, 84 | shuffle=True, 85 | num_workers=FLAGS.num_workers, 86 | drop_last=True, 87 | ) 88 | 89 | datalooper = infiniteloop(dataloader) 90 | 91 | # MODELS 92 | net_model = UNetModelWrapper( 93 | dim=(3, 32, 32), 94 | num_res_blocks=2, 95 | num_channels=FLAGS.num_channel, 96 | channel_mult=[1, 2, 2, 2], 97 | num_heads=4, 98 | num_head_channels=64, 99 | attention_resolutions="16", 100 | dropout=0.1, 101 | ).to( 102 | device 103 | ) # new dropout + bs of 128 104 | 105 | ema_model = copy.deepcopy(net_model) 106 | optim = torch.optim.Adam(net_model.parameters(), lr=FLAGS.lr) 107 | sched = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=warmup_lr) 108 | if FLAGS.parallel: 109 | print( 110 | "Warning: parallel training is performing slightly worse than single GPU training due to statistics computation in dataparallel. We recommend to train over a single GPU, which requires around 8 Gb of GPU memory." 111 | ) 112 | net_model = torch.nn.DataParallel(net_model) 113 | ema_model = torch.nn.DataParallel(ema_model) 114 | 115 | # show model size 116 | model_size = 0 117 | for param in net_model.parameters(): 118 | model_size += param.data.nelement() 119 | print("Model params: %.2f M" % (model_size / 1024 / 1024)) 120 | 121 | ################################# 122 | # OT-CFM 123 | ################################# 124 | 125 | sigma = 0.0 126 | if FLAGS.model == "otcfm": 127 | FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma) 128 | elif FLAGS.model == "icfm": 129 | FM = ConditionalFlowMatcher(sigma=sigma) 130 | elif FLAGS.model == "fm": 131 | FM = TargetConditionalFlowMatcher(sigma=sigma) 132 | elif FLAGS.model == "si": 133 | FM = VariancePreservingConditionalFlowMatcher(sigma=sigma) 134 | else: 135 | raise NotImplementedError( 136 | f"Unknown model {FLAGS.model}, must be one of ['otcfm', 'icfm', 'fm', 'si']" 137 | ) 138 | 139 | savedir = FLAGS.output_dir + FLAGS.model + "/" 140 | os.makedirs(savedir, exist_ok=True) 141 | 142 | with trange(FLAGS.total_steps, dynamic_ncols=True) as pbar: 143 | for step in pbar: 144 | optim.zero_grad() 145 | x1 = next(datalooper).to(device) 146 | x0 = torch.randn_like(x1) 147 | t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1) 148 | vt = net_model(t, xt) 149 | loss = torch.mean((vt - ut) ** 2) 150 | loss.backward() 151 | torch.nn.utils.clip_grad_norm_(net_model.parameters(), FLAGS.grad_clip) # new 152 | optim.step() 153 | sched.step() 154 | ema(net_model, ema_model, FLAGS.ema_decay) # new 155 | 156 | # sample and Saving the weights 157 | if FLAGS.save_step > 0 and step % FLAGS.save_step == 0: 158 | generate_samples(net_model, FLAGS.parallel, savedir, step, net_="normal") 159 | generate_samples(ema_model, FLAGS.parallel, savedir, step, net_="ema") 160 | torch.save( 161 | { 162 | "net_model": net_model.state_dict(), 163 | "ema_model": ema_model.state_dict(), 164 | "sched": sched.state_dict(), 165 | "optim": optim.state_dict(), 166 | "step": step, 167 | }, 168 | savedir + f"{FLAGS.model}_cifar10_weights_step_{step}.pt", 169 | ) 170 | 171 | 172 | if __name__ == "__main__": 173 | app.run(train) 174 | -------------------------------------------------------------------------------- /examples/images/cifar10/utils_cifar.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | import torch 5 | from torch import distributed as dist 6 | from torchdyn.core import NeuralODE 7 | 8 | # from torchvision.transforms import ToPILImage 9 | from torchvision.utils import make_grid, save_image 10 | 11 | use_cuda = torch.cuda.is_available() 12 | device = torch.device("cuda" if use_cuda else "cpu") 13 | 14 | 15 | def setup( 16 | rank: int, 17 | total_num_gpus: int, 18 | master_addr: str = "localhost", 19 | master_port: str = "12355", 20 | backend: str = "nccl", 21 | ): 22 | """Initialize the distributed environment. 23 | 24 | Args: 25 | rank: Rank of the current process. 26 | total_num_gpus: Number of GPUs used in the job. 27 | master_addr: IP address of the master node. 28 | master_port: Port number of the master node. 29 | backend: Backend to use. 30 | """ 31 | 32 | os.environ["MASTER_ADDR"] = master_addr 33 | os.environ["MASTER_PORT"] = master_port 34 | 35 | # initialize the process group 36 | dist.init_process_group( 37 | backend=backend, 38 | rank=rank, 39 | world_size=total_num_gpus, 40 | ) 41 | 42 | 43 | def generate_samples(model, parallel, savedir, step, net_="normal"): 44 | """Save 64 generated images (8 x 8) for sanity check along training. 45 | 46 | Parameters 47 | ---------- 48 | model: 49 | represents the neural network that we want to generate samples from 50 | parallel: bool 51 | represents the parallel training flag. Torchdyn only runs on 1 GPU, we need to send the models from several GPUs to 1 GPU. 52 | savedir: str 53 | represents the path where we want to save the generated images 54 | step: int 55 | represents the current step of training 56 | """ 57 | model.eval() 58 | 59 | model_ = copy.deepcopy(model) 60 | if parallel: 61 | # Send the models from GPU to CPU for inference with NeuralODE from Torchdyn 62 | model_ = model_.module.to(device) 63 | 64 | node_ = NeuralODE(model_, solver="euler", sensitivity="adjoint") 65 | with torch.no_grad(): 66 | traj = node_.trajectory( 67 | torch.randn(64, 3, 32, 32, device=device), 68 | t_span=torch.linspace(0, 1, 100, device=device), 69 | ) 70 | traj = traj[-1, :].view([-1, 3, 32, 32]).clip(-1, 1) 71 | traj = traj / 2 + 0.5 72 | save_image(traj, savedir + f"{net_}_generated_FM_images_step_{step}.png", nrow=8) 73 | 74 | model.train() 75 | 76 | 77 | def ema(source, target, decay): 78 | source_dict = source.state_dict() 79 | target_dict = target.state_dict() 80 | for key in source_dict.keys(): 81 | target_dict[key].data.copy_( 82 | target_dict[key].data * decay + source_dict[key].data * (1 - decay) 83 | ) 84 | 85 | 86 | def infiniteloop(dataloader): 87 | while True: 88 | for x, y in iter(dataloader): 89 | yield x 90 | -------------------------------------------------------------------------------- /examples/tabular/README.md: -------------------------------------------------------------------------------- 1 | # Forest-Flow experiment on the Iris dataset using TorchCFM 2 | 3 | This notebook is a self-contained example showing how to train the novel Forest-Flow method to generate tabular data [(Jolicoeur-Martineau et al. 2023)](https://arxiv.org/abs/2309.09968). The idea behind Forest-Flow is to **learn Independent Conditional Flow-Matching's vector field with XGBoost models** instead of neural networks. The motivation is that it is known that Forests work currently better than neural networks on Tabular data tasks. This idea comes with some difficulties, for instance how to approximate Flow Matching's loss, and this notebook shows how to do it on a minimal example. The method, its training procedure and the experiments are described in [(Jolicoeur-Martineau et al. 2023)](https://arxiv.org/abs/2309.09968). The full code can be found [here](https://github.com/SamsungSAILMontreal/ForestDiffusion). 4 | 5 | To run our jupyter notebooks, installing our package: 6 | 7 | ```bash 8 | cd ../../ 9 | 10 | # install torchcfm 11 | pip install -e '.[forest-flow]' 12 | 13 | # install ipykernel 14 | conda install -c anaconda ipykernel 15 | 16 | # install conda env in jupyter notebook 17 | python -m ipykernel install --user --name=torchcfm 18 | 19 | # launch our notebooks with the torchcfm kernel 20 | ``` 21 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.pytest.ini_options] 2 | addopts = [ 3 | "--color=yes", 4 | "--durations=0", 5 | "--strict-markers", 6 | "--doctest-modules", 7 | ] 8 | filterwarnings = [ 9 | "ignore::DeprecationWarning", 10 | "ignore::UserWarning", 11 | ] 12 | log_cli = "True" 13 | markers = [ 14 | "slow: slow tests", 15 | ] 16 | minversion = "6.0" 17 | testpaths = "tests/" 18 | 19 | [tool.coverage.report] 20 | exclude_lines = [ 21 | "pragma: nocover", 22 | "raise NotImplementedError", 23 | "raise NotImplementedError()", 24 | "if __name__ == .__main__.:", 25 | ] 26 | 27 | [tool.flake8] 28 | extend-ignore = ["E203", "E402", "E501", "F401", "F841", "E741", "F403"] 29 | exclude = ["logs/*","data/*"] 30 | per-file-ignores = [ 31 | '__init__.py:F401', 32 | ] 33 | max-line-length = 99 34 | count = true 35 | 36 | [tool.bandit] 37 | skips = ["B101", "B311"] 38 | 39 | [tool.isort] 40 | known_first_party = ["tests", "src"] 41 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.11.0 2 | torchvision>=0.11.0 3 | 4 | lightning-bolts 5 | matplotlib 6 | numpy 7 | scipy 8 | scikit-learn 9 | scprep 10 | scanpy 11 | torchdyn>=1.0.6 # 1.0.4 is broken on pypi 12 | pot 13 | torchdiffeq 14 | absl-py 15 | clean-fid 16 | -------------------------------------------------------------------------------- /runner-requirements.txt: -------------------------------------------------------------------------------- 1 | # Note if using Conda it is recommended to install torch separately. 2 | # For most of testing the following commands were run to set up the environment 3 | # This was tested with torch==1.12.1 4 | # conda create -n ti-env python=3.10 5 | # conda activate ti-env 6 | # pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113 7 | # pip install -r requirements.txt 8 | # --------- pytorch --------- # 9 | torch>=1.11.0,<2.0.0 10 | torchvision>=0.11.0 11 | pytorch-lightning==1.8.3.post2 12 | torchmetrics==0.11.0 13 | 14 | # --------- hydra --------- # 15 | hydra-core==1.2.0 16 | hydra-colorlog==1.2.0 17 | hydra-optuna-sweeper==1.2.0 18 | # hydra-submitit-launcher 19 | 20 | # --------- loggers --------- # 21 | wandb 22 | # neptune-client 23 | # mlflow 24 | # comet-ml 25 | 26 | # --------- others --------- # 27 | black 28 | isort 29 | flake8 30 | Flake8-pyproject # for configuration via pyproject 31 | pyrootutils # standardizing the project root setup 32 | pre-commit # hooks for applying linters on commit 33 | rich # beautiful text formatting in terminal 34 | pytest # tests 35 | # sh # for running bash commands in some tests (linux/macos only) 36 | 37 | 38 | # --------- pkg reqs -------- # 39 | lightning-bolts 40 | matplotlib 41 | numpy 42 | scipy 43 | scikit-learn 44 | scprep 45 | scanpy 46 | timm 47 | torchdyn>=1.0.5 # 1.0.4 is broken on pypi 48 | pot 49 | 50 | # --------- notebook reqs -------- # 51 | seaborn>=0.12.2 52 | pandas>=2.2.2 53 | -------------------------------------------------------------------------------- /runner/configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model_checkpoint.yaml 3 | - early_stopping.yaml 4 | - model_summary.yaml 5 | - rich_progress_bar.yaml 6 | - _self_ 7 | 8 | model_checkpoint: 9 | dirpath: ${paths.output_dir}/checkpoints 10 | filename: "epoch_{epoch:04d}" 11 | monitor: "val/loss" 12 | mode: "min" 13 | save_last: True 14 | auto_insert_metric_name: False 15 | 16 | early_stopping: 17 | monitor: "val/loss" 18 | patience: 100 19 | mode: "min" 20 | 21 | model_summary: 22 | max_depth: -1 23 | -------------------------------------------------------------------------------- /runner/configs/callbacks/early_stopping.yaml: -------------------------------------------------------------------------------- 1 | # https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.EarlyStopping.html 2 | 3 | # Monitor a metric and stop training when it stops improving. 4 | # Look at the above link for more detailed information. 5 | early_stopping: 6 | _target_: pytorch_lightning.callbacks.EarlyStopping 7 | monitor: ??? # quantity to be monitored, must be specified !!! 8 | min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement 9 | patience: 3 # number of checks with no improvement after which training will be stopped 10 | verbose: False # verbosity mode 11 | mode: "min" # "max" means higher metric value is better, can be also "min" 12 | strict: True # whether to crash the training if monitor is not found in the validation metrics 13 | check_finite: True # when set True, stops training when the monitor becomes NaN or infinite 14 | stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold 15 | divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold 16 | check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch 17 | # log_rank_zero_only: False # this keyword argument isn't available in stable version 18 | -------------------------------------------------------------------------------- /runner/configs/callbacks/model_checkpoint.yaml: -------------------------------------------------------------------------------- 1 | # https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.ModelCheckpoint.html 2 | 3 | # Save the model periodically by monitoring a quantity. 4 | # Look at the above link for more detailed information. 5 | model_checkpoint: 6 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 7 | dirpath: null # directory to save the model file 8 | filename: null # checkpoint filename 9 | monitor: null # name of the logged metric which determines when model is improving 10 | verbose: False # verbosity mode 11 | save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt 12 | save_top_k: 1 # save k best models (determined by above metric) 13 | mode: "min" # "max" means higher metric value is better, can be also "min" 14 | auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name 15 | save_weights_only: False # if True, then only the model’s weights will be saved 16 | every_n_train_steps: null # number of training steps between checkpoints 17 | train_time_interval: null # checkpoints are monitored at the specified time interval 18 | every_n_epochs: null # number of epochs between checkpoints 19 | save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation 20 | -------------------------------------------------------------------------------- /runner/configs/callbacks/model_summary.yaml: -------------------------------------------------------------------------------- 1 | # https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.RichModelSummary.html 2 | 3 | # Generates a summary of all layers in a LightningModule with rich text formatting. 4 | # Look at the above link for more detailed information. 5 | model_summary: 6 | _target_: pytorch_lightning.callbacks.RichModelSummary 7 | max_depth: 1 # the maximum depth of layer nesting that the summary will include 8 | -------------------------------------------------------------------------------- /runner/configs/callbacks/no_stopping.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model_checkpoint.yaml 3 | - model_summary.yaml 4 | - rich_progress_bar.yaml 5 | - _self_ 6 | 7 | model_checkpoint: 8 | dirpath: ${paths.output_dir}/checkpoints 9 | filename: "epoch_{epoch:04d}" 10 | save_last: True 11 | every_n_epochs: 100 # number of epochs between checkpoints 12 | auto_insert_metric_name: False 13 | 14 | model_summary: 15 | max_depth: 3 16 | -------------------------------------------------------------------------------- /runner/configs/callbacks/none.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atong01/conditional-flow-matching/3fd278f9ef2f02e17e107e5769130b6cb44803e2/runner/configs/callbacks/none.yaml -------------------------------------------------------------------------------- /runner/configs/callbacks/rich_progress_bar.yaml: -------------------------------------------------------------------------------- 1 | # https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.RichProgressBar.html 2 | 3 | # Create a progress bar with rich text formatting. 4 | # Look at the above link for more detailed information. 5 | rich_progress_bar: 6 | _target_: pytorch_lightning.callbacks.RichProgressBar 7 | -------------------------------------------------------------------------------- /runner/configs/datamodule/cifar.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodules.cifar10_datamodule.CIFAR10DataModule 2 | #_target_: pl_bolts.datamodules.cifar10_datamodule.CIFAR10DataModule 3 | data_dir: ${paths.data_dir} 4 | batch_size: 128 5 | val_split: 0.0 6 | num_workers: 0 7 | normalize: True 8 | seed: 42 9 | shuffle: True 10 | pin_memory: True 11 | drop_last: False 12 | -------------------------------------------------------------------------------- /runner/configs/datamodule/custom_dist.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodules.distribution_datamodule.TrajectoryNetDistributionTrajectoryDataModule 2 | #_target_: src.datamodules.distribution_datamodule.DistributionDataModule 3 | 4 | data_dir: ${paths.data_dir} # data_dir is specified in config.yaml 5 | train_val_test_split: 1000 6 | batch_size: 100 7 | num_workers: 0 8 | pin_memory: False 9 | 10 | system: ${paths.data_dir}/embryoid_anndata_small_v2.h5ad 11 | 12 | system_kwargs: 13 | max_dim: 1e10 14 | embedding_name: "phate" 15 | #embedding_name: "highly_variable" 16 | whiten: True 17 | #whiten: False 18 | -------------------------------------------------------------------------------- /runner/configs/datamodule/eb_full.yaml: -------------------------------------------------------------------------------- 1 | # https://github.com/DiffEqML/torchdyn/blob/master/torchdyn/datasets/static_datasets.py 2 | _target_: src.datamodules.distribution_datamodule.TorchDynDataModule 3 | #_target_: src.datamodules.distribution_datamodule.DistributionDataModule 4 | 5 | train_val_test_split: [0.8, 0.1, 0.1] 6 | batch_size: 128 7 | num_workers: 0 8 | pin_memory: False 9 | 10 | system: ${paths.data_dir}/eb_velocity_v5.npz 11 | 12 | system_kwargs: 13 | max_dim: 100 14 | whiten: False 15 | -------------------------------------------------------------------------------- /runner/configs/datamodule/funnel.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodules.distribution_datamodule.TorchDynDataModule 2 | #_target_: src.datamodules.distribution_datamodule.DistributionDataModule 3 | 4 | train_val_test_split: [10000, 1000, 1000] 5 | batch_size: 128 6 | num_workers: 0 7 | pin_memory: False 8 | 9 | system: "funnel" 10 | 11 | system_kwargs: 12 | dim: 10 13 | -------------------------------------------------------------------------------- /runner/configs/datamodule/gaussians.yaml: -------------------------------------------------------------------------------- 1 | # https://github.com/DiffEqML/torchdyn/blob/master/torchdyn/datasets/static_datasets.py 2 | _target_: src.datamodules.distribution_datamodule.TorchDynDataModule 3 | #_target_: src.datamodules.distribution_datamodule.DistributionDataModule 4 | 5 | train_val_test_split: [10000, 1000, 1000] 6 | batch_size: 128 7 | num_workers: 0 8 | pin_memory: False 9 | 10 | system: "gaussians" 11 | 12 | system_kwargs: 13 | noise: 1e-4 14 | -------------------------------------------------------------------------------- /runner/configs/datamodule/moons.yaml: -------------------------------------------------------------------------------- 1 | # https://github.com/DiffEqML/torchdyn/blob/master/torchdyn/datasets/static_datasets.py 2 | _target_: src.datamodules.distribution_datamodule.SKLearnDataModule 3 | #_target_: src.datamodules.distribution_datamodule.DistributionDataModule 4 | 5 | train_val_test_split: [10000, 1000, 1000] 6 | batch_size: 128 7 | num_workers: 0 8 | pin_memory: False 9 | 10 | system: "moons" 11 | -------------------------------------------------------------------------------- /runner/configs/datamodule/scurve.yaml: -------------------------------------------------------------------------------- 1 | # https://github.com/DiffEqML/torchdyn/blob/master/torchdyn/datasets/static_datasets.py 2 | _target_: src.datamodules.distribution_datamodule.SKLearnDataModule 3 | #_target_: src.datamodules.distribution_datamodule.DistributionDataModule 4 | 5 | train_val_test_split: [10000, 1000, 1000] 6 | batch_size: 128 7 | num_workers: 0 8 | pin_memory: False 9 | 10 | system: "scurve" 11 | -------------------------------------------------------------------------------- /runner/configs/datamodule/sklearn.yaml: -------------------------------------------------------------------------------- 1 | # https://github.com/DiffEqML/torchdyn/blob/master/torchdyn/datasets/static_datasets.py 2 | _target_: src.datamodules.distribution_datamodule.SKLearnDataModule 3 | #_target_: src.datamodules.distribution_datamodule.DistributionDataModule 4 | 5 | train_val_test_split: [10000, 1000, 1000] 6 | batch_size: 128 7 | num_workers: 0 8 | pin_memory: False 9 | 10 | system: "scurve" 11 | -------------------------------------------------------------------------------- /runner/configs/datamodule/time_dist.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodules.distribution_datamodule.CustomTrajectoryDataModule 2 | #_target_: src.datamodules.distribution_datamodule.DistributionDataModule 3 | 4 | data_dir: ${paths.data_dir} # data_dir is specified in config.yaml 5 | train_val_test_split: [0.8, 0.1, 0.1] 6 | batch_size: 128 7 | num_workers: 0 8 | pin_memory: False 9 | max_dim: 5 10 | whiten: True 11 | 12 | system: ${paths.data_dir}/eb_velocity_v5.npz 13 | -------------------------------------------------------------------------------- /runner/configs/datamodule/torchdyn.yaml: -------------------------------------------------------------------------------- 1 | # https://github.com/DiffEqML/torchdyn/blob/master/torchdyn/datasets/static_datasets.py 2 | _target_: src.datamodules.distribution_datamodule.TorchDynDataModule 3 | #_target_: src.datamodules.distribution_datamodule.DistributionDataModule 4 | 5 | train_val_test_split: [10000, 1000, 1000] 6 | batch_size: 128 7 | num_workers: 0 8 | pin_memory: False 9 | 10 | system: "moons" 11 | 12 | system_kwargs: 13 | noise: 1e-4 14 | -------------------------------------------------------------------------------- /runner/configs/datamodule/tree.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodules.distribution_datamodule.DistributionDataModule 2 | 3 | data_dir: ${data_dir} # data_dir is specified in config.yaml 4 | train_val_test_split: 1000 5 | batch_size: 100 6 | num_workers: 0 7 | pin_memory: False 8 | p: 2 9 | -------------------------------------------------------------------------------- /runner/configs/datamodule/twodim.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodules.distribution_datamodule.TwoDimDataModule 2 | 3 | train_val_test_split: [10000, 1000, 1000] 4 | batch_size: 128 5 | num_workers: 0 6 | pin_memory: False 7 | 8 | system: "moon-8gaussians" 9 | -------------------------------------------------------------------------------- /runner/configs/debug/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # default debugging setup, runs 1 full epoch 4 | # other debugging configs can inherit from this one 5 | 6 | # overwrite task name so debugging logs are stored in separate folder 7 | task_name: "debug" 8 | 9 | # disable callbacks and loggers during debugging 10 | callbacks: null 11 | logger: null 12 | 13 | extras: 14 | ignore_warnings: False 15 | enforce_tags: False 16 | 17 | # sets level of all command line loggers to 'DEBUG' 18 | # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ 19 | hydra: 20 | job_logging: 21 | root: 22 | level: DEBUG 23 | 24 | # use this to also set hydra loggers to 'DEBUG' 25 | # verbose: True 26 | 27 | trainer: 28 | max_epochs: 1 29 | accelerator: cpu # debuggers don't like gpus 30 | devices: 1 # debuggers don't like multiprocessing 31 | detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor 32 | 33 | datamodule: 34 | num_workers: 0 # debuggers don't like multiprocessing 35 | pin_memory: False # disable gpu memory pin 36 | -------------------------------------------------------------------------------- /runner/configs/debug/fdr.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs 1 train, 1 validation and 1 test step 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | fast_dev_run: true 10 | -------------------------------------------------------------------------------- /runner/configs/debug/limit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # uses only 1% of the training data and 5% of validation/test data 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | max_epochs: 3 10 | limit_train_batches: 0.01 11 | limit_val_batches: 0.05 12 | limit_test_batches: 0.05 13 | -------------------------------------------------------------------------------- /runner/configs/debug/overfit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # overfits to 3 batches 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | max_epochs: 20 10 | overfit_batches: 3 11 | 12 | # model ckpt and early stopping need to be disabled during overfitting 13 | callbacks: null 14 | -------------------------------------------------------------------------------- /runner/configs/debug/profiler.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs with execution time profiling 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | max_epochs: 1 10 | profiler: "simple" 11 | # profiler: "advanced" 12 | # profiler: "pytorch" 13 | -------------------------------------------------------------------------------- /runner/configs/eval.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - datamodule: sklearn # choose datamodule with `test_dataloader()` for evaluation 6 | - model: cfm 7 | - logger: null 8 | - trainer: default.yaml 9 | - paths: default.yaml 10 | - extras: default.yaml 11 | - hydra: default.yaml 12 | 13 | task_name: "eval" 14 | 15 | tags: ["dev"] 16 | 17 | # passing checkpoint path is necessary for evaluation 18 | ckpt_path: ??? 19 | -------------------------------------------------------------------------------- /runner/configs/experiment/cfm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: cfm.yaml 5 | - override /logger: 6 | - csv.yaml 7 | - wandb.yaml 8 | - override /datamodule: sklearn.yaml 9 | 10 | name: "cfm" 11 | seed: 42 12 | 13 | datamodule: 14 | batch_size: 512 15 | 16 | model: 17 | optimizer: 18 | weight_decay: 1e-5 19 | 20 | trainer: 21 | max_epochs: 1000 22 | check_val_every_n_epoch: 10 23 | -------------------------------------------------------------------------------- /runner/configs/experiment/cnf.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: cnf.yaml 5 | - override /logger: 6 | - csv.yaml 7 | - wandb.yaml 8 | - override /datamodule: sklearn.yaml 9 | 10 | name: "cnf" 11 | seed: 42 12 | 13 | datamodule: 14 | batch_size: 1024 15 | 16 | model: 17 | optimizer: 18 | weight_decay: 1e-5 19 | 20 | trainer: 21 | max_epochs: 1000 22 | check_val_every_n_epoch: 10 23 | -------------------------------------------------------------------------------- /runner/configs/experiment/icnn.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: icnn 5 | - override /logger: 6 | - csv 7 | - wandb 8 | - override /datamodule: sklearn 9 | 10 | name: "icnn" 11 | seed: 42 12 | 13 | datamodule: 14 | batch_size: 256 15 | 16 | trainer: 17 | max_epochs: 10000 18 | check_val_every_n_epoch: 100 19 | -------------------------------------------------------------------------------- /runner/configs/experiment/image_cfm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: image_cfm.yaml 5 | - override /callbacks: no_stopping 6 | - override /logger: 7 | - csv.yaml 8 | - wandb.yaml 9 | - override /datamodule: cifar.yaml 10 | - override /trainer: ddp.yaml 11 | 12 | name: "cfm" 13 | seed: 42 14 | 15 | datamodule: 16 | batch_size: 128 17 | 18 | model: 19 | _target_: src.models.cfm_module.CFMLitModule 20 | sigma_min: 1e-4 21 | 22 | scheduler: 23 | _target_: timm.scheduler.PolyLRScheduler 24 | _partial_: True 25 | warmup_t: 200 26 | warmup_lr_init: 1e-8 27 | t_initial: 2000 28 | 29 | trainer: 30 | devices: 2 31 | max_epochs: 2000 32 | check_val_every_n_epoch: 10 33 | limit_val_batches: 0.01 34 | -------------------------------------------------------------------------------- /runner/configs/experiment/image_fm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: image_cfm.yaml 5 | - override /callbacks: no_stopping 6 | - override /logger: 7 | - csv.yaml 8 | - wandb.yaml 9 | - override /datamodule: cifar.yaml 10 | - override /trainer: ddp.yaml 11 | 12 | name: "cfm" 13 | seed: 42 14 | 15 | datamodule: 16 | batch_size: 128 17 | 18 | model: 19 | _target_: src.models.cfm_module.FMLitModule 20 | sigma_min: 1e-4 21 | 22 | scheduler: 23 | _target_: timm.scheduler.PolyLRScheduler 24 | _partial_: True 25 | warmup_t: 200 26 | warmup_lr_init: 1e-8 27 | t_initial: 2000 28 | 29 | trainer: 30 | devices: 2 31 | max_epochs: 2000 32 | check_val_every_n_epoch: 10 33 | limit_val_batches: 0.01 34 | -------------------------------------------------------------------------------- /runner/configs/experiment/image_otcfm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: image_cfm.yaml 5 | - override /callbacks: no_stopping 6 | - override /logger: 7 | - csv.yaml 8 | - wandb.yaml 9 | - override /datamodule: cifar.yaml 10 | - override /trainer: ddp.yaml 11 | 12 | name: "cfm" 13 | seed: 42 14 | 15 | datamodule: 16 | batch_size: 128 17 | 18 | model: 19 | _target_: src.models.cfm_module.CFMLitModule 20 | sigma_min: 1e-4 21 | 22 | scheduler: 23 | _target_: timm.scheduler.PolyLRScheduler 24 | _partial_: True 25 | warmup_t: 200 26 | warmup_lr_init: 1e-8 27 | t_initial: 2000 28 | ot_sampler: "exact" 29 | 30 | trainer: 31 | devices: 2 32 | max_epochs: 2000 33 | check_val_every_n_epoch: 10 34 | limit_val_batches: 0.01 35 | -------------------------------------------------------------------------------- /runner/configs/experiment/trajectorynet.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: trajectorynet.yaml 5 | - override /logger: 6 | - csv.yaml 7 | - wandb.yaml 8 | - override /datamodule: twodim.yaml 9 | 10 | name: "cnf" 11 | seed: 42 12 | 13 | datamodule: 14 | batch_size: 1024 15 | 16 | model: 17 | optimizer: 18 | weight_decay: 1e-5 19 | 20 | trainer: 21 | max_epochs: 1000 22 | check_val_every_n_epoch: 10 23 | -------------------------------------------------------------------------------- /runner/configs/extras/default.yaml: -------------------------------------------------------------------------------- 1 | # disable python warnings if they annoy you 2 | ignore_warnings: False 3 | 4 | # ask user for tags if none are provided in the config 5 | enforce_tags: True 6 | 7 | # pretty print config tree at the start of the run using Rich library 8 | print_config: True 9 | -------------------------------------------------------------------------------- /runner/configs/hparams_search/optuna.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # example hyperparameter optimization of some experiment with Optuna: 4 | # python train.py -m hparams_search=mnist_optuna experiment=example 5 | 6 | defaults: 7 | - override /hydra/sweeper: optuna 8 | 9 | # choose metric which will be optimized by Optuna 10 | # make sure this is the correct name of some metric logged in lightning module! 11 | optimized_metric: "val/2-Wasserstein" 12 | 13 | # here we define Optuna hyperparameter search 14 | # it optimizes for value returned from function with @hydra.main decorator 15 | # docs: https://hydra.cc/docs/next/plugins/optuna_sweeper 16 | hydra: 17 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached 18 | 19 | sweeper: 20 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper 21 | 22 | # storage URL to persist optimization results 23 | # for example, you can use SQLite if you set 'sqlite:///example.db' 24 | storage: null 25 | 26 | # name of the study to persist optimization results 27 | study_name: null 28 | 29 | # number of parallel workers 30 | n_jobs: 1 31 | 32 | # 'minimize' or 'maximize' the objective 33 | direction: maximize 34 | 35 | # total number of runs that will be executed 36 | n_trials: 20 37 | 38 | # choose Optuna hyperparameter sampler 39 | # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others 40 | # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html 41 | sampler: 42 | _target_: optuna.samplers.TPESampler 43 | seed: 1234 44 | n_startup_trials: 10 # number of random sampling runs before optimization starts 45 | 46 | # define hyperparameter search space 47 | params: 48 | model.optimizer.lr: interval(0.0001, 0.1) 49 | datamodule.batch_size: choice(32, 64, 128, 256) 50 | -------------------------------------------------------------------------------- /runner/configs/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # https://hydra.cc/docs/configure_hydra/intro/ 2 | 3 | # enable color logging 4 | defaults: 5 | - override hydra_logging: colorlog 6 | - override job_logging: colorlog 7 | 8 | # output directory, generated dynamically on each run 9 | run: 10 | dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S} 11 | sweep: 12 | dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S} 13 | subdir: ${hydra.job.num} 14 | job: 15 | chdir: true 16 | -------------------------------------------------------------------------------- /runner/configs/launcher/mila_cluster.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | # 3 | defaults: 4 | - override /hydra/launcher: submitit_slurm 5 | 6 | hydra: 7 | launcher: 8 | partition: long 9 | cpus_per_task: 2 10 | mem_gb: 20 11 | gres: gpu:1 12 | timeout_min: 1440 13 | array_parallelism: 10 # max num of tasks to run in parallel (via job array) 14 | setup: 15 | - "module purge" 16 | - "module load miniconda/3" 17 | - "conda activate myenv" 18 | - "unset CUDA_VISIBLE_DEVICES" 19 | -------------------------------------------------------------------------------- /runner/configs/launcher/mila_cpu_cluster.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | # 3 | defaults: 4 | - override /hydra/launcher: submitit_slurm 5 | 6 | hydra: 7 | launcher: 8 | partition: long-cpu 9 | cpus_per_task: 1 10 | mem_gb: 5 11 | timeout_min: 100 12 | array_parallelism: 64 13 | setup: 14 | - "module purge" 15 | - "module load miniconda/3" 16 | - "conda activate myenv" 17 | -------------------------------------------------------------------------------- /runner/configs/local/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atong01/conditional-flow-matching/3fd278f9ef2f02e17e107e5769130b6cb44803e2/runner/configs/local/.gitkeep -------------------------------------------------------------------------------- /runner/configs/local/default.yaml: -------------------------------------------------------------------------------- 1 | # path to root directory 2 | # this requires PROJECT_ROOT environment variable to exist 3 | # PROJECT_ROOT is inferred and set by pyrootutils package in `train.py` and `eval.py` 4 | root_dir: ${oc.env:PROJECT_ROOT} 5 | 6 | scratch_dir: ${oc.env:PROJECT_ROOT} 7 | 8 | # path to data directory 9 | data_dir: ${local.scratch_dir}/data/ 10 | 11 | # path to logging directory 12 | log_dir: ${local.scratch_dir}/logs/ 13 | -------------------------------------------------------------------------------- /runner/configs/logger/comet.yaml: -------------------------------------------------------------------------------- 1 | # https://www.comet.ml 2 | 3 | comet: 4 | _target_: pytorch_lightning.loggers.comet.CometLogger 5 | api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable 6 | save_dir: "${paths.output_dir}" 7 | project_name: "lightning-hydra-template" 8 | rest_api_key: null 9 | # experiment_name: "" 10 | experiment_key: null # set to resume experiment 11 | offline: False 12 | prefix: "" 13 | -------------------------------------------------------------------------------- /runner/configs/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # csv logger built in lightning 2 | 3 | csv: 4 | _target_: pytorch_lightning.loggers.csv_logs.CSVLogger 5 | save_dir: "${paths.output_dir}" 6 | name: "csv/" 7 | prefix: "" 8 | -------------------------------------------------------------------------------- /runner/configs/logger/many_loggers.yaml: -------------------------------------------------------------------------------- 1 | # train with many loggers at once 2 | 3 | defaults: 4 | # - comet.yaml 5 | - csv.yaml 6 | # - mlflow.yaml 7 | # - neptune.yaml 8 | - tensorboard.yaml 9 | - wandb.yaml 10 | -------------------------------------------------------------------------------- /runner/configs/logger/mlflow.yaml: -------------------------------------------------------------------------------- 1 | # https://mlflow.org 2 | 3 | mlflow: 4 | _target_: pytorch_lightning.loggers.mlflow.MLFlowLogger 5 | # experiment_name: "" 6 | # run_name: "" 7 | tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI 8 | tags: null 9 | # save_dir: "./mlruns" 10 | prefix: "" 11 | artifact_location: null 12 | # run_id: "" 13 | -------------------------------------------------------------------------------- /runner/configs/logger/neptune.yaml: -------------------------------------------------------------------------------- 1 | # https://neptune.ai 2 | 3 | neptune: 4 | _target_: pytorch_lightning.loggers.neptune.NeptuneLogger 5 | api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable 6 | project: username/lightning-hydra-template 7 | # name: "" 8 | log_model_checkpoints: True 9 | prefix: "" 10 | -------------------------------------------------------------------------------- /runner/configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # https://www.tensorflow.org/tensorboard/ 2 | 3 | tensorboard: 4 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger 5 | save_dir: "${paths.output_dir}/tensorboard/" 6 | name: null 7 | log_graph: False 8 | default_hp_metric: True 9 | prefix: "" 10 | # version: "" 11 | -------------------------------------------------------------------------------- /runner/configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: pytorch_lightning.loggers.wandb.WandbLogger 5 | # name: "" # name of the run (normally generated by wandb) 6 | save_dir: "${paths.output_dir}" 7 | offline: False 8 | id: null # pass correct id to resume experiment! 9 | anonymous: null # enable anonymous logging 10 | project: "conditional-flow-model" 11 | log_model: False # upload lightning ckpts 12 | prefix: "" # a string to put at the beginning of metric keys 13 | # entity: "" # set to name of your wandb team 14 | group: "" 15 | tags: [] 16 | job_type: "" 17 | -------------------------------------------------------------------------------- /runner/configs/model/cfm.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.cfm_module.CFMLitModule 2 | _partial_: true 3 | 4 | optimizer: 5 | _target_: torch.optim.AdamW 6 | _partial_: true 7 | lr: 0.001 8 | weight_decay: 1e-5 9 | 10 | net: 11 | _target_: src.models.components.simple_mlp.VelocityNet 12 | _partial_: true 13 | hidden_dims: [64, 64, 64] 14 | batch_norm: False 15 | activation: "selu" 16 | 17 | augmentations: 18 | _target_: src.models.components.augmentation.AugmentationModule 19 | cnf_estimator: null 20 | l1_reg: 0. 21 | l2_reg: 0. 22 | squared_l2_reg: 0. 23 | jacobian_frobenius_reg: 0. 24 | jacobian_diag_frobenius_reg: 0. 25 | jacobian_off_diag_frobenius_reg: 0. 26 | 27 | partial_solver: 28 | _target_: src.models.components.solver.FlowSolver 29 | _partial_: true 30 | ode_solver: "euler" 31 | atol: 1e-5 32 | rtol: 1e-5 33 | 34 | ot_sampler: null 35 | 36 | sigma_min: 0.1 37 | 38 | # Set to integer if want to train with left out timepoint 39 | leaveout_timepoint: -1 40 | 41 | plot: False 42 | -------------------------------------------------------------------------------- /runner/configs/model/cfm_v2.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.runner.CFMLitModule 2 | _partial_: true 3 | 4 | optimizer: 5 | _target_: torch.optim.AdamW 6 | _partial_: true 7 | lr: 0.001 8 | weight_decay: 1e-5 9 | 10 | net: 11 | _target_: src.models.components.simple_mlp.VelocityNet 12 | _partial_: true 13 | hidden_dims: [64, 64, 64] 14 | batch_norm: False 15 | activation: "selu" 16 | 17 | flow_matcher: 18 | _target_: torchcfm.ConditionalFlowMatcher 19 | sigma: 0.0 20 | 21 | solver: 22 | _target_: src.models.components.solver.FlowSolver 23 | _partial_: true 24 | ode_solver: "euler" 25 | atol: 1e-5 26 | rtol: 1e-5 27 | 28 | plot: True 29 | -------------------------------------------------------------------------------- /runner/configs/model/cnf.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.cfm_module.CNFLitModule 2 | _partial_: true 3 | 4 | optimizer: 5 | _target_: torch.optim.AdamW 6 | _partial_: true 7 | lr: 0.001 8 | weight_decay: 0.01 9 | 10 | net: 11 | _target_: src.models.components.simple_mlp.VelocityNet 12 | _partial_: true 13 | hidden_dims: [64, 64, 64] 14 | batch_norm: False 15 | activation: "selu" 16 | 17 | augmentations: 18 | _target_: src.models.components.augmentation.AugmentationModule 19 | cnf_estimator: "exact" 20 | l1_reg: 0. 21 | l2_reg: 0. 22 | squared_l2_reg: 0. 23 | jacobian_frobenius_reg: 0. 24 | jacobian_diag_frobenius_reg: 0. 25 | jacobian_off_diag_frobenius_reg: 0. 26 | 27 | # Set to integer if want to train with left out timepoint 28 | leaveout_timepoint: -1 29 | -------------------------------------------------------------------------------- /runner/configs/model/fm.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.cfm_module.FMLitModule 2 | _partial_: true 3 | 4 | optimizer: 5 | _target_: torch.optim.AdamW 6 | _partial_: true 7 | lr: 0.001 8 | weight_decay: 1e-5 9 | 10 | net: 11 | _target_: src.models.components.simple_mlp.VelocityNet 12 | _partial_: true 13 | hidden_dims: [64, 64, 64] 14 | batch_norm: False 15 | activation: "selu" 16 | 17 | augmentations: 18 | _target_: src.models.components.augmentation.AugmentationModule 19 | cnf_estimator: null 20 | l1_reg: 0. 21 | l2_reg: 0. 22 | squared_l2_reg: 0. 23 | jacobian_frobenius_reg: 0. 24 | jacobian_diag_frobenius_reg: 0. 25 | jacobian_off_diag_frobenius_reg: 0. 26 | 27 | partial_solver: 28 | _target_: src.models.components.solver.FlowSolver 29 | _partial_: true 30 | ode_solver: "euler" 31 | atol: 1e-5 32 | rtol: 1e-5 33 | 34 | sigma_min: 0.1 35 | 36 | # Set to integer if want to train with left out timepoint 37 | leaveout_timepoint: -1 38 | -------------------------------------------------------------------------------- /runner/configs/model/icnn.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.icnn_module.ICNNLitModule 2 | _partial_: true 3 | 4 | f_net: 5 | _target_: src.models.components.icnn_model.ICNN 6 | _partial_: true 7 | dimh: 64 8 | num_hidden_layers: 4 9 | 10 | g_net: 11 | _target_: src.models.components.icnn_model.ICNN 12 | _partial_: true 13 | dimh: 64 14 | num_hidden_layers: 4 15 | 16 | optimizer: 17 | _target_: torch.optim.Adam 18 | _partial_: true 19 | lr: 0.0001 20 | betas: [0.5, 0.9] 21 | 22 | reg: 0.1 23 | 24 | # Set to integer if want to train with left out timepoint 25 | leaveout_timepoint: -1 26 | -------------------------------------------------------------------------------- /runner/configs/model/image_cfm.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.cfm_module.CFMLitModule 2 | _partial_: true 3 | 4 | optimizer: 5 | _target_: torch.optim.Adam 6 | _partial_: true 7 | lr: 0.0005 8 | 9 | net: 10 | _target_: src.models.components.unet.UNetModelWrapper 11 | _partial_: true 12 | num_res_blocks: 2 13 | num_channels: 256 14 | channel_mult: [1, 2, 2, 2] 15 | num_heads: 4 16 | num_head_channels: 64 17 | attention_resolutions: "16" 18 | dropout: 0 19 | 20 | augmentations: 21 | _target_: src.models.components.augmentation.AugmentationModule 22 | cnf_estimator: null 23 | l1_reg: 0. 24 | l2_reg: 0. 25 | squared_l2_reg: 0. 26 | jacobian_frobenius_reg: 0. 27 | jacobian_diag_frobenius_reg: 0. 28 | jacobian_off_diag_frobenius_reg: 0. 29 | 30 | partial_solver: 31 | _target_: src.models.components.solver.FlowSolver 32 | _partial_: true 33 | ode_solver: "euler" 34 | atol: 1e-5 35 | rtol: 1e-5 36 | 37 | test_nfe: 100 38 | 39 | ot_sampler: null 40 | 41 | sigma_min: 0.1 42 | 43 | # Set to integer if want to train with left out timepoint 44 | leaveout_timepoint: -1 45 | 46 | plot: True 47 | -------------------------------------------------------------------------------- /runner/configs/model/otcfm.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.cfm_module.CFMLitModule 2 | _partial_: true 3 | 4 | optimizer: 5 | _target_: torch.optim.AdamW 6 | _partial_: true 7 | lr: 0.001 8 | weight_decay: 1e-5 9 | 10 | net: 11 | _target_: src.models.components.simple_mlp.VelocityNet 12 | _partial_: true 13 | hidden_dims: [64, 64, 64] 14 | batch_norm: False 15 | activation: "selu" 16 | 17 | augmentations: 18 | _target_: src.models.components.augmentation.AugmentationModule 19 | cnf_estimator: null 20 | l1_reg: 0. 21 | l2_reg: 0. 22 | squared_l2_reg: 0. 23 | jacobian_frobenius_reg: 0. 24 | jacobian_diag_frobenius_reg: 0. 25 | jacobian_off_diag_frobenius_reg: 0. 26 | 27 | partial_solver: 28 | _target_: src.models.components.solver.FlowSolver 29 | _partial_: true 30 | ode_solver: "euler" 31 | atol: 1e-5 32 | rtol: 1e-5 33 | 34 | ot_sampler: "exact" 35 | 36 | sigma_min: 0.1 37 | 38 | # Set to integer if want to train with left out timepoint 39 | leaveout_timepoint: -1 40 | -------------------------------------------------------------------------------- /runner/configs/model/sbcfm.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.cfm_module.SBCFMLitModule 2 | _partial_: true 3 | 4 | optimizer: 5 | _target_: torch.optim.AdamW 6 | _partial_: true 7 | lr: 0.001 8 | weight_decay: 1e-5 9 | 10 | net: 11 | _target_: src.models.components.simple_mlp.VelocityNet 12 | _partial_: true 13 | hidden_dims: [64, 64, 64] 14 | batch_norm: False 15 | activation: "selu" 16 | 17 | augmentations: 18 | _target_: src.models.components.augmentation.AugmentationModule 19 | cnf_estimator: null 20 | l1_reg: 0. 21 | l2_reg: 0. 22 | squared_l2_reg: 0. 23 | jacobian_frobenius_reg: 0. 24 | jacobian_diag_frobenius_reg: 0. 25 | jacobian_off_diag_frobenius_reg: 0. 26 | 27 | partial_solver: 28 | _target_: src.models.components.solver.FlowSolver 29 | _partial_: true 30 | ode_solver: "euler" 31 | atol: 1e-5 32 | rtol: 1e-5 33 | 34 | ot_sampler: "sinkhorn" 35 | 36 | sigma_min: 1.0 37 | 38 | # Set to integer if want to train with left out timepoint 39 | leaveout_timepoint: -1 40 | -------------------------------------------------------------------------------- /runner/configs/model/trajectorynet.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.cfm_module.CNFLitModule 2 | _partial_: true 3 | 4 | optimizer: 5 | _target_: torch.optim.AdamW 6 | _partial_: true 7 | lr: 0.001 8 | weight_decay: 0.01 9 | 10 | net: 11 | _target_: src.models.components.simple_mlp.VelocityNet 12 | _partial_: true 13 | hidden_dims: [64, 64, 64] 14 | batch_norm: False 15 | activation: "tanh" 16 | 17 | augmentations: 18 | _target_: src.models.components.augmentation.AugmentationModule 19 | cnf_estimator: "exact" 20 | l1_reg: 0. 21 | l2_reg: 0. 22 | squared_l2_reg: 1e-4 23 | jacobian_frobenius_reg: 1e-4 24 | jacobian_diag_frobenius_reg: 0. 25 | jacobian_off_diag_frobenius_reg: 0. 26 | 27 | # Set to integer if want to train with left out timepoint 28 | leaveout_timepoint: -1 29 | -------------------------------------------------------------------------------- /runner/configs/paths/default.yaml: -------------------------------------------------------------------------------- 1 | # path to root directory 2 | # this requires PROJECT_ROOT environment variable to exist 3 | # PROJECT_ROOT is inferred and set by pyrootutils package in `train.py` and `eval.py` 4 | root_dir: ${oc.env:PROJECT_ROOT} 5 | 6 | # path to data directory 7 | data_dir: ${local.data_dir} 8 | 9 | # path to logging directory 10 | log_dir: ${local.log_dir} 11 | 12 | # path to output directory, created dynamically by hydra 13 | # path generation pattern is specified in `configs/hydra/default.yaml` 14 | # use it to store all files generated during the run, like ckpts and metrics 15 | output_dir: ${hydra:runtime.output_dir} 16 | 17 | # path to working directory 18 | work_dir: ${hydra:runtime.cwd} 19 | -------------------------------------------------------------------------------- /runner/configs/train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default configuration 4 | # order of defaults determines the order in which configs override each other 5 | defaults: 6 | - _self_ 7 | - datamodule: sklearn 8 | - model: cfm 9 | - callbacks: default 10 | - logger: csv # set logger here or use command line (e.g. `python train.py logger=tensorboard`) 11 | - trainer: default 12 | - paths: default 13 | - extras: default 14 | - hydra: default 15 | - launcher: null 16 | 17 | # experiment configs allow for version control of specific hyperparameters 18 | # e.g. best hyperparameters for given model and datamodule 19 | - experiment: null 20 | 21 | # config for hyperparameter optimization 22 | - hparams_search: null 23 | 24 | # optional local config for machine/user specific settings 25 | # it's optional since it doesn't need to exist and is excluded from version control 26 | - optional local: default 27 | 28 | # debugging config (enable through command line, e.g. `python train.py debug=default) 29 | - debug: null 30 | 31 | # task name, determines output directory path 32 | task_name: "train" 33 | 34 | # tags to help you identify your experiments 35 | # you can overwrite this in experiment configs 36 | # overwrite from command line with `python train.py tags="[first_tag, second_tag]"` 37 | # appending lists from command line is currently not supported :( 38 | # https://github.com/facebookresearch/hydra/issues/1547 39 | tags: ["dev"] 40 | 41 | # set False to skip model training 42 | train: True 43 | 44 | # evaluate on test set, using best model weights achieved during training 45 | # lightning chooses best weights based on the metric specified in checkpoint callback 46 | test: True 47 | 48 | # simply provide checkpoint path to resume training 49 | ckpt_path: null 50 | 51 | # seed for random number generators in pytorch, numpy and python.random 52 | seed: null 53 | -------------------------------------------------------------------------------- /runner/configs/trainer/cpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | accelerator: cpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /runner/configs/trainer/ddp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | # use "ddp_spawn" instead of "ddp", 5 | # it's slower but normal "ddp" currently doesn't work ideally with hydra 6 | # https://github.com/facebookresearch/hydra/issues/2070 7 | # https://pytorch-lightning.readthedocs.io/en/latest/accelerators/gpu_intermediate.html#distributed-data-parallel-spawn 8 | strategy: ddp_spawn 9 | 10 | accelerator: gpu 11 | devices: 4 12 | num_nodes: 1 13 | sync_batchnorm: True 14 | -------------------------------------------------------------------------------- /runner/configs/trainer/ddp_sim.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | # simulate DDP on CPU, useful for debugging 5 | accelerator: cpu 6 | devices: 2 7 | strategy: ddp_spawn 8 | -------------------------------------------------------------------------------- /runner/configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | default_root_dir: ${paths.output_dir} 4 | 5 | min_epochs: 1 # prevents early stopping 6 | max_epochs: 10 7 | 8 | accelerator: cpu 9 | devices: 1 10 | 11 | # mixed precision for extra speed-up 12 | # precision: 16 13 | 14 | # perform a validation loop every N training epochs 15 | check_val_every_n_epoch: 1 16 | 17 | # set True to to ensure deterministic results 18 | # makes training slower but gives more reproducibility than just setting seeds 19 | deterministic: False 20 | -------------------------------------------------------------------------------- /runner/configs/trainer/gpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | accelerator: gpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /runner/configs/trainer/mps.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | accelerator: mps 5 | devices: 1 6 | -------------------------------------------------------------------------------- /runner/data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atong01/conditional-flow-matching/3fd278f9ef2f02e17e107e5769130b6cb44803e2/runner/data/.gitkeep -------------------------------------------------------------------------------- /runner/logs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atong01/conditional-flow-matching/3fd278f9ef2f02e17e107e5769130b6cb44803e2/runner/logs/.gitkeep -------------------------------------------------------------------------------- /runner/scripts/schedule.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Schedule execution of many runs 3 | # Run from root folder with: bash scripts/schedule.sh 4 | 5 | python src/train.py trainer.max_epochs=5 logger=csv 6 | 7 | python src/train.py trainer.max_epochs=10 logger=csv 8 | -------------------------------------------------------------------------------- /runner/scripts/two-dim-cfm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Compares flow matching (FM) conditional flow matching (CFM) and optimal 4 | # transport conditional flow matching on four datasets. twodim is not possible 5 | # for the flow matching algorithm as it has a non-gaussian source distribution. 6 | # FM is therefore only run on three datasets. 7 | python src/train.py -m experiment=cfm \ 8 | model=cfm,otcfm \ 9 | launcher=mila_cpu_cluster \ 10 | model.sigma_min=0.1 \ 11 | datamodule=scurve,moons,twodim,gaussians \ 12 | seed=42,43,44,45,46 & 13 | 14 | # Sleep to avoid launching jobs at the same time 15 | sleep 1 16 | python src/train.py -m experiment=cfm \ 17 | model=fm \ 18 | launcher=mila_cpu_cluster \ 19 | model.sigma_min=0.1 \ 20 | datamodule=scurve,moons,gaussians \ 21 | seed=42,43,44,45,46 & 22 | -------------------------------------------------------------------------------- /runner/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atong01/conditional-flow-matching/3fd278f9ef2f02e17e107e5769130b6cb44803e2/runner/src/__init__.py -------------------------------------------------------------------------------- /runner/src/datamodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atong01/conditional-flow-matching/3fd278f9ef2f02e17e107e5769130b6cb44803e2/runner/src/datamodules/__init__.py -------------------------------------------------------------------------------- /runner/src/datamodules/cifar10_datamodule.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Union 2 | 3 | import pl_bolts 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from torchvision import transforms as transform_lib 7 | 8 | 9 | class CIFAR10DataModule(pl_bolts.datamodules.cifar10_datamodule.CIFAR10DataModule): 10 | def __init__(self, *args, **kwargs): 11 | test_transforms = transform_lib.ToTensor() 12 | super().__init__(*args, test_transforms=test_transforms, **kwargs) 13 | 14 | def val_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]: 15 | """The val dataloader.""" 16 | return self._data_loader(self.dataset_train) 17 | -------------------------------------------------------------------------------- /runner/src/datamodules/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atong01/conditional-flow-matching/3fd278f9ef2f02e17e107e5769130b6cb44803e2/runner/src/datamodules/components/__init__.py -------------------------------------------------------------------------------- /runner/src/datamodules/components/base.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning import LightningDataModule 2 | from torch.utils.data import DataLoader 3 | 4 | 5 | class BaseLightningDataModule(LightningDataModule): 6 | """Adds base train, val, test dataloaders from data_train, data_val, and data_test.""" 7 | 8 | def train_dataloader(self): 9 | return DataLoader( 10 | dataset=self.data_train, 11 | batch_size=self.hparams.batch_size, 12 | num_workers=self.hparams.num_workers, 13 | pin_memory=self.hparams.pin_memory, 14 | shuffle=True, 15 | ) 16 | 17 | def val_dataloader(self): 18 | return DataLoader( 19 | dataset=self.data_val, 20 | batch_size=self.hparams.batch_size, 21 | num_workers=self.hparams.num_workers, 22 | pin_memory=self.hparams.pin_memory, 23 | shuffle=False, 24 | ) 25 | 26 | def test_dataloader(self): 27 | return DataLoader( 28 | dataset=self.data_test, 29 | batch_size=self.hparams.batch_size, 30 | num_workers=self.hparams.num_workers, 31 | pin_memory=self.hparams.pin_memory, 32 | shuffle=False, 33 | ) 34 | -------------------------------------------------------------------------------- /runner/src/datamodules/components/time_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scanpy as sc 3 | 4 | 5 | def adata_dataset(path, embed_name="X_pca", label_name="day", max_dim=100): 6 | adata = sc.read_h5ad(path) 7 | labels = adata.obs[label_name].astype("category") 8 | ulabels = labels.cat.categories 9 | return adata.obsm[embed_name][:, :max_dim], labels, ulabels 10 | 11 | 12 | def tnet_dataset(path, embed_name="pcs", label_name="sample_labels", max_dim=100): 13 | a = np.load(path, allow_pickle=True) 14 | return a[embed_name][:, :max_dim], a[label_name], np.unique(a[label_name]) 15 | 16 | 17 | def load_dataset(path, max_dim=100): 18 | if path.endswith("h5ad"): 19 | return adata_dataset(path, max_dim=max_dim) 20 | if path.endswith("npz"): 21 | return tnet_dataset(path, max_dim=max_dim) 22 | raise NotImplementedError() 23 | -------------------------------------------------------------------------------- /runner/src/datamodules/components/two_dim.py: -------------------------------------------------------------------------------- 1 | # Adapted from From DSB 2 | # https://github.com/JTT94/diffusion_schrodinger_bridge/blob/main/bridge/data/two_dim.py 3 | import numpy as np 4 | import torch 5 | from sklearn import datasets 6 | from torch.utils.data import TensorDataset 7 | 8 | # checker/pinwheel/8gaussians can be found at 9 | # https://github.com/rtqichen/ffjord/blob/994864ad0517db3549717c25170f9b71e96788b1/lib/toy_data.py#L8 10 | 11 | 12 | def data_distrib(npar, data, random_state=42): 13 | np.random.seed(random_state) 14 | 15 | if data == "mixture": 16 | init_sample = torch.randn(npar, 2) 17 | p = init_sample.shape[0] // 2 18 | init_sample[:p, 0] = init_sample[:p, 0] - 7.0 19 | init_sample[p:, 0] = init_sample[p:, 0] + 7.0 20 | 21 | if data == "scurve": 22 | X, y = datasets.make_s_curve(n_samples=npar, noise=0.1, random_state=None) 23 | init_sample = torch.tensor(X)[:, [0, 2]] 24 | scaling_factor = 7 25 | init_sample = (init_sample - init_sample.mean()) / init_sample.std() * scaling_factor 26 | 27 | if data == "swiss": 28 | X, y = datasets.make_swiss_roll(n_samples=npar, noise=0.1, random_state=None) 29 | init_sample = torch.tensor(X)[:, [0, 2]] 30 | scaling_factor = 7 31 | init_sample = (init_sample - init_sample.mean()) / init_sample.std() * scaling_factor 32 | 33 | if data == "moon": 34 | X, y = datasets.make_moons(n_samples=npar, noise=0.1, random_state=None) 35 | scaling_factor = 7.0 36 | init_sample = torch.tensor(X) 37 | init_sample = (init_sample - init_sample.mean()) / init_sample.std() * scaling_factor 38 | 39 | if data == "circle": 40 | X, y = datasets.make_circles(n_samples=npar, noise=0.0, random_state=None, factor=0.5) 41 | init_sample = torch.tensor(X) * 10 42 | 43 | if data == "checker": 44 | x1 = np.random.rand(npar) * 4 - 2 45 | x2_ = np.random.rand(npar) - np.random.randint(0, 2, npar) * 2 46 | x2 = x2_ + (np.floor(x1) % 2) 47 | x = np.concatenate([x1[:, None], x2[:, None]], 1) * 7.5 48 | init_sample = torch.from_numpy(x) 49 | 50 | if data == "pinwheel": 51 | radial_std = 0.3 52 | tangential_std = 0.1 53 | num_classes = 5 54 | num_per_class = npar // 5 55 | rate = 0.25 56 | rads = np.linspace(0, 2 * np.pi, num_classes, endpoint=False) 57 | 58 | features = np.random.randn(num_classes * num_per_class, 2) * np.array( 59 | [radial_std, tangential_std] 60 | ) 61 | features[:, 0] += 1.0 62 | labels = np.repeat(np.arange(num_classes), num_per_class) 63 | 64 | angles = rads[labels] + rate * np.exp(features[:, 0]) 65 | rotations = np.stack([np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)]) 66 | rotations = np.reshape(rotations.T, (-1, 2, 2)) 67 | x = 7.5 * np.random.permutation(np.einsum("ti,tij->tj", features, rotations)) 68 | init_sample = torch.from_numpy(x) 69 | 70 | if data == "8gaussians": 71 | scale = 4.0 72 | centers = [ 73 | (1, 0), 74 | (-1, 0), 75 | (0, 1), 76 | (0, -1), 77 | (1.0 / np.sqrt(2), 1.0 / np.sqrt(2)), 78 | (1.0 / np.sqrt(2), -1.0 / np.sqrt(2)), 79 | (-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)), 80 | (-1.0 / np.sqrt(2), -1.0 / np.sqrt(2)), 81 | ] 82 | centers = [(scale * x, scale * y) for x, y in centers] 83 | 84 | dataset = [] 85 | for i in range(npar): 86 | point = np.random.randn(2) * 0.5 87 | idx = np.random.randint(8) 88 | center = centers[idx] 89 | point[0] += center[0] 90 | point[1] += center[1] 91 | dataset.append(point) 92 | dataset = np.array(dataset, dtype="float32") 93 | dataset *= 3 94 | init_sample = torch.from_numpy(dataset) 95 | 96 | init_sample = init_sample.float() 97 | 98 | return init_sample 99 | 100 | 101 | def two_dim_ds(npar, data_tag): 102 | init_sample = data_distrib(npar, data_tag) 103 | init_ds = TensorDataset(init_sample) 104 | return init_ds 105 | -------------------------------------------------------------------------------- /runner/src/eval.py: -------------------------------------------------------------------------------- 1 | import pyrootutils 2 | 3 | root = pyrootutils.setup_root( 4 | search_from=__file__, 5 | indicator=[".git", "pyproject.toml"], 6 | pythonpath=True, 7 | dotenv=True, 8 | ) 9 | 10 | # ------------------------------------------------------------------------------------ # 11 | # `pyrootutils.setup_root(...)` above is optional line to make environment more convenient 12 | # should be placed at the top of each entry file 13 | # 14 | # main advantages: 15 | # - allows you to keep all entry files in "src/" without installing project as a package 16 | # - launching python file works no matter where is your current work dir 17 | # - automatically loads environment variables from ".env" if exists 18 | # 19 | # how it works: 20 | # - `setup_root()` above recursively searches for either ".git" or "pyproject.toml" in present 21 | # and parent dirs, to determine the project root dir 22 | # - adds root dir to the PYTHONPATH (if `pythonpath=True`), so this file can be run from 23 | # any place without installing project as a package 24 | # - sets PROJECT_ROOT environment variable which is used in "configs/paths/default.yaml" 25 | # to make all paths always relative to project root 26 | # - loads environment variables from ".env" in root dir (if `dotenv=True`) 27 | # 28 | # you can remove `pyrootutils.setup_root(...)` if you: 29 | # 1. either install project as a package or move each entry file to the project root dir 30 | # 2. remove PROJECT_ROOT variable from paths in "configs/paths/default.yaml" 31 | # 32 | # https://github.com/ashleve/pyrootutils 33 | # ------------------------------------------------------------------------------------ # 34 | 35 | from typing import List, Tuple 36 | 37 | import hydra 38 | from omegaconf import DictConfig 39 | from pytorch_lightning import LightningDataModule, LightningModule, Trainer 40 | from pytorch_lightning.loggers import LightningLoggerBase 41 | 42 | from src import utils 43 | 44 | log = utils.get_pylogger(__name__) 45 | 46 | 47 | @utils.task_wrapper 48 | def evaluate(cfg: DictConfig) -> Tuple[dict, dict]: 49 | """Evaluates given checkpoint on a datamodule testset. 50 | 51 | This method is wrapped in optional @task_wrapper decorator which applies extra utilities 52 | before and after the call. 53 | 54 | Args: 55 | cfg (DictConfig): Configuration composed by Hydra. 56 | 57 | Returns: 58 | Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. 59 | """ 60 | 61 | assert cfg.ckpt_path 62 | 63 | log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>") 64 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.datamodule) 65 | 66 | log.info(f"Instantiating model <{cfg.model._target_}>") 67 | if hasattr(datamodule, "pass_to_model"): 68 | log.info("Passing full datamodule to model") 69 | model: LightningModule = hydra.utils.instantiate(cfg.model)(datamodule=datamodule) 70 | else: 71 | if hasattr(datamodule, "dim"): 72 | log.info("Passing datamodule.dim to model") 73 | model: LightningModule = hydra.utils.instantiate(cfg.model)(dim=datamodule.dim) 74 | else: 75 | model: LightningModule = hydra.utils.instantiate(cfg.model) 76 | 77 | log.info("Instantiating loggers...") 78 | logger: List[LightningLoggerBase] = utils.instantiate_loggers(cfg.get("logger")) 79 | 80 | log.info(f"Instantiating trainer <{cfg.trainer._target_}>") 81 | trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger) 82 | 83 | object_dict = { 84 | "cfg": cfg, 85 | "datamodule": datamodule, 86 | "model": model, 87 | "logger": logger, 88 | "trainer": trainer, 89 | } 90 | 91 | if logger: 92 | log.info("Logging hyperparameters!") 93 | utils.log_hyperparameters(object_dict) 94 | 95 | log.info("Starting testing!") 96 | trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path) 97 | 98 | # for predictions use trainer.predict(...) 99 | # predictions = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path) 100 | 101 | metric_dict = trainer.callback_metrics 102 | 103 | return metric_dict, object_dict 104 | 105 | 106 | @hydra.main(version_base="1.2", config_path=root / "configs", config_name="eval.yaml") 107 | def main(cfg: DictConfig) -> None: 108 | evaluate(cfg) 109 | 110 | 111 | if __name__ == "__main__": 112 | main() 113 | -------------------------------------------------------------------------------- /runner/src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atong01/conditional-flow-matching/3fd278f9ef2f02e17e107e5769130b6cb44803e2/runner/src/models/__init__.py -------------------------------------------------------------------------------- /runner/src/models/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atong01/conditional-flow-matching/3fd278f9ef2f02e17e107e5769130b6cb44803e2/runner/src/models/components/__init__.py -------------------------------------------------------------------------------- /runner/src/models/components/distribution_distances.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Union 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from .mmd import linear_mmd2, mix_rbf_mmd2, poly_mmd2 8 | from .optimal_transport import wasserstein 9 | 10 | 11 | def compute_distances(pred, true): 12 | """Computes distances between vectors.""" 13 | mse = torch.nn.functional.mse_loss(pred, true).item() 14 | me = math.sqrt(mse) 15 | mae = torch.mean(torch.abs(pred - true)).item() 16 | return mse, me, mae 17 | 18 | 19 | def compute_distribution_distances(pred: torch.Tensor, true: Union[torch.Tensor, list]): 20 | """computes distances between distributions. 21 | pred: [batch, times, dims] tensor 22 | true: [batch, times, dims] tensor or list[batch[i], dims] of length times 23 | 24 | This handles jagged times as a list of tensors. 25 | """ 26 | NAMES = [ 27 | "1-Wasserstein", 28 | "2-Wasserstein", 29 | "Linear_MMD", 30 | "Poly_MMD", 31 | "RBF_MMD", 32 | "Mean_MSE", 33 | "Mean_L2", 34 | "Mean_L1", 35 | "Median_MSE", 36 | "Median_L2", 37 | "Median_L1", 38 | ] 39 | is_jagged = isinstance(true, list) 40 | pred_is_jagged = isinstance(pred, list) 41 | dists = [] 42 | to_return = [] 43 | names = [] 44 | filtered_names = [name for name in NAMES if not is_jagged or not name.endswith("MMD")] 45 | ts = len(pred) if pred_is_jagged else pred.shape[1] 46 | for t in np.arange(ts): 47 | if pred_is_jagged: 48 | a = pred[t] 49 | else: 50 | a = pred[:, t, :] 51 | if is_jagged: 52 | b = true[t] 53 | else: 54 | b = true[:, t, :] 55 | w1 = wasserstein(a, b, power=1) 56 | w2 = wasserstein(a, b, power=2) 57 | if not pred_is_jagged and not is_jagged: 58 | mmd_linear = linear_mmd2(a, b).item() 59 | mmd_poly = poly_mmd2(a, b, d=2, alpha=1.0, c=2.0).item() 60 | mmd_rbf = mix_rbf_mmd2(a, b, sigma_list=[0.01, 0.1, 1, 10, 100]).item() 61 | mean_dists = compute_distances(torch.mean(a, dim=0), torch.mean(b, dim=0)) 62 | median_dists = compute_distances(torch.median(a, dim=0)[0], torch.median(b, dim=0)[0]) 63 | if pred_is_jagged or is_jagged: 64 | dists.append((w1, w2, *mean_dists, *median_dists)) 65 | else: 66 | dists.append((w1, w2, mmd_linear, mmd_poly, mmd_rbf, *mean_dists, *median_dists)) 67 | # For multipoint datasets add timepoint specific distances 68 | if ts > 1: 69 | names.extend([f"t{t+1}/{name}" for name in filtered_names]) 70 | to_return.extend(dists[-1]) 71 | 72 | to_return.extend(np.array(dists).mean(axis=0)) 73 | names.extend(filtered_names) 74 | return names, to_return 75 | -------------------------------------------------------------------------------- /runner/src/models/components/emd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import ot as pot # Python Optimal Transport package 3 | import scipy.sparse 4 | from sklearn.metrics.pairwise import pairwise_distances 5 | 6 | 7 | def earth_mover_distance( 8 | p, 9 | q, 10 | eigenvals=None, 11 | weights1=None, 12 | weights2=None, 13 | return_matrix=False, 14 | metric="sqeuclidean", 15 | ): 16 | """Returns the earth mover's distance between two point clouds. 17 | 18 | Parameters 19 | ---------- 20 | cloud1 : 2-D array 21 | First point cloud 22 | cloud2 : 2-D array 23 | Second point cloud 24 | Returns 25 | ------- 26 | distance : float 27 | The distance between the two point clouds 28 | """ 29 | p = p.toarray() if scipy.sparse.isspmatrix(p) else p 30 | q = q.toarray() if scipy.sparse.isspmatrix(q) else q 31 | if eigenvals is not None: 32 | p = p.dot(eigenvals) 33 | q = q.dot(eigenvals) 34 | if weights1 is None: 35 | p_weights = np.ones(len(p)) / len(p) 36 | else: 37 | weights1 = weights1.astype("float64") 38 | p_weights = weights1 / weights1.sum() 39 | 40 | if weights2 is None: 41 | q_weights = np.ones(len(q)) / len(q) 42 | else: 43 | weights2 = weights2.astype("float64") 44 | q_weights = weights2 / weights2.sum() 45 | 46 | pairwise_dist = np.ascontiguousarray(pairwise_distances(p, Y=q, metric=metric, n_jobs=-1)) 47 | 48 | result = pot.emd2( 49 | p_weights, q_weights, pairwise_dist, numItermax=1e7, return_matrix=return_matrix 50 | ) 51 | if return_matrix: 52 | square_emd, log_dict = result 53 | return np.sqrt(square_emd), log_dict 54 | else: 55 | return np.sqrt(result) 56 | 57 | 58 | def interpolate_with_ot(p0, p1, tmap, interp_frac, size): 59 | """Interpolate between p0 and p1 at fraction t_interpolate knowing a transport map from p0 to 60 | p1. 61 | 62 | Parameters 63 | ---------- 64 | p0 : 2-D array 65 | The genes of each cell in the source population 66 | p1 : 2-D array 67 | The genes of each cell in the destination population 68 | tmap : 2-D array 69 | A transport map from p0 to p1 70 | t_interpolate : float 71 | The fraction at which to interpolate 72 | size : int 73 | The number of cells in the interpolated population 74 | Returns 75 | ------- 76 | p05 : 2-D array 77 | An interpolated population of 'size' cells 78 | """ 79 | p0 = p0.toarray() if scipy.sparse.isspmatrix(p0) else p0 80 | p1 = p1.toarray() if scipy.sparse.isspmatrix(p1) else p1 81 | p0 = np.asarray(p0, dtype=np.float64) 82 | p1 = np.asarray(p1, dtype=np.float64) 83 | tmap = np.asarray(tmap, dtype=np.float64) 84 | if p0.shape[1] != p1.shape[1]: 85 | raise ValueError("Unable to interpolate. Number of genes do not match") 86 | if p0.shape[0] != tmap.shape[0] or p1.shape[0] != tmap.shape[1]: 87 | raise ValueError( 88 | "Unable to interpolate. Tmap size is {}, expected {}".format( 89 | tmap.shape, (len(p0), len(p1)) 90 | ) 91 | ) 92 | I = len(p0) 93 | J = len(p1) 94 | # Assume growth is exponential and retrieve growth rate at t_interpolate 95 | # If all sums are the same then this does not change anything 96 | # This only matters if sum is not the same for all rows 97 | p = tmap / np.power(tmap.sum(axis=0), 1.0 - interp_frac) 98 | p = p.flatten(order="C") 99 | p = p / p.sum() 100 | choices = np.random.choice(I * J, p=p, size=size) 101 | return np.asarray( 102 | [p0[i // J] * (1 - interp_frac) + p1[i % J] * interp_frac for i in choices], 103 | dtype=np.float64, 104 | ) 105 | 106 | 107 | def interpolate_per_point_with_ot(p0, p1, tmap, interp_frac): 108 | """Interpolate between p0 and p1 at fraction t_interpolate knowing a transport map from p0 to 109 | p1. 110 | 111 | Parameters 112 | ---------- 113 | p0 : 2-D array 114 | The genes of each cell in the source population 115 | p1 : 2-D array 116 | The genes of each cell in the destination population 117 | tmap : 2-D array 118 | A transport map from p0 to p1 119 | t_interpolate : float 120 | The fraction at which to interpolate 121 | Returns 122 | ------- 123 | p05 : 2-D array 124 | An interpolated population of 'size' cells 125 | """ 126 | assert len(p0) == len(p1) 127 | p0 = p0.toarray() if scipy.sparse.isspmatrix(p0) else p0 128 | p1 = p1.toarray() if scipy.sparse.isspmatrix(p1) else p1 129 | p0 = np.asarray(p0, dtype=np.float64) 130 | p1 = np.asarray(p1, dtype=np.float64) 131 | tmap = np.asarray(tmap, dtype=np.float64) 132 | if p0.shape[1] != p1.shape[1]: 133 | raise ValueError("Unable to interpolate. Number of genes do not match") 134 | if p0.shape[0] != tmap.shape[0] or p1.shape[0] != tmap.shape[1]: 135 | raise ValueError( 136 | "Unable to interpolate. Tmap size is {}, expected {}".format( 137 | tmap.shape, (len(p0), len(p1)) 138 | ) 139 | ) 140 | 141 | I = len(p0) 142 | J = len(p1) 143 | # Assume growth is exponential and retrieve growth rate at t_interpolate 144 | # If all sums are the same then this does not change anything 145 | # This only matters if sum is not the same for all rows 146 | p = tmap / (tmap.sum(axis=0) / 1.0 - interp_frac) 147 | # p = tmap / np.power(tmap.sum(axis=0), 1.0 - interp_frac) 148 | # p = p.flatten(order="C") 149 | p = p / p.sum(axis=0) 150 | choices = np.array([np.random.choice(I, p=p[i]) for i in range(I)]) 151 | return np.asarray( 152 | [p0[i] * (1 - interp_frac) + p1[j] * interp_frac for i, j in enumerate(choices)], 153 | dtype=np.float64, 154 | ) 155 | -------------------------------------------------------------------------------- /runner/src/models/components/icnn_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class ICNN(torch.nn.Module): 6 | """Input Convex Neural Network.""" 7 | 8 | def __init__(self, dim=2, dimh=64, num_hidden_layers=4): 9 | super().__init__() 10 | 11 | Wzs = [] 12 | Wzs.append(nn.Linear(dim, dimh)) 13 | for _ in range(num_hidden_layers - 1): 14 | Wzs.append(torch.nn.Linear(dimh, dimh, bias=False)) 15 | Wzs.append(torch.nn.Linear(dimh, 1, bias=False)) 16 | self.Wzs = torch.nn.ModuleList(Wzs) 17 | 18 | Wxs = [] 19 | for _ in range(num_hidden_layers - 1): 20 | Wxs.append(nn.Linear(dim, dimh)) 21 | Wxs.append(nn.Linear(dim, 1, bias=False)) 22 | self.Wxs = torch.nn.ModuleList(Wxs) 23 | self.act = nn.Softplus() 24 | 25 | def forward(self, x): 26 | z = self.act(self.Wzs[0](x)) 27 | for Wz, Wx in zip(self.Wzs[1:-1], self.Wxs[:-1]): 28 | z = self.act(Wz(z) + Wx(x)) 29 | return self.Wzs[-1](z) + self.Wxs[-1](x) 30 | -------------------------------------------------------------------------------- /runner/src/models/components/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atong01/conditional-flow-matching/3fd278f9ef2f02e17e107e5769130b6cb44803e2/runner/src/models/components/layers/__init__.py -------------------------------------------------------------------------------- /runner/src/models/components/layers/diffeq_layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic import * 2 | from .container import * 3 | from .resnet import * 4 | from .wrappers import * 5 | -------------------------------------------------------------------------------- /runner/src/models/components/layers/diffeq_layers/container.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .wrappers import diffeq_wrapper 5 | 6 | 7 | class SequentialDiffEq(nn.Module): 8 | """A container for a sequential chain of layers. 9 | 10 | Supports both regular and diffeq layers. 11 | """ 12 | 13 | def __init__(self, *layers): 14 | super().__init__() 15 | self.layers = nn.ModuleList([diffeq_wrapper(layer) for layer in layers]) 16 | 17 | def forward(self, t, x): 18 | for layer in self.layers: 19 | x = layer(t, x) 20 | return x 21 | 22 | 23 | class MixtureODELayer(nn.Module): 24 | """Produces a mixture of experts where output = sigma(t) * f(t, x). 25 | 26 | Time-dependent weights sigma(t) help learn to blend the experts without resorting to a highly 27 | stiff f. Supports both regular and diffeq experts. 28 | """ 29 | 30 | def __init__(self, experts): 31 | super().__init__() 32 | assert len(experts) > 1 33 | wrapped_experts = [diffeq_wrapper(ex) for ex in experts] 34 | self.experts = nn.ModuleList(wrapped_experts) 35 | self.mixture_weights = nn.Linear(1, len(self.experts)) 36 | 37 | def forward(self, t, y): 38 | dys = [] 39 | for f in self.experts: 40 | dys.append(f(t, y)) 41 | dys = torch.stack(dys, 0) 42 | weights = self.mixture_weights(t).view(-1, *([1] * (dys.ndimension() - 1))) 43 | 44 | dy = torch.sum(dys * weights, dim=0, keepdim=False) 45 | return dy 46 | -------------------------------------------------------------------------------- /runner/src/models/components/layers/diffeq_layers/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from . import basic, container 4 | 5 | NGROUPS = 16 6 | 7 | 8 | class ResNet(container.SequentialDiffEq): 9 | def __init__(self, dim, intermediate_dim, n_resblocks, conv_block=None): 10 | super().__init__() 11 | 12 | if conv_block is None: 13 | conv_block = basic.ConcatCoordConv2d 14 | 15 | self.dim = dim 16 | self.intermediate_dim = intermediate_dim 17 | self.n_resblocks = n_resblocks 18 | 19 | layers = [] 20 | layers.append(conv_block(dim, intermediate_dim, ksize=3, stride=1, padding=1, bias=False)) 21 | for _ in range(n_resblocks): 22 | layers.append(BasicBlock(intermediate_dim, conv_block)) 23 | layers.append(nn.GroupNorm(NGROUPS, intermediate_dim, eps=1e-4)) 24 | layers.append(nn.ReLU(inplace=True)) 25 | layers.append(conv_block(intermediate_dim, dim, ksize=1, bias=False)) 26 | 27 | super().__init__(*layers) 28 | 29 | def __repr__(self): 30 | return ( 31 | "{name}({dim}, intermediate_dim={intermediate_dim}, n_resblocks={n_resblocks})".format( 32 | name=self.__class__.__name__, **self.__dict__ 33 | ) 34 | ) 35 | 36 | 37 | class BasicBlock(nn.Module): 38 | expansion = 1 39 | 40 | def __init__(self, dim, conv_block=None): 41 | super().__init__() 42 | 43 | if conv_block is None: 44 | conv_block = basic.ConcatCoordConv2d 45 | 46 | self.norm1 = nn.GroupNorm(NGROUPS, dim, eps=1e-4) 47 | self.relu1 = nn.ReLU(inplace=True) 48 | self.conv1 = conv_block(dim, dim, ksize=3, stride=1, padding=1, bias=False) 49 | self.norm2 = nn.GroupNorm(NGROUPS, dim, eps=1e-4) 50 | self.relu2 = nn.ReLU(inplace=True) 51 | self.conv2 = conv_block(dim, dim, ksize=3, stride=1, padding=1, bias=False) 52 | 53 | def forward(self, t, x): 54 | residual = x 55 | 56 | out = self.norm1(x) 57 | out = self.relu1(out) 58 | out = self.conv1(t, out) 59 | 60 | out = self.norm2(out) 61 | out = self.relu2(out) 62 | out = self.conv2(t, out) 63 | 64 | out += residual 65 | 66 | return out 67 | -------------------------------------------------------------------------------- /runner/src/models/components/layers/diffeq_layers/wrappers.py: -------------------------------------------------------------------------------- 1 | from inspect import signature 2 | 3 | import torch.nn as nn 4 | 5 | __all__ = ["diffeq_wrapper", "reshape_wrapper"] 6 | 7 | 8 | class DiffEqWrapper(nn.Module): 9 | def __init__(self, module): 10 | super().__init__() 11 | self.module = module 12 | if len(signature(self.module.forward).parameters) == 1: 13 | self.diffeq = lambda t, y: self.module(y) 14 | elif len(signature(self.module.forward).parameters) == 2: 15 | self.diffeq = self.module 16 | else: 17 | raise ValueError("Differential equation needs to either take (t, y) or (y,) as input.") 18 | 19 | def forward(self, t, y): 20 | return self.diffeq(t, y) 21 | 22 | def __repr__(self): 23 | return self.diffeq.__repr__() 24 | 25 | 26 | def diffeq_wrapper(layer): 27 | return DiffEqWrapper(layer) 28 | 29 | 30 | class ReshapeDiffEq(nn.Module): 31 | def __init__(self, input_shape, net): 32 | super().__init__() 33 | assert ( 34 | len(signature(net.forward).parameters) == 2 35 | ), "use diffeq_wrapper before reshape_wrapper." 36 | self.input_shape = input_shape 37 | self.net = net 38 | 39 | def forward(self, t, x): 40 | batchsize = x.shape[0] 41 | x = x.view(batchsize, *self.input_shape) 42 | return self.net(t, x).view(batchsize, -1) 43 | 44 | def __repr__(self): 45 | return self.diffeq.__repr__() 46 | 47 | 48 | def reshape_wrapper(input_shape, layer): 49 | return ReshapeDiffEq(input_shape, layer) 50 | -------------------------------------------------------------------------------- /runner/src/models/components/layers/squeeze.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | __all__ = ["SqueezeLayer"] 4 | 5 | 6 | class SqueezeLayer(nn.Module): 7 | def __init__(self, downscale_factor): 8 | super().__init__() 9 | self.downscale_factor = downscale_factor 10 | 11 | def forward(self, x, logpx=None, reverse=False): 12 | if reverse: 13 | return self._upsample(x, logpx) 14 | else: 15 | return self._downsample(x, logpx) 16 | 17 | def _downsample(self, x, logpx=None): 18 | squeeze_x = squeeze(x, self.downscale_factor) 19 | if logpx is None: 20 | return squeeze_x 21 | else: 22 | return squeeze_x, logpx 23 | 24 | def _upsample(self, y, logpy=None): 25 | unsqueeze_y = unsqueeze(y, self.downscale_factor) 26 | if logpy is None: 27 | return unsqueeze_y 28 | else: 29 | return unsqueeze_y, logpy 30 | 31 | 32 | def unsqueeze(input, upscale_factor=2): 33 | """[:, C*r^2, H, W] -> [:, C, H*r, W*r]""" 34 | batch_size, in_channels, in_height, in_width = input.size() 35 | out_channels = in_channels // (upscale_factor**2) 36 | 37 | out_height = in_height * upscale_factor 38 | out_width = in_width * upscale_factor 39 | 40 | input_view = input.contiguous().view( 41 | batch_size, out_channels, upscale_factor, upscale_factor, in_height, in_width 42 | ) 43 | 44 | output = input_view.permute(0, 1, 4, 2, 5, 3).contiguous() 45 | return output.view(batch_size, out_channels, out_height, out_width) 46 | 47 | 48 | def squeeze(input, downscale_factor=2): 49 | """[:, C, H*r, W*r] -> [:, C*r^2, H, W]""" 50 | batch_size, in_channels, in_height, in_width = input.size() 51 | out_channels = in_channels * (downscale_factor**2) 52 | 53 | out_height = in_height // downscale_factor 54 | out_width = in_width // downscale_factor 55 | 56 | input_view = input.contiguous().view( 57 | batch_size, 58 | in_channels, 59 | out_height, 60 | downscale_factor, 61 | out_width, 62 | downscale_factor, 63 | ) 64 | 65 | output = input_view.permute(0, 1, 3, 5, 2, 4).contiguous() 66 | return output.view(batch_size, out_channels, out_height, out_width) 67 | -------------------------------------------------------------------------------- /runner/src/models/components/mmd.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | 4 | import torch 5 | 6 | min_var_est = 1e-8 7 | 8 | 9 | # Consider linear time MMD with a linear kernel: 10 | # K(f(x), f(y)) = f(x)^Tf(y) 11 | # h(z_i, z_j) = k(x_i, x_j) + k(y_i, y_j) - k(x_i, y_j) - k(x_j, y_i) 12 | # = [f(x_i) - f(y_i)]^T[f(x_j) - f(y_j)] 13 | # 14 | # f_of_X: batch_size * k 15 | # f_of_Y: batch_size * k 16 | def linear_mmd2(f_of_X, f_of_Y): 17 | loss = 0.0 18 | delta = f_of_X - f_of_Y 19 | loss = torch.mean((delta[:-1] * delta[1:]).sum(1)) 20 | return loss 21 | 22 | 23 | # Consider linear time MMD with a polynomial kernel: 24 | # K(f(x), f(y)) = (alpha*f(x)^Tf(y) + c)^d 25 | # f_of_X: batch_size * k 26 | # f_of_Y: batch_size * k 27 | def poly_mmd2(f_of_X, f_of_Y, d=2, alpha=1.0, c=2.0): 28 | K_XX = alpha * (f_of_X[:-1] * f_of_X[1:]).sum(1) + c 29 | K_XX_mean = torch.mean(K_XX.pow(d)) 30 | 31 | K_YY = alpha * (f_of_Y[:-1] * f_of_Y[1:]).sum(1) + c 32 | K_YY_mean = torch.mean(K_YY.pow(d)) 33 | 34 | K_XY = alpha * (f_of_X[:-1] * f_of_Y[1:]).sum(1) + c 35 | K_XY_mean = torch.mean(K_XY.pow(d)) 36 | 37 | K_YX = alpha * (f_of_Y[:-1] * f_of_X[1:]).sum(1) + c 38 | K_YX_mean = torch.mean(K_YX.pow(d)) 39 | 40 | return K_XX_mean + K_YY_mean - K_XY_mean - K_YX_mean 41 | 42 | 43 | def _mix_rbf_kernel(X, Y, sigma_list): 44 | assert X.size(0) == Y.size(0) 45 | m = X.size(0) 46 | 47 | Z = torch.cat((X, Y), 0) 48 | ZZT = torch.mm(Z, Z.t()) 49 | diag_ZZT = torch.diag(ZZT).unsqueeze(1) 50 | Z_norm_sqr = diag_ZZT.expand_as(ZZT) 51 | exponent = Z_norm_sqr - 2 * ZZT + Z_norm_sqr.t() 52 | 53 | K = 0.0 54 | for sigma in sigma_list: 55 | gamma = 1.0 / (2 * sigma**2) 56 | K += torch.exp(-gamma * exponent) 57 | 58 | return K[:m, :m], K[:m, m:], K[m:, m:], len(sigma_list) 59 | 60 | 61 | def mix_rbf_mmd2(X, Y, sigma_list, biased=True): 62 | K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list) 63 | # return _mmd2(K_XX, K_XY, K_YY, const_diagonal=d, biased=biased) 64 | return _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased) 65 | 66 | 67 | def mix_rbf_mmd2_and_ratio(X, Y, sigma_list, biased=True): 68 | K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list) 69 | # return _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=d, biased=biased) 70 | return _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased) 71 | 72 | 73 | ################################################################################ 74 | # Helper functions to compute variances based on kernel matrices 75 | ################################################################################ 76 | 77 | 78 | def _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=False): 79 | m = K_XX.size(0) # assume X, Y are same shape 80 | 81 | # Get the various sums of kernels that we'll use 82 | # Kts drop the diagonal, but we don't need to compute them explicitly 83 | if const_diagonal is not False: 84 | diag_X = diag_Y = const_diagonal 85 | sum_diag_X = sum_diag_Y = m * const_diagonal 86 | else: 87 | diag_X = torch.diag(K_XX) # (m,) 88 | diag_Y = torch.diag(K_YY) # (m,) 89 | sum_diag_X = torch.sum(diag_X) 90 | sum_diag_Y = torch.sum(diag_Y) 91 | 92 | Kt_XX_sums = K_XX.sum(dim=1) - diag_X # \tilde{K}_XX * e = K_XX * e - diag_X 93 | Kt_YY_sums = K_YY.sum(dim=1) - diag_Y # \tilde{K}_YY * e = K_YY * e - diag_Y 94 | K_XY_sums_0 = K_XY.sum(dim=0) # K_{XY}^T * e 95 | 96 | Kt_XX_sum = Kt_XX_sums.sum() # e^T * \tilde{K}_XX * e 97 | Kt_YY_sum = Kt_YY_sums.sum() # e^T * \tilde{K}_YY * e 98 | K_XY_sum = K_XY_sums_0.sum() # e^T * K_{XY} * e 99 | 100 | if biased: 101 | mmd2 = ( 102 | (Kt_XX_sum + sum_diag_X) / (m * m) 103 | + (Kt_YY_sum + sum_diag_Y) / (m * m) 104 | - 2.0 * K_XY_sum / (m * m) 105 | ) 106 | else: 107 | mmd2 = Kt_XX_sum / (m * (m - 1)) + Kt_YY_sum / (m * (m - 1)) - 2.0 * K_XY_sum / (m * m) 108 | 109 | return mmd2 110 | 111 | 112 | def _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=False, biased=False): 113 | mmd2, var_est = _mmd2_and_variance( 114 | K_XX, K_XY, K_YY, const_diagonal=const_diagonal, biased=biased 115 | ) 116 | loss = mmd2 / torch.sqrt(torch.clamp(var_est, min=min_var_est)) 117 | return loss, mmd2, var_est 118 | 119 | 120 | def _mmd2_and_variance(K_XX, K_XY, K_YY, const_diagonal=False, biased=False): 121 | m = K_XX.size(0) # assume X, Y are same shape 122 | 123 | # Get the various sums of kernels that we'll use 124 | # Kts drop the diagonal, but we don't need to compute them explicitly 125 | if const_diagonal is not False: 126 | diag_X = diag_Y = const_diagonal 127 | sum_diag_X = sum_diag_Y = m * const_diagonal 128 | sum_diag2_X = sum_diag2_Y = m * const_diagonal**2 129 | else: 130 | diag_X = torch.diag(K_XX) # (m,) 131 | diag_Y = torch.diag(K_YY) # (m,) 132 | sum_diag_X = torch.sum(diag_X) 133 | sum_diag_Y = torch.sum(diag_Y) 134 | sum_diag2_X = diag_X.dot(diag_X) 135 | sum_diag2_Y = diag_Y.dot(diag_Y) 136 | 137 | Kt_XX_sums = K_XX.sum(dim=1) - diag_X # \tilde{K}_XX * e = K_XX * e - diag_X 138 | Kt_YY_sums = K_YY.sum(dim=1) - diag_Y # \tilde{K}_YY * e = K_YY * e - diag_Y 139 | K_XY_sums_0 = K_XY.sum(dim=0) # K_{XY}^T * e 140 | K_XY_sums_1 = K_XY.sum(dim=1) # K_{XY} * e 141 | 142 | Kt_XX_sum = Kt_XX_sums.sum() # e^T * \tilde{K}_XX * e 143 | Kt_YY_sum = Kt_YY_sums.sum() # e^T * \tilde{K}_YY * e 144 | K_XY_sum = K_XY_sums_0.sum() # e^T * K_{XY} * e 145 | 146 | Kt_XX_2_sum = (K_XX**2).sum() - sum_diag2_X # \| \tilde{K}_XX \|_F^2 147 | Kt_YY_2_sum = (K_YY**2).sum() - sum_diag2_Y # \| \tilde{K}_YY \|_F^2 148 | K_XY_2_sum = (K_XY**2).sum() # \| K_{XY} \|_F^2 149 | 150 | if biased: 151 | mmd2 = ( 152 | (Kt_XX_sum + sum_diag_X) / (m * m) 153 | + (Kt_YY_sum + sum_diag_Y) / (m * m) 154 | - 2.0 * K_XY_sum / (m * m) 155 | ) 156 | else: 157 | mmd2 = Kt_XX_sum / (m * (m - 1)) + Kt_YY_sum / (m * (m - 1)) - 2.0 * K_XY_sum / (m * m) 158 | 159 | var_est = ( 160 | 2.0 161 | / (m**2 * (m - 1.0) ** 2) 162 | * ( 163 | 2 * Kt_XX_sums.dot(Kt_XX_sums) 164 | - Kt_XX_2_sum 165 | + 2 * Kt_YY_sums.dot(Kt_YY_sums) 166 | - Kt_YY_2_sum 167 | ) 168 | - (4.0 * m - 6.0) / (m**3 * (m - 1.0) ** 3) * (Kt_XX_sum**2 + Kt_YY_sum**2) 169 | + 4.0 170 | * (m - 2.0) 171 | / (m**3 * (m - 1.0) ** 2) 172 | * (K_XY_sums_1.dot(K_XY_sums_1) + K_XY_sums_0.dot(K_XY_sums_0)) 173 | - 4.0 * (m - 3.0) / (m**3 * (m - 1.0) ** 2) * (K_XY_2_sum) 174 | - (8 * m - 12) / (m**5 * (m - 1)) * K_XY_sum**2 175 | + 8.0 176 | / (m**3 * (m - 1.0)) 177 | * ( 178 | 1.0 / m * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum 179 | - Kt_XX_sums.dot(K_XY_sums_1) 180 | - Kt_YY_sums.dot(K_XY_sums_0) 181 | ) 182 | ) 183 | return mmd2, var_est 184 | -------------------------------------------------------------------------------- /runner/src/models/components/nn.py: -------------------------------------------------------------------------------- 1 | """Various utilities for neural networks.""" 2 | 3 | import math 4 | 5 | import torch as th 6 | import torch.nn as nn 7 | 8 | 9 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 10 | class SiLU(nn.Module): 11 | def forward(self, x): 12 | return x * th.sigmoid(x) 13 | 14 | 15 | class GroupNorm32(nn.GroupNorm): 16 | def forward(self, x): 17 | return super().forward(x.float()).type(x.dtype) 18 | 19 | 20 | def conv_nd(dims, *args, **kwargs): 21 | """Create a 1D, 2D, or 3D convolution module.""" 22 | if dims == 1: 23 | return nn.Conv1d(*args, **kwargs) 24 | elif dims == 2: 25 | return nn.Conv2d(*args, **kwargs) 26 | elif dims == 3: 27 | return nn.Conv3d(*args, **kwargs) 28 | raise ValueError(f"unsupported dimensions: {dims}") 29 | 30 | 31 | def linear(*args, **kwargs): 32 | """Create a linear module.""" 33 | return nn.Linear(*args, **kwargs) 34 | 35 | 36 | def avg_pool_nd(dims, *args, **kwargs): 37 | """Create a 1D, 2D, or 3D average pooling module.""" 38 | if dims == 1: 39 | return nn.AvgPool1d(*args, **kwargs) 40 | elif dims == 2: 41 | return nn.AvgPool2d(*args, **kwargs) 42 | elif dims == 3: 43 | return nn.AvgPool3d(*args, **kwargs) 44 | raise ValueError(f"unsupported dimensions: {dims}") 45 | 46 | 47 | def update_ema(target_params, source_params, rate=0.99): 48 | """Update target parameters to be closer to those of source parameters using an exponential 49 | moving average. 50 | 51 | :param target_params: the target parameter sequence. 52 | :param source_params: the source parameter sequence. 53 | :param rate: the EMA rate (closer to 1 means slower). 54 | """ 55 | for targ, src in zip(target_params, source_params): 56 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 57 | 58 | 59 | def zero_module(module): 60 | """Zero out the parameters of a module and return it.""" 61 | for p in module.parameters(): 62 | p.detach().zero_() 63 | return module 64 | 65 | 66 | def scale_module(module, scale): 67 | """Scale the parameters of a module and return it.""" 68 | for p in module.parameters(): 69 | p.detach().mul_(scale) 70 | return module 71 | 72 | 73 | def mean_flat(tensor): 74 | """Take the mean over all non-batch dimensions.""" 75 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 76 | 77 | 78 | def normalization(channels): 79 | """Make a standard normalization layer. 80 | 81 | :param channels: number of input channels. 82 | :return: an nn.Module for normalization. 83 | """ 84 | return GroupNorm32(32, channels) 85 | 86 | 87 | def timestep_embedding(timesteps, dim, max_period=10000): 88 | """Create sinusoidal timestep embeddings. 89 | 90 | :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. 91 | :param dim: the dimension of the output. 92 | :param max_period: controls the minimum frequency of the embeddings. 93 | :return: an [N x dim] Tensor of positional embeddings. 94 | """ 95 | half = dim // 2 96 | freqs = th.exp( 97 | -math.log(max_period) 98 | * th.arange(start=0, end=half, dtype=th.float32, device=timesteps.device) 99 | / half 100 | ) 101 | args = timesteps[:, None].float() * freqs[None] 102 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 103 | if dim % 2: 104 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 105 | return embedding 106 | 107 | 108 | def checkpoint(func, inputs, params, flag): 109 | """Evaluate a function without caching intermediate activations, allowing for reduced memory at 110 | the expense of extra compute in the backward pass. 111 | 112 | :param func: the function to evaluate. 113 | :param inputs: the argument sequence to pass to `func`. 114 | :param params: a sequence of parameters `func` depends on but does not 115 | explicitly take as arguments. 116 | :param flag: if False, disable gradient checkpointing. 117 | """ 118 | if flag: 119 | args = tuple(inputs) + tuple(params) 120 | return CheckpointFunction.apply(func, len(inputs), *args) 121 | else: 122 | return func(*inputs) 123 | 124 | 125 | class CheckpointFunction(th.autograd.Function): 126 | @staticmethod 127 | def forward(ctx, run_function, length, *args): 128 | ctx.run_function = run_function 129 | ctx.input_tensors = list(args[:length]) 130 | ctx.input_params = list(args[length:]) 131 | with th.no_grad(): 132 | output_tensors = ctx.run_function(*ctx.input_tensors) 133 | return output_tensors 134 | 135 | @staticmethod 136 | def backward(ctx, *output_grads): 137 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 138 | with th.enable_grad(): 139 | # Fixes a bug where the first op in run_function modifies the 140 | # Tensor storage in place, which is not allowed for detach()'d 141 | # Tensors. 142 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 143 | output_tensors = ctx.run_function(*shallow_copies) 144 | input_grads = th.autograd.grad( 145 | output_tensors, 146 | ctx.input_tensors + ctx.input_params, 147 | output_grads, 148 | allow_unused=True, 149 | ) 150 | del ctx.input_tensors 151 | del ctx.input_params 152 | del output_tensors 153 | return (None, None) + input_grads 154 | -------------------------------------------------------------------------------- /runner/src/models/components/optimal_transport.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | from typing import Optional 4 | 5 | import numpy as np 6 | import ot as pot 7 | import torch 8 | 9 | 10 | class OTPlanSampler: 11 | """OTPlanSampler implements sampling coordinates according to an squared L2 OT plan with 12 | different implementations of the plan calculation.""" 13 | 14 | def __init__( 15 | self, 16 | method: str, 17 | reg: float = 0.05, 18 | reg_m: float = 1.0, 19 | normalize_cost=False, 20 | **kwargs, 21 | ): 22 | # ot_fn should take (a, b, M) as arguments where a, b are marginals and 23 | # M is a cost matrix 24 | if method == "exact": 25 | self.ot_fn = pot.emd 26 | elif method == "sinkhorn": 27 | self.ot_fn = partial(pot.sinkhorn, reg=reg) 28 | elif method == "unbalanced": 29 | self.ot_fn = partial(pot.unbalanced.sinkhorn_knopp_unbalanced, reg=reg, reg_m=reg_m) 30 | elif method == "partial": 31 | self.ot_fn = partial(pot.partial.entropic_partial_wasserstein, reg=reg) 32 | else: 33 | raise ValueError(f"Unknown method: {method}") 34 | self.reg = reg 35 | self.reg_m = reg_m 36 | self.normalize_cost = normalize_cost 37 | self.kwargs = kwargs 38 | 39 | def get_map(self, x0, x1): 40 | a, b = pot.unif(x0.shape[0]), pot.unif(x1.shape[0]) 41 | if x0.dim() > 2: 42 | x0 = x0.reshape(x0.shape[0], -1) 43 | if x1.dim() > 2: 44 | x1 = x1.reshape(x1.shape[0], -1) 45 | x1 = x1.reshape(x1.shape[0], -1) 46 | M = torch.cdist(x0, x1) ** 2 47 | if self.normalize_cost: 48 | M = M / M.max() 49 | p = self.ot_fn(a, b, M.detach().cpu().numpy()) 50 | if not np.all(np.isfinite(p)): 51 | print("ERROR: p is not finite") 52 | print(p) 53 | print("Cost mean, max", M.mean(), M.max()) 54 | print(x0, x1) 55 | return p 56 | 57 | def sample_map(self, pi, batch_size): 58 | p = pi.flatten() 59 | p = p / p.sum() 60 | choices = np.random.choice(pi.shape[0] * pi.shape[1], p=p, size=batch_size) 61 | return np.divmod(choices, pi.shape[1]) 62 | 63 | def sample_plan(self, x0, x1): 64 | pi = self.get_map(x0, x1) 65 | i, j = self.sample_map(pi, x0.shape[0]) 66 | return x0[i], x1[j] 67 | 68 | def sample_trajectory(self, X): 69 | # Assume X is [batch, times, dim] 70 | times = X.shape[1] 71 | pis = [] 72 | for t in range(times - 1): 73 | pis.append(self.get_map(X[:, t], X[:, t + 1])) 74 | 75 | indices = [np.arange(X.shape[0])] 76 | for pi in pis: 77 | j = [] 78 | for i in indices[-1]: 79 | j.append(np.random.choice(pi.shape[1], p=pi[i] / pi[i].sum())) 80 | indices.append(np.array(j)) 81 | 82 | to_return = [] 83 | for t in range(times): 84 | to_return.append(X[:, t][indices[t]]) 85 | to_return = np.stack(to_return, axis=1) 86 | return to_return 87 | 88 | 89 | def wasserstein( 90 | x0: torch.Tensor, 91 | x1: torch.Tensor, 92 | method: Optional[str] = None, 93 | reg: float = 0.05, 94 | power: int = 2, 95 | **kwargs, 96 | ) -> float: 97 | assert power == 1 or power == 2 98 | # ot_fn should take (a, b, M) as arguments where a, b are marginals and 99 | # M is a cost matrix 100 | if method == "exact" or method is None: 101 | ot_fn = pot.emd2 102 | elif method == "sinkhorn": 103 | ot_fn = partial(pot.sinkhorn2, reg=reg) 104 | else: 105 | raise ValueError(f"Unknown method: {method}") 106 | 107 | a, b = pot.unif(x0.shape[0]), pot.unif(x1.shape[0]) 108 | if x0.dim() > 2: 109 | x0 = x0.reshape(x0.shape[0], -1) 110 | if x1.dim() > 2: 111 | x1 = x1.reshape(x1.shape[0], -1) 112 | M = torch.cdist(x0, x1) 113 | if power == 2: 114 | M = M**2 115 | ret = ot_fn(a, b, M.detach().cpu().numpy(), numItermax=1e7) 116 | if power == 2: 117 | ret = math.sqrt(ret) 118 | return ret 119 | -------------------------------------------------------------------------------- /runner/src/models/components/regularizers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Regularizer(nn.Module): 6 | def __init__(self): 7 | pass 8 | 9 | 10 | def _batch_root_mean_squared(tensor): 11 | tensor = tensor.view(tensor.shape[0], -1) 12 | return torch.norm(tensor, p=2, dim=1) / tensor.shape[1] ** 0.5 13 | 14 | 15 | class RegularizationFunc(nn.Module): 16 | def forward(self, t, x, dx, context) -> torch.Tensor: 17 | """Outputs a batch of scaler regularizations.""" 18 | raise NotImplementedError 19 | 20 | 21 | class L1Reg(RegularizationFunc): 22 | def forward(self, t, x, dx, context) -> torch.Tensor: 23 | return torch.mean(torch.abs(dx), dim=1) 24 | 25 | 26 | class L2Reg(RegularizationFunc): 27 | def forward(self, t, x, dx, context) -> torch.Tensor: 28 | return _batch_root_mean_squared(dx) 29 | 30 | 31 | class SquaredL2Reg(RegularizationFunc): 32 | def forward(self, t, x, dx, context) -> torch.Tensor: 33 | to_return = dx.view(dx.shape[0], -1) 34 | return torch.pow(torch.norm(to_return, p=2, dim=1), 2) 35 | 36 | 37 | def _get_minibatch_jacobian(y, x, create_graph=True): 38 | """Computes the Jacobian of y wrt x assuming minibatch-mode. 39 | 40 | Args: 41 | y: (N, ...) with a total of D_y elements in ... 42 | x: (N, ...) with a total of D_x elements in ... 43 | Returns: 44 | The minibatch Jacobian matrix of shape (N, D_y, D_x) 45 | """ 46 | # assert y.shape[0] == x.shape[0] 47 | y = y.view(y.shape[0], -1) 48 | 49 | # Compute Jacobian row by row. 50 | jac = [] 51 | for j in range(y.shape[1]): 52 | dy_j_dx = torch.autograd.grad( 53 | y[:, j], 54 | x, 55 | torch.ones_like(y[:, j]), 56 | retain_graph=True, 57 | create_graph=create_graph, 58 | )[0] 59 | jac.append(torch.unsqueeze(dy_j_dx, -1)) 60 | jac = torch.cat(jac, -1) 61 | return jac 62 | 63 | 64 | class JacobianFrobeniusReg(RegularizationFunc): 65 | def forward(self, t, x, dx, context) -> torch.Tensor: 66 | if hasattr(context, "jac"): 67 | jac = context.jac 68 | else: 69 | jac = _get_minibatch_jacobian(dx, x) 70 | context.jac = jac 71 | jac = _get_minibatch_jacobian(dx, x) 72 | context.jac = jac 73 | return _batch_root_mean_squared(jac) 74 | 75 | 76 | class JacobianDiagFrobeniusReg(RegularizationFunc): 77 | def forward(self, t, x, dx, context) -> torch.Tensor: 78 | if hasattr(context, "jac"): 79 | jac = context.jac 80 | else: 81 | jac = _get_minibatch_jacobian(dx, x) 82 | context.jac = jac 83 | diagonal = jac.view(jac.shape[0], -1)[ 84 | :, :: jac.shape[1] 85 | ] # assumes jac is minibatch square, ie. (N, M, M). 86 | return _batch_root_mean_squared(diagonal) 87 | 88 | 89 | class JacobianOffDiagFrobeniusReg(RegularizationFunc): 90 | def forward(self, t, x, dx, context) -> torch.Tensor: 91 | if hasattr(context, "jac"): 92 | jac = context.jac 93 | else: 94 | jac = _get_minibatch_jacobian(dx, x) 95 | context.jac = jac 96 | diagonal = jac.view(jac.shape[0], -1)[ 97 | :, :: jac.shape[1] 98 | ] # assumes jac is minibatch square, ie. (N, M, M). 99 | ss_offdiag = torch.sum(jac.view(jac.shape[0], -1) ** 2, dim=1) - torch.sum( 100 | diagonal**2, dim=1 101 | ) 102 | ms_offdiag = ss_offdiag / (diagonal.shape[1] * (diagonal.shape[1] - 1)) 103 | return ms_offdiag 104 | 105 | 106 | def autograd_trace(x_out, x_in, **kwargs): 107 | """Standard brute-force means of obtaining trace of the Jacobian, O(d) calls to autograd.""" 108 | trJ = 0.0 109 | for i in range(x_in.shape[1]): 110 | trJ += torch.autograd.grad(x_out[:, i].sum(), x_in, allow_unused=False, create_graph=True)[ 111 | 0 112 | ][:, i] 113 | return trJ 114 | 115 | 116 | class CNFReg(RegularizationFunc): 117 | def __init__(self, trace_estimator=None, noise_dist=None): 118 | super().__init__() 119 | self.trace_estimator = trace_estimator if trace_estimator is not None else autograd_trace 120 | self.noise_dist, self.noise = noise_dist, None 121 | 122 | def forward(self, t, x, dx, context): 123 | # TODO we could check if jac is in the context to speed up 124 | return -self.trace_estimator(dx, x, noise=self.noise) 125 | 126 | 127 | class AugmentationModule(nn.Module): 128 | """Class orchestrating augmentations. 129 | 130 | Also establishes order. 131 | """ 132 | 133 | def __init__( 134 | self, 135 | cnf_estimator: str = None, 136 | l1_reg: float = 0.0, 137 | l2_reg: float = 0.0, 138 | squared_l2_reg: float = 0.0, 139 | jacobian_frobenius_reg: float = 0.0, 140 | jacobian_diag_frobenius_reg: float = 0.0, 141 | jacobian_off_diag_frobenius_reg: float = 0.0, 142 | ) -> None: 143 | super().__init__() 144 | coeffs = [] 145 | regs = [] 146 | if cnf_estimator == "exact": 147 | coeffs.append(1) 148 | regs.append(CNFReg(None, noise_dist=None)) 149 | if l1_reg > 0.0: 150 | coeffs.append(l1_reg) 151 | regs.append(L1Reg()) 152 | if l2_reg > 0.0: 153 | coeffs.append(l2_reg) 154 | regs.append(L2Reg()) 155 | if squared_l2_reg > 0.0: 156 | coeffs.append(squared_l2_reg) 157 | regs.append(SquaredL2Reg()) 158 | if jacobian_frobenius_reg > 0.0: 159 | coeffs.append(jacobian_frobenius_reg) 160 | regs.append(JacobianFrobeniusReg()) 161 | if jacobian_diag_frobenius_reg > 0.0: 162 | coeffs.append(jacobian_diag_frobenius_reg) 163 | regs.append(JacobianDiagFrobeniusReg()) 164 | if jacobian_off_diag_frobenius_reg > 0.0: 165 | coeffs.append(jacobian_off_diag_frobenius_reg) 166 | regs.append(JacobianOffDiagFrobeniusReg()) 167 | 168 | self.coeffs = torch.tensor(coeffs) 169 | self.regs = torch.ModuleList(regs) 170 | 171 | 172 | if __name__ == "__main__": 173 | # Test Shapes 174 | class SharedContext: 175 | pass 176 | 177 | for reg in [ 178 | L1Reg, 179 | L2Reg, 180 | SquaredL2Reg, 181 | JacobianFrobeniusReg, 182 | JacobianDiagFrobeniusReg, 183 | JacobianOffDiagFrobeniusReg, 184 | ]: 185 | x = torch.ones(2, 3).requires_grad_(True) 186 | dx = x * 2 187 | out = reg().forward(torch.ones(1), x, dx, SharedContext) 188 | assert out.dim() == 1 189 | assert out.shape[0] == 2 190 | -------------------------------------------------------------------------------- /runner/src/models/components/schedule.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class NoiseScheduler: 6 | """Base Class for noise schedule. 7 | 8 | The noise schedule is a function that maps time to reference process noise level. We can use 9 | this to determine the Brownian bridge noise schedule. 10 | 11 | We define the noise schedule with __call__ and the Brownian bridge noise schedule with sigma_t. 12 | We define F as the integral of the squared reference process noise schedule which is a useful 13 | intermediate quantity. 14 | """ 15 | 16 | def __call__(self, t): 17 | """Calculate the reference process noise schedule. 18 | 19 | g(t) in the paper. 20 | """ 21 | raise NotImplementedError 22 | 23 | def F(self, t): 24 | """Calculate the integral of the squared reference process noise schedule.""" 25 | raise NotImplementedError 26 | 27 | def sigma_t(self, t): 28 | """Given the reference process noise schedule, calculate the brownian bridge noise 29 | schedule.""" 30 | return torch.sqrt(self.F(t) - self.F(t) ** 2 / self.F(1)) 31 | 32 | 33 | class ConstantNoiseScheduler(NoiseScheduler): 34 | def __init__(self, sigma: float): 35 | self.sigma = sigma 36 | 37 | def __call__(self, t): 38 | return self.sigma 39 | 40 | def F(self, t): 41 | return self.sigma**2 * t 42 | 43 | 44 | class LinearDecreasingNoiseScheduler(NoiseScheduler): 45 | def __init__(self, sigma_min: float, sigma_max: float): 46 | self.sigma_min = sigma_min 47 | self.sigma_max = sigma_max 48 | 49 | def __call__(self, t): 50 | return torch.sqrt(t * self.sigma_min + (1 - t) * self.sigma_max) 51 | 52 | def F(self, t): 53 | return (t**2) * self.sigma_min / 2 - (t**2) * self.sigma_max / 2 + self.sigma_max * t 54 | 55 | 56 | class CosineNoiseScheduler(NoiseScheduler): 57 | def __init__(self, sigma_min: float, scale: float): 58 | self.sigma_min = sigma_min 59 | self.scale = scale 60 | 61 | def __call__(self, t): 62 | return self.scale * (1 - (t * np.pi * 2).cos()) + self.sigma_min 63 | 64 | def F(self, t): 65 | antider = t - (t * 2 * np.pi).sin() / (2 * np.pi) 66 | antider2 = t - 2 * (t * 2 * np.pi).sin() / (2 * np.pi) 67 | antider2 += t / 2 + (t * 4 * np.pi).sin() / (8 * np.pi) 68 | return ( 69 | self.scale**2 * antider2 70 | + t * self.sigma_min**2 71 | + self.scale * 2 * self.sigma_min * antider 72 | ) 73 | -------------------------------------------------------------------------------- /runner/src/models/components/simple_dense_net.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class SimpleDenseNet(nn.Module): 5 | def __init__( 6 | self, 7 | input_size: int = 784, 8 | lin1_size: int = 256, 9 | lin2_size: int = 256, 10 | lin3_size: int = 256, 11 | output_size: int = 10, 12 | ): 13 | super().__init__() 14 | 15 | self.model = nn.Sequential( 16 | nn.Linear(input_size, lin1_size), 17 | nn.BatchNorm1d(lin1_size), 18 | nn.ReLU(), 19 | nn.Linear(lin1_size, lin2_size), 20 | nn.BatchNorm1d(lin2_size), 21 | nn.ReLU(), 22 | nn.Linear(lin2_size, lin3_size), 23 | nn.BatchNorm1d(lin3_size), 24 | nn.ReLU(), 25 | nn.Linear(lin3_size, output_size), 26 | ) 27 | 28 | def forward(self, x): 29 | batch_size, channels, width, height = x.size() 30 | 31 | # (batch, 1, width, height) -> (batch, 1*width*height) 32 | x = x.view(batch_size, -1) 33 | 34 | return self.model(x) 35 | 36 | 37 | if __name__ == "__main__": 38 | _ = SimpleDenseNet() 39 | -------------------------------------------------------------------------------- /runner/src/models/components/simple_mlp.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | from torch import nn 5 | 6 | ACTIVATION_MAP = { 7 | "relu": nn.ReLU, 8 | "sigmoid": nn.Sigmoid, 9 | "tanh": nn.Tanh, 10 | "selu": nn.SELU, 11 | "elu": nn.ELU, 12 | "lrelu": nn.LeakyReLU, 13 | "softplus": nn.Softplus, 14 | "silu": nn.SiLU, 15 | } 16 | 17 | 18 | class SimpleDenseNet(nn.Module): 19 | def __init__( 20 | self, 21 | input_size: int, 22 | target_size: int, 23 | activation: str, 24 | batch_norm: bool = True, 25 | hidden_dims: Optional[List[int]] = None, 26 | ): 27 | super().__init__() 28 | if hidden_dims is None: 29 | hidden_dims = [256, 256, 256] 30 | dims = [input_size, *hidden_dims, target_size] 31 | layers = [] 32 | for i in range(len(dims) - 2): 33 | layers.append(nn.Linear(dims[i], dims[i + 1])) 34 | if batch_norm: 35 | layers.append(nn.BatchNorm1d(dims[i + 1])) 36 | layers.append(ACTIVATION_MAP[activation]()) 37 | layers.append(nn.Linear(dims[-2], dims[-1])) 38 | self.model = nn.Sequential(*layers) 39 | 40 | def forward(self, x): 41 | return self.model(x) 42 | 43 | 44 | class DivergenceFreeNet(SimpleDenseNet): 45 | """Implements a divergence free network as the gradient of a scalar potential function.""" 46 | 47 | def __init__(self, dim: int, *args, **kwargs): 48 | super().__init__(input_size=dim + 1, target_size=1, *args, **kwargs) 49 | 50 | def energy(self, x): 51 | return self.model(x) 52 | 53 | def forward(self, t, x, *args, **kwargs): 54 | """Ignore t run model.""" 55 | if t.dim() < 2: 56 | t = t.repeat(x.shape[0])[:, None] 57 | x = torch.cat([t, x], dim=-1) 58 | x = x.requires_grad_(True) 59 | grad = torch.autograd.grad(torch.sum(self.model(x)), x, create_graph=True)[0] 60 | return grad[:, :-1] 61 | 62 | 63 | class TimeInvariantVelocityNet(SimpleDenseNet): 64 | def __init__(self, dim: int, *args, **kwargs): 65 | super().__init__(input_size=dim, target_size=dim, *args, **kwargs) 66 | 67 | def forward(self, t, x, *args, **kwargs): 68 | """Ignore t run model.""" 69 | del t 70 | return self.model(x) 71 | 72 | 73 | class VelocityNet(SimpleDenseNet): 74 | def __init__(self, dim: int, *args, **kwargs): 75 | super().__init__(input_size=dim + 1, target_size=dim, *args, **kwargs) 76 | 77 | def forward(self, t, x, *args, **kwargs): 78 | """Ignore t run model.""" 79 | if t.dim() < 1 or t.shape[0] != x.shape[0]: 80 | t = t.repeat(x.shape[0])[:, None] 81 | if t.dim() < 2: 82 | t = t[:, None] 83 | x = torch.cat([t, x], dim=-1) 84 | return self.model(x) 85 | 86 | 87 | if __name__ == "__main__": 88 | _ = SimpleDenseNet() 89 | _ = TimeInvariantVelocityNet() 90 | -------------------------------------------------------------------------------- /runner/src/models/components/sinkhorn_knopp_unbalanced.py: -------------------------------------------------------------------------------- 1 | """Implements unbalanced sinkhorn knopp optimization for unbalanced ot. 2 | 3 | This is from the package python optimal transport but modified to take three regularization 4 | parameters instead of two. This is necessary to find growth rates of the source distribution that 5 | best match the target distribution or vis versa. by setting reg_m_1 to something low and reg_m_2 to 6 | something large we can compute an unbalanced optimal transport where all the scaling is done on the 7 | source distribution and none is done on the target distribution. 8 | """ 9 | import warnings 10 | 11 | import numpy as np 12 | 13 | 14 | def sinkhorn_knopp_unbalanced( 15 | a, 16 | b, 17 | M, 18 | reg, 19 | reg_m_1, 20 | reg_m_2, 21 | numItermax=1000, 22 | stopThr=1e-6, 23 | verbose=False, 24 | log=False, 25 | **kwargs, 26 | ): 27 | """Solve the entropic regularization unbalanced optimal transport problem. 28 | 29 | The function solves the following optimization problem: 30 | 31 | .. math:: 32 | W = \\min_\\gamma <\\gamma,M>_F + reg\\cdot\\Omega(\\gamma) + \ 33 | \reg_m_1 KL(\\gamma 1, a) + \reg_m_2 KL(\\gamma^T 1, b) 34 | 35 | s.t. 36 | \\gamma\\geq 0 37 | where : 38 | 39 | - M is the (dim_a, dim_b) metric cost matrix 40 | - :math:`\\Omega` is the entropic regularization term 41 | :math:`\\Omega(\\gamma)=\\sum_{i,j} \\gamma_{i,j}\\log(\\gamma_{i,j})` 42 | - a and b are source and target unbalanced distributions 43 | - KL is the Kullback-Leibler divergence 44 | 45 | The algorithm used for solving the problem is the generalized 46 | Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ 47 | 48 | 49 | Parameters 50 | ---------- 51 | a : np.ndarray (dim_a,) 52 | Unnormalized histogram of dimension dim_a 53 | b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) 54 | One or multiple unnormalized histograms of dimension dim_b 55 | If many, compute all the OT distances (a, b_i) 56 | M : np.ndarray (dim_a, dim_b) 57 | loss matrix 58 | reg : float 59 | Entropy regularization term > 0 60 | reg_m: float 61 | Marginal relaxation term > 0 62 | numItermax : int, optional 63 | Max number of iterations 64 | stopThr : float, optional 65 | Stop threshold on error (> 0) 66 | verbose : bool, optional 67 | Print information along iterations 68 | log : bool, optional 69 | record log if True 70 | 71 | 72 | Returns 73 | ------- 74 | if n_hists == 1: 75 | gamma : (dim_a x dim_b) ndarray 76 | Optimal transportation matrix for the given parameters 77 | log : dict 78 | log dictionary returned only if `log` is `True` 79 | else: 80 | ot_distance : (n_hists,) ndarray 81 | the OT distance between `a` and each of the histograms `b_i` 82 | log : dict 83 | log dictionary returned only if `log` is `True` 84 | Examples 85 | -------- 86 | 87 | >>> import ot 88 | >>> a=[.5, .5] 89 | >>> b=[.5, .5] 90 | >>> M=[[0., 1.],[1., 0.]] 91 | >>> ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.) 92 | array([[0.51122814, 0.18807032], 93 | [0.18807032, 0.51122814]]) 94 | 95 | References 96 | ---------- 97 | 98 | .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). 99 | Scaling algorithms for unbalanced transport problems. arXiv preprint 100 | arXiv:1607.05816. 101 | 102 | .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : 103 | Learning with a Wasserstein Loss, Advances in Neural Information 104 | Processing Systems (NIPS) 2015 105 | 106 | See Also 107 | -------- 108 | ot.lp.emd : Unregularized OT 109 | ot.optim.cg : General regularized OT 110 | """ 111 | 112 | a = np.asarray(a, dtype=np.float64) 113 | b = np.asarray(b, dtype=np.float64) 114 | M = np.asarray(M, dtype=np.float64) 115 | 116 | dim_a, dim_b = M.shape 117 | 118 | if len(a) == 0: 119 | a = np.ones(dim_a, dtype=np.float64) / dim_a 120 | if len(b) == 0: 121 | b = np.ones(dim_b, dtype=np.float64) / dim_b 122 | 123 | if len(b.shape) > 1: 124 | n_hists = b.shape[1] 125 | else: 126 | n_hists = 0 127 | 128 | if log: 129 | log = {"err": []} 130 | 131 | # we assume that no distances are null except those of the diagonal of 132 | # distances 133 | if n_hists: 134 | u = np.ones((dim_a, 1)) / dim_a 135 | v = np.ones((dim_b, n_hists)) / dim_b 136 | a = a.reshape(dim_a, 1) 137 | else: 138 | u = np.ones(dim_a) / dim_a 139 | v = np.ones(dim_b) / dim_b 140 | 141 | # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute 142 | K = np.empty(M.shape, dtype=M.dtype) 143 | np.divide(M, -reg, out=K) 144 | np.exp(K, out=K) 145 | 146 | cpt = 0 147 | err = 1.0 148 | 149 | while err > stopThr and cpt < numItermax: 150 | uprev = u 151 | vprev = v 152 | 153 | Kv = K.dot(v) 154 | u = (a / Kv) ** (reg_m_1 / (reg_m_1 + reg)) 155 | Ktu = K.T.dot(u) 156 | v = (b / Ktu) ** (reg_m_2 / (reg_m_2 + reg)) 157 | 158 | if ( 159 | np.any(Ktu == 0.0) 160 | or np.any(np.isnan(u)) 161 | or np.any(np.isnan(v)) 162 | or np.any(np.isinf(u)) 163 | or np.any(np.isinf(v)) 164 | ): 165 | # we have reached the machine precision 166 | # come back to previous solution and quit loop 167 | warnings.warn("Numerical errors at iteration %s" % cpt) 168 | u = uprev 169 | v = vprev 170 | break 171 | if cpt % 10 == 0: 172 | # we can speed up the process by checking for the error only all 173 | # the 10th iterations 174 | err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.0) 175 | err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.0) 176 | err = 0.5 * (err_u + err_v) 177 | if log: 178 | log["err"].append(err) 179 | if verbose: 180 | if cpt % 200 == 0: 181 | print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) 182 | print(f"{cpt:5d}|{err:8e}|") 183 | cpt += 1 184 | 185 | if log: 186 | log["logu"] = np.log(u + 1e-16) 187 | log["logv"] = np.log(v + 1e-16) 188 | 189 | if n_hists: # return only loss 190 | res = np.einsum("ik,ij,jk,ij->k", u, K, v, M) 191 | if log: 192 | return res, log 193 | else: 194 | return res 195 | 196 | else: # return OT matrix 197 | if log: 198 | return u[:, None] * K * v[None, :], log 199 | else: 200 | return u[:, None] * K * v[None, :] 201 | -------------------------------------------------------------------------------- /runner/src/models/components/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def plot_trajectories(data, pred, graph, dataset, title=[1, 2.1]): 9 | fig, axs = plt.subplots(1, 3, figsize=(10, 2.3)) 10 | fig.tight_layout(pad=0.2, w_pad=2, h_pad=3) 11 | assert data.shape[-1] == pred.shape[-1] 12 | for i in range(data.shape[-1]): 13 | axs[0].plot(data[0, :, i].squeeze()) 14 | axs[1].plot(pred[0, :, i].squeeze()) 15 | title = f"{dataset}: Epoch = {title[0]}, Loss = {title[1]:1.3f}" 16 | axs[1].set_title(title) 17 | cax = axs[2].matshow(graph) 18 | fig.colorbar(cax) 19 | if not os.path.exists("figs"): 20 | os.mkdir("figs") 21 | plt.savefig(f"figs/{title}.png") 22 | plt.close() 23 | 24 | 25 | def plot_graph_dist(graph_mu, graph_thresh, graph_std, ground_truth, path): 26 | fig, axs = plt.subplots(1, 4, figsize=(13, 4.5)) 27 | # fig.tight_layout(pad=0.2, w_pad=2, h_pad=3) 28 | axs[0].set_title("Ground Truth") 29 | axs[1].set_title("Graph means") 30 | axs[2].set_title("Graph post-threshold") 31 | axs[3].set_title("Graph std") 32 | 33 | print(graph_mu.shape, ground_truth.shape) 34 | 35 | g = [ground_truth, graph_mu, graph_thresh, graph_std] 36 | for col in range(4): 37 | ax = axs[col] 38 | pcm = ax.matshow(g[col], cmap="viridis") 39 | fig.colorbar(pcm, ax=ax) 40 | 41 | if not os.path.exists(path + "/figs"): 42 | os.mkdir(path + "/figs") 43 | plt.savefig(f"{path}/figs/graph_dist_plot.png") 44 | plt.close() 45 | 46 | 47 | def plot_traj_dist(data, pred, dataset, title=[1, 2.1]): 48 | fig, axs = plt.subplots(1, 2, figsize=(10, 2.3)) 49 | fig.tight_layout(pad=0.2, w_pad=2, h_pad=3) 50 | assert data.shape[-1] == pred.shape[-1] 51 | for i in range(data.shape[-1]): 52 | axs[0].plot(data[0, :, i].squeeze()) 53 | axs[1].plot(pred[0, :, i].squeeze()) 54 | title = f"{dataset}: Epoch = {title[0]}, Loss = {title[1]:1.3f}" 55 | axs[1].set_title(title) 56 | if not os.path.exists("figs"): 57 | os.mkdir("figs") 58 | plt.savefig(f"figs/{title}.png") 59 | plt.close() 60 | 61 | 62 | def plot_cnf(data, traj, graph, dataset, title): 63 | n = 1000 64 | fig, axes = plt.subplots(1, 2, figsize=(10, 5)) 65 | ax = axes[0] 66 | data = data.reshape([-1, *data.shape[2:]]) 67 | ax.scatter(data[:, 0], data[:, 1], alpha=0.5) 68 | ax.scatter(traj[:n, -1, 0], traj[:n, -1, 1], s=10, alpha=0.8, c="black") 69 | # ax.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black") 70 | # ax.scatter(traj[:, :n, 0], traj[:, :n, 1], s=0.2, alpha=0.2, c="olive") 71 | ax.scatter(traj[:n, :, 0], traj[:n, :, 1], s=0.2, alpha=0.2, c="olive") 72 | # ax.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue") 73 | ax.scatter(traj[:n, 0, 0], traj[:n, 0, 1], s=4, alpha=1, c="blue") 74 | ax.legend(["data", "Last Timepoint", "Flow", "Posterior"]) 75 | 76 | ax = axes[1] 77 | cax = ax.matshow(graph) 78 | fig.colorbar(cax) 79 | title = f"{dataset}: Epoch = {title[0]}, Loss = {title[1]:1.3f}" 80 | ax.set_title(title) 81 | if not os.path.exists("figs"): 82 | os.mkdir("figs") 83 | plt.savefig(f"figs/{title}.png") 84 | plt.close() 85 | 86 | 87 | def plot_pca_traj(data, traj, graph, adata, dataset, title): 88 | """ 89 | Args: 90 | data: np.array [N, T, D] 91 | traj: np.array [N, T, D] 92 | graph: np.array [D, D] 93 | """ 94 | n = 1000 95 | fig, axes = plt.subplots(1, 2, figsize=(10, 5)) 96 | ax = axes[0] 97 | # data = data.reshape([-1, *data.shape[2:]]) 98 | 99 | def pca_transform(x, d=2): 100 | return (x - adata.var["means"].values) @ adata.varm["PCs"][:, :d] 101 | 102 | traj = pca_transform(traj) 103 | 104 | for t in range(data.shape[1]): 105 | pcd = pca_transform(data[:, t]) 106 | ax.scatter(pcd[:, 0], pcd[:, 1], alpha=0.5) 107 | ax.scatter(traj[:n, -1, 0], traj[:n, -1, 1], s=10, alpha=0.8, c="black") 108 | ax.scatter(traj[:n, :, 0], traj[:n, :, 1], s=0.2, alpha=0.2, c="olive") 109 | ax.scatter(traj[:n, 0, 0], traj[:n, 0, 1], s=4, alpha=1, c="blue") 110 | ax.legend( 111 | [ 112 | *[f"T={i}" for i in range(data.shape[1])], 113 | "Last Timepoint", 114 | "Flow", 115 | "Posterior", 116 | ] 117 | ) 118 | 119 | ax = axes[1] 120 | cax = ax.matshow(graph) 121 | fig.colorbar(cax) 122 | title = f"{dataset}: Epoch = {title[0]}, Loss = {title[1]:1.3f}" 123 | ax.set_title(title) 124 | if not os.path.exists("figs_pca"): 125 | os.mkdir("figs_pca") 126 | plt.savefig(f"figs_pca/{title}.png") 127 | np.save(f"figs_pca/{title}.npy", graph) 128 | plt.close() 129 | 130 | 131 | def to_torch(arr): 132 | if isinstance(arr, list): 133 | return torch.tensor(np.array(arr)).float() 134 | elif isinstance(arr, (np.ndarray, np.generic)): 135 | return torch.tensor(arr).float() 136 | else: 137 | raise NotImplementedError(f"to_torch not implemented for type: {type(arr)}") 138 | -------------------------------------------------------------------------------- /runner/src/models/runner.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Optional, Union 2 | 3 | import numpy as np 4 | import torch 5 | from pytorch_lightning import LightningDataModule, LightningModule 6 | 7 | from torchcfm import ConditionalFlowMatcher 8 | 9 | from .components.augmentation import AugmentationModule 10 | from .components.distribution_distances import compute_distribution_distances 11 | from .components.plotting import plot_trajectory, store_trajectories 12 | from .components.solver import FlowSolver 13 | from .utils import get_wandb_logger 14 | 15 | 16 | class CFMLitModule(LightningModule): 17 | def __init__( 18 | self, 19 | net: Any, 20 | optimizer: Any, 21 | datamodule: LightningDataModule, 22 | flow_matcher: ConditionalFlowMatcher, 23 | solver: FlowSolver, 24 | scheduler: Optional[Any] = None, 25 | plot: bool = False, 26 | ) -> None: 27 | super().__init__() 28 | self.save_hyperparameters( 29 | ignore=[ 30 | "net", 31 | "optimizer", 32 | "scheduler", 33 | "datamodule", 34 | "augmentations", 35 | "flow_matcher", 36 | "solver", 37 | ], 38 | logger=False, 39 | ) 40 | self.datamodule = datamodule 41 | self.is_trajectory = False 42 | if hasattr(datamodule, "IS_TRAJECTORY"): 43 | self.is_trajectory = datamodule.IS_TRAJECTORY 44 | # dims is either an integer or a tuple. This helps us to decide whether to process things as 45 | # a vector or as an image. 46 | if hasattr(datamodule, "dim"): 47 | self.dim = datamodule.dim 48 | self.is_image = False 49 | elif hasattr(datamodule, "dims"): 50 | self.dim = datamodule.dims 51 | self.is_image = True 52 | else: 53 | raise NotImplementedError("Datamodule must have either dim or dims") 54 | self.net = net(dim=self.dim) 55 | self.solver = solver 56 | self.optimizer = optimizer 57 | self.flow_matcher = flow_matcher 58 | self.scheduler = scheduler 59 | self.criterion = torch.nn.MSELoss() 60 | self.val_augmentations = AugmentationModule( 61 | # cnf_estimator=None, 62 | l1_reg=1, 63 | l2_reg=1, 64 | squared_l2_reg=1, 65 | ) 66 | 67 | def unpack_batch(self, batch): 68 | """Unpacks a batch of data to a single tensor.""" 69 | if not isinstance(self.dim, int): 70 | # Assume this is an image classification dataset where we need to strip the targets 71 | return batch[0] 72 | return batch 73 | 74 | def preprocess_batch(self, batch, training=False): 75 | """Converts a batch of data into matched a random pair of (x0, x1)""" 76 | X = self.unpack_batch(batch) 77 | # If no trajectory assume generate from standard normal 78 | x0 = torch.randn_like(X) 79 | x1 = X 80 | return x0, x1 81 | 82 | def step(self, batch: Any, training: bool = False): 83 | """Computes the loss on a batch of data.""" 84 | x0, x1 = self.preprocess_batch(batch, training) 85 | t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow(x0, x1) 86 | vt = self.net(t, xt) 87 | return torch.nn.functional.mse_loss(vt, ut) 88 | 89 | def training_step(self, batch: Any, batch_idx: int): 90 | loss = self.step(batch, training=True) 91 | self.log("train/loss", loss, on_step=True, prog_bar=True) 92 | return loss 93 | 94 | def eval_step(self, batch: Any, batch_idx: int, prefix: str): 95 | loss = self.step(batch, training=True) 96 | self.log(f"{prefix}/loss", loss) 97 | return {"loss": loss, "x": batch} 98 | 99 | def preprocess_epoch_end(self, outputs: List[Any], prefix: str): 100 | """Preprocess the outputs of the epoch end function.""" 101 | v = {k: torch.cat([d[k] for d in outputs]) for k in ["x"]} 102 | x = v["x"] 103 | 104 | # Sample some random points for the plotting function 105 | rand = torch.randn_like(x) 106 | x = torch.stack([rand, x], dim=1) 107 | ts = x.shape[1] 108 | x0 = x[:, 0] 109 | x_rest = x[:, 1:] 110 | return ts, x, x0, x_rest 111 | 112 | def forward_eval_integrate(self, ts, x0, x_rest, outputs, prefix): 113 | # Build a trajectory 114 | t_span = torch.linspace(0, 1, 101) 115 | aug_dims = self.val_augmentations.aug_dims 116 | solver = self.solver(self.net, self.dim) 117 | solver.augmentations = self.val_augmentations 118 | traj, aug = solver.odeint(x0, t_span) 119 | full_trajs = [traj] 120 | traj, aug = traj[-1], aug[-1] 121 | regs = [torch.mean(aug, dim=0).detach().cpu().numpy()] 122 | trajs = [traj] 123 | nfe = solver.nfe 124 | full_trajs = torch.cat(full_trajs) 125 | 126 | regs = np.stack(regs).mean(axis=0) 127 | names = [f"{prefix}/{name}" for name in self.val_augmentations.names] 128 | self.log_dict(dict(zip(names, regs)), sync_dist=True) 129 | 130 | names, dists = compute_distribution_distances(trajs, x_rest) 131 | names = [f"{prefix}/{name}" for name in names] 132 | d = dict(zip(names, dists)) 133 | d[f"{prefix}/nfe"] = nfe 134 | self.log_dict(d, sync_dist=True) 135 | return trajs, full_trajs 136 | 137 | def eval_epoch_end(self, outputs: List[Any], prefix: str): 138 | wandb_logger = get_wandb_logger(self.loggers) 139 | ts, x, x0, x_rest = self.preprocess_epoch_end(outputs, prefix) 140 | trajs, full_trajs = self.forward_eval_integrate(ts, x0, x_rest, outputs, prefix) 141 | 142 | if self.hparams.plot: 143 | plot_trajectory( 144 | x, 145 | full_trajs, 146 | title=f"{self.current_epoch}_ode", 147 | key="ode_path", 148 | wandb_logger=wandb_logger, 149 | ) 150 | store_trajectories(x, self.net) 151 | 152 | def validation_step(self, batch: Any, batch_idx: int): 153 | return self.eval_step(batch, batch_idx, "val") 154 | 155 | def validation_epoch_end(self, outputs: List[Any]): 156 | self.eval_epoch_end(outputs, "val") 157 | 158 | def test_step(self, batch: Any, batch_idx: int): 159 | return self.eval_step(batch, batch_idx, "test") 160 | 161 | def test_epoch_end(self, outputs: List[Any]): 162 | self.eval_epoch_end(outputs, "test") 163 | 164 | def configure_optimizers(self): 165 | """Pass model parameters to optimizer.""" 166 | optimizer = self.optimizer(params=self.parameters()) 167 | if self.scheduler is None: 168 | return optimizer 169 | 170 | scheduler = self.scheduler(optimizer) 171 | return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}] 172 | 173 | def lr_scheduler_step(self, scheduler, optimizer_idx, metric): 174 | scheduler.step(epoch=self.current_epoch) 175 | -------------------------------------------------------------------------------- /runner/src/models/utils.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.loggers import WandbLogger 2 | 3 | 4 | def get_wandb_logger(loggers): 5 | """Gets the wandb logger if it is the list of loggers otherwise returns None.""" 6 | wandb_logger = None 7 | for logger in loggers: 8 | if isinstance(logger, WandbLogger): 9 | wandb_logger = logger 10 | return wandb_logger 11 | -------------------------------------------------------------------------------- /runner/src/train.py: -------------------------------------------------------------------------------- 1 | import pyrootutils 2 | 3 | root = pyrootutils.setup_root( 4 | search_from=__file__, 5 | indicator=[".git", "pyproject.toml", "README.md"], 6 | pythonpath=True, 7 | dotenv=True, 8 | ) 9 | 10 | # ------------------------------------------------------------------------------------ # 11 | # `pyrootutils.setup_root(...)` above is optional line to make environment more convenient 12 | # should be placed at the top of each entry file 13 | # 14 | # main advantages: 15 | # - allows you to keep all entry files in "src/" without installing project as a package 16 | # - launching python file works no matter where is your current work dir 17 | # - automatically loads environment variables from ".env" if exists 18 | # 19 | # how it works: 20 | # - `setup_root()` above recursively searches for either ".git" or "pyproject.toml" in present 21 | # and parent dirs, to determine the project root dir 22 | # - adds root dir to the PYTHONPATH (if `pythonpath=True`), so this file can be run from 23 | # any place without installing project as a package 24 | # - sets PROJECT_ROOT environment variable which is used in "configs/paths/default.yaml" 25 | # to make all paths always relative to project root 26 | # - loads environment variables from ".env" in root dir (if `dotenv=True`) 27 | # 28 | # you can remove `pyrootutils.setup_root(...)` if you: 29 | # 1. either install project as a package or move each entry file to the project root dir 30 | # 2. remove PROJECT_ROOT variable from paths in "configs/paths/default.yaml" 31 | # 32 | # https://github.com/ashleve/pyrootutils 33 | # ------------------------------------------------------------------------------------ # 34 | 35 | from typing import List, Optional, Tuple 36 | 37 | import hydra 38 | import pytorch_lightning as pl 39 | from omegaconf import DictConfig 40 | from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer 41 | from pytorch_lightning.loggers import LightningLoggerBase 42 | 43 | from src import utils 44 | 45 | log = utils.get_pylogger(__name__) 46 | 47 | 48 | @utils.task_wrapper 49 | def train(cfg: DictConfig) -> Tuple[dict, dict]: 50 | """Trains the model. Can additionally evaluate on a testset, using best weights obtained during 51 | training. 52 | 53 | This method is wrapped in optional @task_wrapper decorator which applies extra utilities 54 | before and after the call. 55 | 56 | Args: 57 | cfg (DictConfig): Configuration composed by Hydra. 58 | 59 | Returns: 60 | Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. 61 | """ 62 | 63 | # set seed for random number generators in pytorch, numpy and python.random 64 | if cfg.get("seed"): 65 | pl.seed_everything(cfg.seed, workers=True) 66 | 67 | log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>") 68 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.datamodule) 69 | 70 | log.info(f"Instantiating model <{cfg.model._target_}>") 71 | if hasattr(datamodule, "pass_to_model"): 72 | log.info("Passing full datamodule to model") 73 | model: LightningModule = hydra.utils.instantiate(cfg.model)(datamodule=datamodule) 74 | else: 75 | if hasattr(datamodule, "dim"): 76 | log.info("Passing datamodule.dim to model") 77 | model: LightningModule = hydra.utils.instantiate(cfg.model)(dim=datamodule.dim) 78 | else: 79 | model: LightningModule = hydra.utils.instantiate(cfg.model) 80 | 81 | log.info("Instantiating callbacks...") 82 | callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) 83 | 84 | log.info("Instantiating loggers...") 85 | logger: List[LightningLoggerBase] = utils.instantiate_loggers(cfg.get("logger")) 86 | 87 | log.info(f"Instantiating trainer <{cfg.trainer._target_}>") 88 | trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) 89 | 90 | object_dict = { 91 | "cfg": cfg, 92 | "datamodule": datamodule, 93 | "model": model, 94 | "callbacks": callbacks, 95 | "logger": logger, 96 | "trainer": trainer, 97 | } 98 | 99 | if logger: 100 | log.info("Logging hyperparameters!") 101 | utils.log_hyperparameters(object_dict) 102 | 103 | if cfg.get("train"): 104 | log.info("Starting training!") 105 | trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) 106 | 107 | train_metrics = trainer.callback_metrics 108 | 109 | if cfg.get("test"): 110 | log.info("Starting testing!") 111 | ckpt_path = trainer.checkpoint_callback.best_model_path 112 | if ckpt_path == "": 113 | log.warning("Best ckpt not found! Using current weights for testing...") 114 | ckpt_path = None 115 | trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) 116 | log.info(f"Best ckpt path: {ckpt_path}") 117 | 118 | test_metrics = trainer.callback_metrics 119 | 120 | # merge train and test metrics 121 | metric_dict = {**train_metrics, **test_metrics} 122 | 123 | return metric_dict, object_dict 124 | 125 | 126 | @hydra.main(version_base="1.2", config_path=root / "configs", config_name="train.yaml") 127 | def main(cfg: DictConfig) -> Optional[float]: 128 | # train the model 129 | metric_dict, _ = train(cfg) 130 | 131 | # safely retrieve metric value for hydra-based hyperparameter optimization 132 | metric_value = utils.get_metric_value( 133 | metric_dict=metric_dict, metric_name=cfg.get("optimized_metric") 134 | ) 135 | 136 | # return optimized metric 137 | return metric_value 138 | 139 | 140 | if __name__ == "__main__": 141 | main() 142 | -------------------------------------------------------------------------------- /runner/src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from src.utils.pylogger import get_pylogger 2 | from src.utils.rich_utils import enforce_tags, print_config_tree 3 | from src.utils.utils import ( 4 | close_loggers, 5 | extras, 6 | get_metric_value, 7 | instantiate_callbacks, 8 | instantiate_loggers, 9 | log_hyperparameters, 10 | save_file, 11 | task_wrapper, 12 | ) 13 | -------------------------------------------------------------------------------- /runner/src/utils/pylogger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from pytorch_lightning.utilities import rank_zero_only 4 | 5 | 6 | def get_pylogger(name=__name__) -> logging.Logger: 7 | """Initializes multi-GPU-friendly python command line logger.""" 8 | 9 | logger = logging.getLogger(name) 10 | 11 | # this ensures all logging levels get marked with the rank zero decorator 12 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup 13 | logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") 14 | for level in logging_levels: 15 | setattr(logger, level, rank_zero_only(getattr(logger, level))) 16 | 17 | return logger 18 | -------------------------------------------------------------------------------- /runner/src/utils/rich_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Sequence 3 | 4 | import rich 5 | import rich.syntax 6 | import rich.tree 7 | from hydra.core.hydra_config import HydraConfig 8 | from omegaconf import DictConfig, OmegaConf, open_dict 9 | from pytorch_lightning.utilities import rank_zero_only 10 | from rich.prompt import Prompt 11 | 12 | from src.utils import pylogger 13 | 14 | log = pylogger.get_pylogger(__name__) 15 | 16 | 17 | @rank_zero_only 18 | def print_config_tree( 19 | cfg: DictConfig, 20 | print_order: Sequence[str] = ( 21 | "datamodule", 22 | "model", 23 | "callbacks", 24 | "logger", 25 | "trainer", 26 | "paths", 27 | "extras", 28 | ), 29 | resolve: bool = False, 30 | save_to_file: bool = False, 31 | ) -> None: 32 | """Prints content of DictConfig using Rich library and its tree structure. 33 | 34 | Args: 35 | cfg (DictConfig): Configuration composed by Hydra. 36 | print_order (Sequence[str], optional): Determines in what order config components are printed. 37 | resolve (bool, optional): Whether to resolve reference fields of DictConfig. 38 | save_to_file (bool, optional): Whether to export config to the hydra output folder. 39 | """ 40 | 41 | style = "dim" 42 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 43 | 44 | queue = [] 45 | 46 | # add fields from `print_order` to queue 47 | for field in print_order: 48 | queue.append(field) if field in cfg else log.warning( 49 | f"Field '{field}' not found in config. Skipping '{field}' config printing..." 50 | ) 51 | 52 | # add all the other fields to queue (not specified in `print_order`) 53 | for field in cfg: 54 | if field not in queue: 55 | queue.append(field) 56 | 57 | # generate config tree from queue 58 | for field in queue: 59 | branch = tree.add(field, style=style, guide_style=style) 60 | 61 | config_group = cfg[field] 62 | if isinstance(config_group, DictConfig): 63 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) 64 | else: 65 | branch_content = str(config_group) 66 | 67 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 68 | 69 | # print config tree 70 | rich.print(tree) 71 | 72 | # save config tree to file 73 | if save_to_file: 74 | with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: 75 | rich.print(tree, file=file) 76 | 77 | 78 | @rank_zero_only 79 | def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: 80 | """Prompts user to input tags from command line if no tags are provided in config.""" 81 | 82 | if not cfg.get("tags"): 83 | if "id" in HydraConfig().cfg.hydra.job: 84 | raise ValueError("Specify tags before launching a multirun!") 85 | 86 | log.warning("No tags provided in config. Prompting user to input tags...") 87 | tags = Prompt.ask("Enter a list of comma separated tags", default="dev") 88 | tags = [t.strip() for t in tags.split(",") if t != ""] 89 | 90 | with open_dict(cfg): 91 | cfg.tags = tags 92 | 93 | log.info(f"Tags: {cfg.tags}") 94 | 95 | if save_to_file: 96 | with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: 97 | rich.print(cfg.tags, file=file) 98 | 99 | 100 | if __name__ == "__main__": 101 | from hydra import compose, initialize 102 | 103 | with initialize(version_base="1.2", config_path="../../configs"): 104 | cfg = compose(config_name="train.yaml", return_hydra_config=False, overrides=[]) 105 | print_config_tree(cfg, resolve=False, save_to_file=False) 106 | -------------------------------------------------------------------------------- /runner/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atong01/conditional-flow-matching/3fd278f9ef2f02e17e107e5769130b6cb44803e2/runner/tests/__init__.py -------------------------------------------------------------------------------- /runner/tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pyrootutils 2 | import pytest 3 | from hydra import compose, initialize 4 | from hydra.core.global_hydra import GlobalHydra 5 | from omegaconf import DictConfig, open_dict 6 | 7 | 8 | @pytest.fixture(scope="package") 9 | def cfg_train_global() -> DictConfig: 10 | with initialize(version_base="1.2", config_path="../configs"): 11 | cfg = compose(config_name="train.yaml", return_hydra_config=True, overrides=[]) 12 | 13 | # set defaults for all tests 14 | with open_dict(cfg): 15 | cfg.paths.root_dir = str(pyrootutils.find_root()) 16 | cfg.trainer.max_epochs = 1 17 | cfg.trainer.limit_train_batches = 0.02 18 | cfg.trainer.limit_val_batches = 0.2 19 | cfg.trainer.limit_test_batches = 0.2 20 | cfg.trainer.accelerator = "cpu" 21 | cfg.trainer.devices = 1 22 | cfg.datamodule.num_workers = 0 23 | cfg.datamodule.pin_memory = False 24 | cfg.extras.print_config = False 25 | cfg.extras.enforce_tags = False 26 | cfg.logger = None 27 | cfg.launcher = None 28 | 29 | return cfg 30 | 31 | 32 | @pytest.fixture(scope="package") 33 | def cfg_eval_global() -> DictConfig: 34 | with initialize(version_base="1.2", config_path="../configs"): 35 | cfg = compose(config_name="eval.yaml", return_hydra_config=True, overrides=["ckpt_path=."]) 36 | 37 | # set defaults for all tests 38 | with open_dict(cfg): 39 | cfg.paths.root_dir = str(pyrootutils.find_root()) 40 | cfg.trainer.max_epochs = 1 41 | cfg.trainer.limit_test_batches = 0.2 42 | cfg.trainer.accelerator = "cpu" 43 | cfg.trainer.devices = 1 44 | cfg.datamodule.num_workers = 0 45 | cfg.datamodule.pin_memory = False 46 | cfg.extras.print_config = False 47 | cfg.extras.enforce_tags = False 48 | cfg.logger = None 49 | 50 | return cfg 51 | 52 | 53 | # this is called by each test which uses `cfg_train` arg 54 | # each test generates its own temporary logging path 55 | @pytest.fixture(scope="function") 56 | def cfg_train(cfg_train_global, tmp_path) -> DictConfig: 57 | cfg = cfg_train_global.copy() 58 | 59 | with open_dict(cfg): 60 | cfg.paths.data_dir = str(tmp_path) 61 | cfg.paths.output_dir = str(tmp_path) 62 | cfg.paths.log_dir = str(tmp_path) 63 | 64 | yield cfg 65 | 66 | GlobalHydra.instance().clear() 67 | 68 | 69 | # this is called by each test which uses `cfg_eval` arg 70 | # each test generates its own temporary logging path 71 | @pytest.fixture(scope="function") 72 | def cfg_eval(cfg_eval_global, tmp_path) -> DictConfig: 73 | cfg = cfg_eval_global.copy() 74 | 75 | with open_dict(cfg): 76 | cfg.paths.data_dir = str(tmp_path) 77 | cfg.paths.output_dir = str(tmp_path) 78 | cfg.paths.log_dir = str(tmp_path) 79 | 80 | yield cfg 81 | 82 | GlobalHydra.instance().clear() 83 | -------------------------------------------------------------------------------- /runner/tests/helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atong01/conditional-flow-matching/3fd278f9ef2f02e17e107e5769130b6cb44803e2/runner/tests/helpers/__init__.py -------------------------------------------------------------------------------- /runner/tests/helpers/package_available.py: -------------------------------------------------------------------------------- 1 | import platform 2 | 3 | import pkg_resources 4 | from pytorch_lightning.accelerators import TPUAccelerator 5 | 6 | 7 | def _package_available(package_name: str) -> bool: 8 | """Check if a package is available in your environment.""" 9 | try: 10 | return pkg_resources.require(package_name) is not None 11 | except pkg_resources.DistributionNotFound: 12 | return False 13 | 14 | 15 | _TPU_AVAILABLE = TPUAccelerator.is_available() 16 | 17 | _IS_WINDOWS = platform.system() == "Windows" 18 | 19 | _SH_AVAILABLE = not _IS_WINDOWS and _package_available("sh") 20 | 21 | _DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _package_available("deepspeed") 22 | _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _package_available("fairscale") 23 | 24 | _WANDB_AVAILABLE = _package_available("wandb") 25 | _NEPTUNE_AVAILABLE = _package_available("neptune") 26 | _COMET_AVAILABLE = _package_available("comet_ml") 27 | _MLFLOW_AVAILABLE = _package_available("mlflow") 28 | -------------------------------------------------------------------------------- /runner/tests/helpers/run_if.py: -------------------------------------------------------------------------------- 1 | """Adapted from: 2 | 3 | https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/helpers/runif.py 4 | """ 5 | 6 | import sys 7 | from typing import Optional 8 | 9 | import pytest 10 | import torch 11 | from packaging.version import Version 12 | from pkg_resources import get_distribution 13 | 14 | from tests.helpers.package_available import ( 15 | _COMET_AVAILABLE, 16 | _DEEPSPEED_AVAILABLE, 17 | _FAIRSCALE_AVAILABLE, 18 | _IS_WINDOWS, 19 | _MLFLOW_AVAILABLE, 20 | _NEPTUNE_AVAILABLE, 21 | _SH_AVAILABLE, 22 | _TPU_AVAILABLE, 23 | _WANDB_AVAILABLE, 24 | ) 25 | 26 | 27 | class RunIf: 28 | """RunIf wrapper for conditional skipping of tests. 29 | 30 | Fully compatible with `@pytest.mark`. 31 | 32 | Example: 33 | 34 | @RunIf(min_torch="1.8") 35 | @pytest.mark.parametrize("arg1", [1.0, 2.0]) 36 | def test_wrapper(arg1): 37 | assert arg1 > 0 38 | """ 39 | 40 | def __new__( 41 | self, 42 | min_gpus: int = 0, 43 | min_torch: Optional[str] = None, 44 | max_torch: Optional[str] = None, 45 | min_python: Optional[str] = None, 46 | skip_windows: bool = False, 47 | sh: bool = False, 48 | tpu: bool = False, 49 | fairscale: bool = False, 50 | deepspeed: bool = False, 51 | wandb: bool = False, 52 | neptune: bool = False, 53 | comet: bool = False, 54 | mlflow: bool = False, 55 | **kwargs, 56 | ): 57 | """ 58 | Args: 59 | min_gpus: min number of GPUs required to run test 60 | min_torch: minimum pytorch version to run test 61 | max_torch: maximum pytorch version to run test 62 | min_python: minimum python version required to run test 63 | skip_windows: skip test for Windows platform 64 | tpu: if TPU is available 65 | sh: if `sh` module is required to run the test 66 | fairscale: if `fairscale` module is required to run the test 67 | deepspeed: if `deepspeed` module is required to run the test 68 | wandb: if `wandb` module is required to run the test 69 | neptune: if `neptune` module is required to run the test 70 | comet: if `comet` module is required to run the test 71 | mlflow: if `mlflow` module is required to run the test 72 | kwargs: native pytest.mark.skipif keyword arguments 73 | """ 74 | conditions = [] 75 | reasons = [] 76 | 77 | if min_gpus: 78 | conditions.append(torch.cuda.device_count() < min_gpus) 79 | reasons.append(f"GPUs>={min_gpus}") 80 | 81 | if min_torch: 82 | torch_version = get_distribution("torch").version 83 | conditions.append(Version(torch_version) < Version(min_torch)) 84 | reasons.append(f"torch>={min_torch}") 85 | 86 | if max_torch: 87 | torch_version = get_distribution("torch").version 88 | conditions.append(Version(torch_version) >= Version(max_torch)) 89 | reasons.append(f"torch<{max_torch}") 90 | 91 | if min_python: 92 | py_version = ( 93 | f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" 94 | ) 95 | conditions.append(Version(py_version) < Version(min_python)) 96 | reasons.append(f"python>={min_python}") 97 | 98 | if skip_windows: 99 | conditions.append(_IS_WINDOWS) 100 | reasons.append("does not run on Windows") 101 | 102 | if tpu: 103 | conditions.append(not _TPU_AVAILABLE) 104 | reasons.append("TPU") 105 | 106 | if sh: 107 | conditions.append(not _SH_AVAILABLE) 108 | reasons.append("sh") 109 | 110 | if fairscale: 111 | conditions.append(not _FAIRSCALE_AVAILABLE) 112 | reasons.append("fairscale") 113 | 114 | if deepspeed: 115 | conditions.append(not _DEEPSPEED_AVAILABLE) 116 | reasons.append("deepspeed") 117 | 118 | if wandb: 119 | conditions.append(not _WANDB_AVAILABLE) 120 | reasons.append("wandb") 121 | 122 | if neptune: 123 | conditions.append(not _NEPTUNE_AVAILABLE) 124 | reasons.append("neptune") 125 | 126 | if comet: 127 | conditions.append(not _COMET_AVAILABLE) 128 | reasons.append("comet") 129 | 130 | if mlflow: 131 | conditions.append(not _MLFLOW_AVAILABLE) 132 | reasons.append("mlflow") 133 | 134 | reasons = [rs for cond, rs in zip(conditions, reasons) if cond] 135 | return pytest.mark.skipif( 136 | condition=any(conditions), 137 | reason=f"Requires: [{' + '.join(reasons)}]", 138 | **kwargs, 139 | ) 140 | -------------------------------------------------------------------------------- /runner/tests/helpers/run_sh_command.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pytest 4 | 5 | from tests.helpers.package_available import _SH_AVAILABLE 6 | 7 | if _SH_AVAILABLE: 8 | import sh 9 | 10 | 11 | def run_sh_command(command: List[str]): 12 | """Default method for executing shell commands with pytest and sh package.""" 13 | msg = None 14 | try: 15 | sh.python(command) 16 | except sh.ErrorReturnCode as e: 17 | msg = e.stderr.decode() 18 | if msg: 19 | pytest.fail(msg=msg) 20 | -------------------------------------------------------------------------------- /runner/tests/test_configs.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from hydra.core.hydra_config import HydraConfig 3 | from omegaconf import DictConfig 4 | 5 | 6 | def test_train_config(cfg_train: DictConfig): 7 | assert cfg_train 8 | assert cfg_train.datamodule 9 | assert cfg_train.model 10 | assert cfg_train.trainer 11 | 12 | HydraConfig().set_config(cfg_train) 13 | 14 | hydra.utils.instantiate(cfg_train.datamodule) 15 | hydra.utils.instantiate(cfg_train.model) 16 | hydra.utils.instantiate(cfg_train.trainer) 17 | 18 | 19 | def test_eval_config(cfg_eval: DictConfig): 20 | assert cfg_eval 21 | assert cfg_eval.datamodule 22 | assert cfg_eval.model 23 | assert cfg_eval.trainer 24 | 25 | HydraConfig().set_config(cfg_eval) 26 | 27 | hydra.utils.instantiate(cfg_eval.datamodule) 28 | hydra.utils.instantiate(cfg_eval.model) 29 | hydra.utils.instantiate(cfg_eval.trainer) 30 | -------------------------------------------------------------------------------- /runner/tests/test_datamodule.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | import torch 5 | 6 | from src.datamodules.distribution_datamodule import ( 7 | SKLearnDataModule, 8 | TorchDynDataModule, 9 | TwoDimDataModule, 10 | ) 11 | 12 | 13 | @pytest.mark.parametrize("batch_size", [32, 128]) 14 | @pytest.mark.parametrize("train_val_test_split", [400, [1000, 100, 100]]) 15 | @pytest.mark.parametrize( 16 | "datamodule,system", 17 | [ 18 | (SKLearnDataModule, "scurve"), 19 | (SKLearnDataModule, "moons"), 20 | (TorchDynDataModule, "gaussians"), 21 | ], 22 | ) 23 | def test_single_datamodule(batch_size, train_val_test_split, datamodule, system): 24 | dm = datamodule( 25 | batch_size=batch_size, train_val_test_split=train_val_test_split, system=system 26 | ) 27 | 28 | assert dm.data_train is not None and dm.data_val is not None and dm.data_test is not None 29 | assert dm.train_dataloader() and dm.val_dataloader() and dm.test_dataloader() 30 | 31 | num_datapoints = len(dm.data_train) + len(dm.data_val) + len(dm.data_test) 32 | assert num_datapoints == 1200 33 | 34 | batch = next(iter(dm.train_dataloader())) 35 | x = batch 36 | assert x.dim() == 2 37 | assert x.shape[0] == batch_size 38 | assert x.shape[-1] == 2 39 | assert dm.dim == 2 40 | assert x.dtype == torch.float32 41 | 42 | 43 | @pytest.mark.parametrize("batch_size", [32, 128]) 44 | @pytest.mark.parametrize("train_val_test_split", [300, [200, 50, 50]]) 45 | @pytest.mark.parametrize( 46 | "datamodule,system", 47 | [ 48 | (TwoDimDataModule, "moon-8gaussians"), 49 | ], 50 | ) 51 | def test_trajectory_datamodule(batch_size, train_val_test_split, datamodule, system): 52 | dm = datamodule( 53 | batch_size=batch_size, train_val_test_split=train_val_test_split, system=system 54 | ) 55 | # assert dm.train_dataloader() and dm.val_dataloader() and dm.test_dataloader() 56 | 57 | batch = next(iter(dm.train_dataloader())) 58 | x = batch 59 | assert len(x) == 2 60 | for t in range(len(dm.timepoint_data)): 61 | xt = x[t] 62 | assert xt.dim() == 2 63 | assert xt.shape[0] == batch_size 64 | assert xt.shape[-1] == 2 65 | assert xt.dtype == torch.float32 66 | assert dm.dim == 2 67 | -------------------------------------------------------------------------------- /runner/tests/test_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from hydra.core.hydra_config import HydraConfig 5 | from omegaconf import open_dict 6 | 7 | from src.eval import evaluate 8 | from src.train import train 9 | 10 | 11 | @pytest.mark.slow 12 | def test_train_eval(tmp_path, cfg_train, cfg_eval): 13 | """Train for 1 epoch with `train.py` and evaluate with `eval.py`""" 14 | assert str(tmp_path) == cfg_train.paths.output_dir == cfg_eval.paths.output_dir 15 | 16 | with open_dict(cfg_train): 17 | cfg_train.trainer.max_epochs = 1 18 | cfg_train.test = True 19 | 20 | HydraConfig().set_config(cfg_train) 21 | train_metric_dict, _ = train(cfg_train) 22 | 23 | assert "last.ckpt" in os.listdir(tmp_path / "checkpoints") 24 | 25 | with open_dict(cfg_eval): 26 | cfg_eval.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt") 27 | 28 | HydraConfig().set_config(cfg_eval) 29 | test_metric_dict, _ = evaluate(cfg_eval) 30 | 31 | assert test_metric_dict["test/2-Wasserstein"] > 0.0 32 | -------------------------------------------------------------------------------- /runner/tests/test_sweeps.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tests.helpers.run_if import RunIf 4 | from tests.helpers.run_sh_command import run_sh_command 5 | 6 | startfile = "runner/src/train.py" 7 | overrides = ["logger=[]"] 8 | dir_overrides = ["paths.data_dir", "hydra.sweep.dir"] 9 | 10 | 11 | @RunIf(sh=True) 12 | @pytest.mark.slow 13 | @pytest.mark.xfail( 14 | reason="Currently failing experiments with fast_dev_run which messes with gradients" 15 | ) 16 | def test_xfail_fast_dev_experiments(tmp_path): 17 | """Test running all available experiment configs with fast_dev_run=True.""" 18 | command = ( 19 | [ 20 | startfile, 21 | "-m", 22 | "experiment=glob(*)", 23 | "++trainer.fast_dev_run=true", 24 | ] 25 | + overrides 26 | + [f"{d}={tmp_path}" for d in dir_overrides] 27 | ) 28 | run_sh_command(command) 29 | 30 | 31 | @RunIf(sh=True) 32 | @pytest.mark.slow 33 | def test_experiments(tmp_path): 34 | """Test running all available experiment configs with fast_dev_run=True.""" 35 | command = ( 36 | [ 37 | startfile, 38 | "-m", 39 | "experiment=cfm", 40 | "model=cfm,otcfm,sbcfm,fm", 41 | "++trainer.fast_dev_run=true", 42 | "++trainer.limit_val_batches=0.25", 43 | ] 44 | + overrides 45 | + [f"{d}={tmp_path}" for d in dir_overrides] 46 | ) 47 | run_sh_command(command) 48 | 49 | 50 | @RunIf(sh=True) 51 | @pytest.mark.slow 52 | def test_hydra_sweep(tmp_path): 53 | """Test default hydra sweep.""" 54 | command = ( 55 | [ 56 | startfile, 57 | "-m", 58 | "hydra.sweep.dir=" + str(tmp_path), 59 | "model.optimizer.lr=0.005,0.01", 60 | "++trainer.fast_dev_run=true", 61 | ] 62 | + overrides 63 | + [f"{d}={tmp_path}" for d in dir_overrides] 64 | ) 65 | 66 | run_sh_command(command) 67 | 68 | 69 | @RunIf(sh=True) 70 | @pytest.mark.slow 71 | @pytest.mark.xfail(reason="DDP is not working yet") 72 | def test_hydra_sweep_ddp_sim(tmp_path): 73 | """Test default hydra sweep with ddp sim.""" 74 | command = ( 75 | [ 76 | startfile, 77 | "-m", 78 | "trainer=ddp_sim", 79 | "trainer.max_epochs=3", 80 | "+trainer.limit_train_batches=0.01", 81 | "+trainer.limit_val_batches=0.1", 82 | "+trainer.limit_test_batches=0.1", 83 | "model.optimizer.lr=0.005,0.01,0.02", 84 | ] 85 | + overrides 86 | + [f"{d}={tmp_path}" for d in dir_overrides] 87 | ) 88 | run_sh_command(command) 89 | 90 | 91 | @RunIf(sh=True) 92 | @pytest.mark.slow 93 | @pytest.mark.skip(reason="Too slow for easy esting, pathway currently not used") 94 | def test_optuna_sweep(tmp_path): 95 | """Test optuna sweep.""" 96 | command = ( 97 | [ 98 | startfile, 99 | "-m", 100 | "hparams_search=optuna", 101 | "hydra.sweep.dir=" + str(tmp_path), 102 | "hydra.sweeper.n_trials=3", 103 | "hydra.sweeper.sampler.n_startup_trials=2", 104 | # "++trainer.fast_dev_run=true", 105 | ] 106 | + overrides 107 | + [f"{d}={tmp_path}" for d in dir_overrides] 108 | ) 109 | run_sh_command(command) 110 | 111 | 112 | @RunIf(wandb=True, sh=True) 113 | @pytest.mark.slow 114 | @pytest.mark.xfail(reason="wandb import is still bad without API key") 115 | def test_optuna_sweep_ddp_sim_wandb(tmp_path): 116 | """Test optuna sweep with wandb and ddp sim.""" 117 | command = [ 118 | startfile, 119 | "-m", 120 | "hparams_search=optuna", 121 | "hydra.sweeper.n_trials=5", 122 | "trainer=ddp_sim", 123 | "trainer.max_epochs=3", 124 | "+trainer.limit_train_batches=0.01", 125 | "+trainer.limit_val_batches=0.1", 126 | "+trainer.limit_test_batches=0.1", 127 | "logger=wandb", 128 | ] + [f"{d}={tmp_path}" for d in dir_overrides] 129 | run_sh_command(command) 130 | -------------------------------------------------------------------------------- /runner/tests/test_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from hydra.core.hydra_config import HydraConfig 5 | from omegaconf import open_dict 6 | 7 | from src.train import train 8 | from tests.helpers.run_if import RunIf 9 | 10 | 11 | def test_train_fast_dev_run(cfg_train): 12 | """Run for 1 train, val and test step.""" 13 | HydraConfig().set_config(cfg_train) 14 | with open_dict(cfg_train): 15 | cfg_train.trainer.fast_dev_run = True 16 | cfg_train.trainer.accelerator = "cpu" 17 | train(cfg_train) 18 | 19 | 20 | @RunIf(min_gpus=1) 21 | def test_train_fast_dev_run_gpu(cfg_train): 22 | """Run for 1 train, val and test step on GPU.""" 23 | HydraConfig().set_config(cfg_train) 24 | with open_dict(cfg_train): 25 | cfg_train.trainer.fast_dev_run = True 26 | cfg_train.trainer.accelerator = "gpu" 27 | train(cfg_train) 28 | 29 | 30 | @RunIf(min_gpus=1) 31 | @pytest.mark.slow 32 | def test_train_epoch_gpu_amp(cfg_train): 33 | """Train 1 epoch on GPU with mixed-precision.""" 34 | HydraConfig().set_config(cfg_train) 35 | with open_dict(cfg_train): 36 | cfg_train.trainer.max_epochs = 1 37 | cfg_train.trainer.accelerator = "cpu" 38 | cfg_train.trainer.precision = 16 39 | train(cfg_train) 40 | 41 | 42 | @pytest.mark.slow 43 | def test_train_epoch_double_val_loop(cfg_train): 44 | """Train 1 epoch with validation loop twice per epoch.""" 45 | HydraConfig().set_config(cfg_train) 46 | with open_dict(cfg_train): 47 | cfg_train.trainer.max_epochs = 1 48 | cfg_train.trainer.val_check_interval = 0.5 49 | train(cfg_train) 50 | 51 | 52 | @pytest.mark.slow 53 | @pytest.mark.xfail(reason="DDP currently failing") 54 | def test_train_ddp_sim(cfg_train): 55 | """Simulate DDP (Distributed Data Parallel) on 2 CPU processes.""" 56 | HydraConfig().set_config(cfg_train) 57 | with open_dict(cfg_train): 58 | cfg_train.trainer.max_epochs = 2 59 | cfg_train.trainer.accelerator = "cpu" 60 | cfg_train.trainer.devices = 2 61 | cfg_train.trainer.strategy = "ddp_spawn" 62 | train(cfg_train) 63 | 64 | 65 | @pytest.mark.slow 66 | def test_train_resume(tmp_path, cfg_train): 67 | """Run 1 epoch, finish, and resume for another epoch.""" 68 | with open_dict(cfg_train): 69 | cfg_train.trainer.max_epochs = 1 70 | cfg_train.callbacks.model_checkpoint.save_top_k = 2 71 | print(cfg_train) 72 | 73 | HydraConfig().set_config(cfg_train) 74 | metric_dict_1, _ = train(cfg_train) 75 | 76 | files = os.listdir(tmp_path / "checkpoints") 77 | assert "last.ckpt" in files 78 | assert "epoch_0000.ckpt" in files 79 | 80 | with open_dict(cfg_train): 81 | cfg_train.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt") 82 | cfg_train.trainer.max_epochs = 2 83 | 84 | metric_dict_2, _ = train(cfg_train) 85 | 86 | files = os.listdir(tmp_path / "checkpoints") 87 | assert "epoch_0001.ckpt" in files 88 | assert "epoch_0002.ckpt" not in files 89 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | 5 | from setuptools import find_packages, setup 6 | 7 | install_requires = [ 8 | "torch>=1.11.0", 9 | "matplotlib", 10 | "numpy", # Due to pandas incompatibility 11 | "scipy", 12 | "scikit-learn", 13 | "torchdyn>=1.0.6", 14 | "pot", 15 | "torchdiffeq", 16 | "absl-py", 17 | "pandas>=2.2.2", 18 | ] 19 | 20 | version_py = os.path.join(os.path.dirname(__file__), "torchcfm", "version.py") 21 | version = open(version_py).read().strip().split("=")[-1].replace('"', "").strip() 22 | readme = open("README.md", encoding="utf8").read() 23 | setup( 24 | name="torchcfm", 25 | version=version, 26 | description="Conditional Flow Matching for Fast Continuous Normalizing Flow Training.", 27 | author="Alexander Tong, Kilian Fatras", 28 | author_email="alexandertongdev@gmail.com", 29 | url="https://github.com/atong01/conditional-flow-matching", 30 | install_requires=install_requires, 31 | license="MIT", 32 | long_description=readme, 33 | long_description_content_type="text/markdown", 34 | packages=find_packages(exclude=["tests", "tests.*"]), 35 | extras_require={"forest-flow": ["xgboost", "scikit-learn", "ForestDiffusion"]}, 36 | ) 37 | -------------------------------------------------------------------------------- /tests/test_conditional_flow_matcher.py: -------------------------------------------------------------------------------- 1 | """Tests for Conditional Flow Matcher classers.""" 2 | 3 | # Author: Kilian Fatras 4 | 5 | import math 6 | 7 | import numpy as np 8 | import pytest 9 | import torch 10 | 11 | from torchcfm.conditional_flow_matching import ( 12 | ConditionalFlowMatcher, 13 | ExactOptimalTransportConditionalFlowMatcher, 14 | SchrodingerBridgeConditionalFlowMatcher, 15 | TargetConditionalFlowMatcher, 16 | VariancePreservingConditionalFlowMatcher, 17 | pad_t_like_x, 18 | ) 19 | from torchcfm.optimal_transport import OTPlanSampler 20 | 21 | TEST_SEED = 1994 22 | TEST_BATCH_SIZE = 128 23 | SIGMA_CONDITION = { 24 | "sb_cfm": lambda x: x <= 0, 25 | } 26 | 27 | 28 | def random_samples(shape, batch_size=TEST_BATCH_SIZE): 29 | """Generate random samples of different dimensions.""" 30 | if isinstance(shape, int): 31 | shape = [shape] 32 | return [torch.randn(batch_size, *shape), torch.randn(batch_size, *shape)] 33 | 34 | 35 | def compute_xt_ut(method, x0, x1, t_given, sigma, epsilon): 36 | if method == "vp_cfm": 37 | sigma_t = sigma 38 | mu_t = torch.cos(math.pi / 2 * t_given) * x0 + torch.sin(math.pi / 2 * t_given) * x1 39 | computed_xt = mu_t + sigma_t * epsilon 40 | computed_ut = ( 41 | math.pi 42 | / 2 43 | * (torch.cos(math.pi / 2 * t_given) * x1 - torch.sin(math.pi / 2 * t_given) * x0) 44 | ) 45 | elif method == "t_cfm": 46 | sigma_t = 1 - (1 - sigma) * t_given 47 | mu_t = t_given * x1 48 | computed_xt = mu_t + sigma_t * epsilon 49 | computed_ut = (x1 - (1 - sigma) * computed_xt) / sigma_t 50 | 51 | elif method == "sb_cfm": 52 | sigma_t = sigma * torch.sqrt(t_given * (1 - t_given)) 53 | mu_t = t_given * x1 + (1 - t_given) * x0 54 | computed_xt = mu_t + sigma_t * epsilon 55 | computed_ut = ( 56 | (1 - 2 * t_given) 57 | / (2 * t_given * (1 - t_given) + 1e-8) 58 | * (computed_xt - (t_given * x1 + (1 - t_given) * x0)) 59 | + x1 60 | - x0 61 | ) 62 | elif method in ["exact_ot_cfm", "i_cfm"]: 63 | sigma_t = sigma 64 | mu_t = t_given * x1 + (1 - t_given) * x0 65 | computed_xt = mu_t + sigma_t * epsilon 66 | computed_ut = x1 - x0 67 | 68 | return computed_xt, computed_ut 69 | 70 | 71 | def get_flow_matcher(method, sigma): 72 | if method == "vp_cfm": 73 | fm = VariancePreservingConditionalFlowMatcher(sigma=sigma) 74 | elif method == "t_cfm": 75 | fm = TargetConditionalFlowMatcher(sigma=sigma) 76 | elif method == "sb_cfm": 77 | fm = SchrodingerBridgeConditionalFlowMatcher(sigma=sigma, ot_method="sinkhorn") 78 | elif method == "exact_ot_cfm": 79 | fm = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma) 80 | elif method == "i_cfm": 81 | fm = ConditionalFlowMatcher(sigma=sigma) 82 | return fm 83 | 84 | 85 | def sample_plan(method, x0, x1, sigma): 86 | if method == "sb_cfm": 87 | x0, x1 = OTPlanSampler(method="sinkhorn", reg=2 * (sigma**2)).sample_plan(x0, x1) 88 | elif method == "exact_ot_cfm": 89 | x0, x1 = OTPlanSampler(method="exact").sample_plan(x0, x1) 90 | return x0, x1 91 | 92 | 93 | @pytest.mark.parametrize("method", ["vp_cfm", "t_cfm", "sb_cfm", "exact_ot_cfm", "i_cfm"]) 94 | # Test both integer and floating sigma 95 | @pytest.mark.parametrize("sigma", [0.0, 5e-4, 0.5, 1.5, 0, 1]) 96 | @pytest.mark.parametrize("shape", [[1], [2], [1, 2], [3, 4, 5]]) 97 | def test_fm(method, sigma, shape): 98 | batch_size = TEST_BATCH_SIZE 99 | 100 | if method in SIGMA_CONDITION.keys() and SIGMA_CONDITION[method](sigma): 101 | with pytest.raises(ValueError): 102 | get_flow_matcher(method, sigma) 103 | return 104 | 105 | FM = get_flow_matcher(method, sigma) 106 | x0, x1 = random_samples(shape, batch_size=batch_size) 107 | torch.manual_seed(TEST_SEED) 108 | np.random.seed(TEST_SEED) 109 | t, xt, ut, eps = FM.sample_location_and_conditional_flow(x0, x1, return_noise=True) 110 | _ = FM.compute_lambda(t) 111 | 112 | if method in ["sb_cfm", "exact_ot_cfm"]: 113 | torch.manual_seed(TEST_SEED) 114 | np.random.seed(TEST_SEED) 115 | x0, x1 = sample_plan(method, x0, x1, sigma) 116 | 117 | torch.manual_seed(TEST_SEED) 118 | t_given_init = torch.rand(batch_size) 119 | t_given = t_given_init.reshape(-1, *([1] * (x0.dim() - 1))) 120 | sigma_pad = pad_t_like_x(sigma, x0) 121 | epsilon = torch.randn_like(x0) 122 | computed_xt, computed_ut = compute_xt_ut(method, x0, x1, t_given, sigma_pad, epsilon) 123 | 124 | assert torch.all(ut.eq(computed_ut)) 125 | assert torch.all(xt.eq(computed_xt)) 126 | assert torch.all(eps.eq(epsilon)) 127 | assert any(t_given_init == t) 128 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from torchcfm.models import MLP 4 | from torchcfm.models.unet import UNetModel 5 | 6 | 7 | def test_initialize_models(): 8 | model = UNetModel( 9 | dim=(1, 28, 28), 10 | num_channels=32, 11 | num_res_blocks=1, 12 | num_classes=10, 13 | class_cond=True, 14 | ) 15 | model = MLP(dim=2, time_varying=True, w=64) 16 | -------------------------------------------------------------------------------- /tests/test_optimal_transport.py: -------------------------------------------------------------------------------- 1 | """Tests for Conditional Flow Matcher classers.""" 2 | 3 | # Author: Kilian Fatras 4 | 5 | import math 6 | 7 | import numpy as np 8 | import ot 9 | import pytest 10 | import torch 11 | 12 | from torchcfm.optimal_transport import OTPlanSampler, wasserstein 13 | 14 | ot_sampler = OTPlanSampler(method="exact") 15 | 16 | 17 | def test_sample_map(batch_size=128): 18 | # Build sparse random OT map 19 | map = np.eye(batch_size) 20 | rng = np.random.default_rng() 21 | permuted_map = rng.permutation(map, axis=1) 22 | 23 | # Sample elements from the OT plan 24 | # All elements should be sampled only once 25 | indices = ot_sampler.sample_map(permuted_map, batch_size=batch_size, replace=False) 26 | 27 | # Reconstruct the coupling from the sampled elements 28 | reconstructed_map = np.zeros((batch_size, batch_size)) 29 | for i in range(batch_size): 30 | reconstructed_map[indices[0][i], indices[1][i]] = 1 31 | assert np.array_equal(reconstructed_map, permuted_map) 32 | 33 | 34 | def test_get_map(batch_size=128): 35 | x0 = torch.randn(batch_size, 2, 2, 2) 36 | x1 = torch.randn(batch_size, 2, 2, 2) 37 | 38 | M = torch.cdist(x0.reshape(x0.shape[0], -1), x1.reshape(x1.shape[0], -1)) ** 2 39 | pot_pi = ot.emd(ot.unif(x0.shape[0]), ot.unif(x1.shape[0]), M.numpy()) 40 | 41 | pi = ot_sampler.get_map(x0, x1) 42 | 43 | assert np.array_equal(pi, pot_pi) 44 | 45 | 46 | def test_sample_plan(batch_size=128, seed=1980): 47 | torch.manual_seed(seed) 48 | np.random.seed(seed) 49 | x0 = torch.randn(batch_size, 2, 2, 2) 50 | x1 = torch.randn(batch_size, 2, 2, 2) 51 | 52 | pi = ot_sampler.get_map(x0, x1) 53 | indices_i, indices_j = ot_sampler.sample_map(pi, batch_size=batch_size, replace=True) 54 | new_x0, new_x1 = x0[indices_i], x1[indices_j] 55 | 56 | torch.manual_seed(seed) 57 | np.random.seed(seed) 58 | 59 | sampled_x0, sampled_x1 = ot_sampler.sample_plan(x0, x1, replace=True) 60 | 61 | assert torch.equal(new_x0, sampled_x0) 62 | assert torch.equal(new_x1, sampled_x1) 63 | 64 | 65 | def test_wasserstein(batch_size=128, seed=1980): 66 | torch.manual_seed(seed) 67 | np.random.seed(seed) 68 | x0 = torch.randn(batch_size, 2, 2, 2) 69 | x1 = torch.randn(batch_size, 2, 2, 2) 70 | 71 | M = torch.cdist(x0.reshape(x0.shape[0], -1), x1.reshape(x1.shape[0], -1)) 72 | pot_W22 = ot.emd2(ot.unif(x0.shape[0]), ot.unif(x1.shape[0]), (M**2).numpy()) 73 | pot_W2 = np.sqrt(pot_W22) 74 | W2 = wasserstein(x0, x1, "exact") 75 | 76 | pot_W1 = ot.emd2(ot.unif(x0.shape[0]), ot.unif(x1.shape[0]), M.numpy()) 77 | W1 = wasserstein(x0, x1, "exact", power=1) 78 | 79 | pot_eot = ot.sinkhorn2( 80 | ot.unif(x0.shape[0]), ot.unif(x1.shape[0]), M.numpy(), reg=0.01, numItermax=int(1e7) 81 | ) 82 | eot = wasserstein(x0, x1, "sinkhorn", reg=0.01, power=1) 83 | 84 | with pytest.raises(ValueError) as excinfo: 85 | eot = wasserstein(x0, x1, "noname", reg=0.01, power=1) 86 | 87 | assert pot_W2 == W2 88 | assert pot_W1 == W1 89 | assert pot_eot == eot 90 | -------------------------------------------------------------------------------- /tests/test_time_t.py: -------------------------------------------------------------------------------- 1 | """Tests for time Tensor t.""" 2 | 3 | # Author: Kilian Fatras 4 | 5 | import pytest 6 | import torch 7 | 8 | from torchcfm.conditional_flow_matching import ( 9 | ConditionalFlowMatcher, 10 | ExactOptimalTransportConditionalFlowMatcher, 11 | SchrodingerBridgeConditionalFlowMatcher, 12 | TargetConditionalFlowMatcher, 13 | VariancePreservingConditionalFlowMatcher, 14 | ) 15 | 16 | seed = 1994 17 | batch_size = 128 18 | 19 | 20 | @pytest.mark.parametrize( 21 | "FM", 22 | [ 23 | ConditionalFlowMatcher(sigma=0.0), 24 | ExactOptimalTransportConditionalFlowMatcher(sigma=0.0), 25 | TargetConditionalFlowMatcher(sigma=0.0), 26 | SchrodingerBridgeConditionalFlowMatcher(sigma=0.1), 27 | VariancePreservingConditionalFlowMatcher(sigma=0.0), 28 | ], 29 | ) 30 | def test_random_Tensor_t(FM): 31 | # Test sample_location_and_conditional_flow functions 32 | x0 = torch.randn(batch_size, 2) 33 | x1 = torch.randn(batch_size, 2) 34 | 35 | torch.manual_seed(seed) 36 | t_given = torch.rand(batch_size) 37 | t_given, xt, ut = FM.sample_location_and_conditional_flow(x0, x1, t=t_given) 38 | 39 | torch.manual_seed(seed) 40 | t_random, xt, ut = FM.sample_location_and_conditional_flow(x0, x1, t=None) 41 | 42 | assert any(t_given == t_random) 43 | 44 | 45 | @pytest.mark.parametrize( 46 | "FM", 47 | [ 48 | ExactOptimalTransportConditionalFlowMatcher(sigma=0.0), 49 | SchrodingerBridgeConditionalFlowMatcher(sigma=0.1), 50 | ], 51 | ) 52 | @pytest.mark.parametrize("return_noise", [True, False]) 53 | def test_guided_random_Tensor_t(FM, return_noise): 54 | # Test guided_sample_location_and_conditional_flow functions 55 | x0 = torch.randn(batch_size, 2) 56 | y0 = torch.randint(high=10, size=(batch_size, 1)) 57 | x1 = torch.randn(batch_size, 2) 58 | y1 = torch.randint(high=10, size=(batch_size, 1)) 59 | 60 | torch.manual_seed(seed) 61 | t_given = torch.rand(batch_size) 62 | t_given = FM.guided_sample_location_and_conditional_flow( 63 | x0, x1, y0=y0, y1=y1, t=t_given, return_noise=return_noise 64 | )[0] 65 | 66 | torch.manual_seed(seed) 67 | t_random = FM.guided_sample_location_and_conditional_flow( 68 | x0, x1, y0=y0, y1=y1, t=None, return_noise=return_noise 69 | )[0] 70 | 71 | assert any(t_given == t_random) 72 | -------------------------------------------------------------------------------- /torchcfm/__init__.py: -------------------------------------------------------------------------------- 1 | from .conditional_flow_matching import * 2 | from .version import __version__ 3 | -------------------------------------------------------------------------------- /torchcfm/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import MLP 2 | -------------------------------------------------------------------------------- /torchcfm/models/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class MLP(torch.nn.Module): 5 | def __init__(self, dim, out_dim=None, w=64, time_varying=False): 6 | super().__init__() 7 | self.time_varying = time_varying 8 | if out_dim is None: 9 | out_dim = dim 10 | self.net = torch.nn.Sequential( 11 | torch.nn.Linear(dim + (1 if time_varying else 0), w), 12 | torch.nn.SELU(), 13 | torch.nn.Linear(w, w), 14 | torch.nn.SELU(), 15 | torch.nn.Linear(w, w), 16 | torch.nn.SELU(), 17 | torch.nn.Linear(w, out_dim), 18 | ) 19 | 20 | def forward(self, x): 21 | return self.net(x) 22 | 23 | 24 | class GradModel(torch.nn.Module): 25 | def __init__(self, action): 26 | super().__init__() 27 | self.action = action 28 | 29 | def forward(self, x): 30 | x = x.requires_grad_(True) 31 | grad = torch.autograd.grad(torch.sum(self.action(x)), x, create_graph=True)[0] 32 | return grad[:, :-1] 33 | -------------------------------------------------------------------------------- /torchcfm/models/unet/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet import UNetModelWrapper as UNetModel 2 | -------------------------------------------------------------------------------- /torchcfm/models/unet/nn.py: -------------------------------------------------------------------------------- 1 | """Various utilities for neural networks.""" 2 | 3 | import math 4 | 5 | import torch as th 6 | import torch.nn as nn 7 | 8 | 9 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 10 | class SiLU(nn.Module): 11 | def forward(self, x): 12 | return x * th.sigmoid(x) 13 | 14 | 15 | class GroupNorm32(nn.GroupNorm): 16 | def forward(self, x): 17 | return super().forward(x.float()).type(x.dtype) 18 | 19 | 20 | def conv_nd(dims, *args, **kwargs): 21 | """Create a 1D, 2D, or 3D convolution module.""" 22 | if dims == 1: 23 | return nn.Conv1d(*args, **kwargs) 24 | elif dims == 2: 25 | return nn.Conv2d(*args, **kwargs) 26 | elif dims == 3: 27 | return nn.Conv3d(*args, **kwargs) 28 | raise ValueError(f"unsupported dimensions: {dims}") 29 | 30 | 31 | def linear(*args, **kwargs): 32 | """Create a linear module.""" 33 | return nn.Linear(*args, **kwargs) 34 | 35 | 36 | def avg_pool_nd(dims, *args, **kwargs): 37 | """Create a 1D, 2D, or 3D average pooling module.""" 38 | if dims == 1: 39 | return nn.AvgPool1d(*args, **kwargs) 40 | elif dims == 2: 41 | return nn.AvgPool2d(*args, **kwargs) 42 | elif dims == 3: 43 | return nn.AvgPool3d(*args, **kwargs) 44 | raise ValueError(f"unsupported dimensions: {dims}") 45 | 46 | 47 | def update_ema(target_params, source_params, rate=0.99): 48 | """Update target parameters to be closer to those of source parameters using an exponential 49 | moving average. 50 | 51 | :param target_params: the target parameter sequence. 52 | :param source_params: the source parameter sequence. 53 | :param rate: the EMA rate (closer to 1 means slower). 54 | """ 55 | for targ, src in zip(target_params, source_params): 56 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 57 | 58 | 59 | def zero_module(module): 60 | """Zero out the parameters of a module and return it.""" 61 | for p in module.parameters(): 62 | p.detach().zero_() 63 | return module 64 | 65 | 66 | def scale_module(module, scale): 67 | """Scale the parameters of a module and return it.""" 68 | for p in module.parameters(): 69 | p.detach().mul_(scale) 70 | return module 71 | 72 | 73 | def mean_flat(tensor): 74 | """Take the mean over all non-batch dimensions.""" 75 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 76 | 77 | 78 | def normalization(channels): 79 | """Make a standard normalization layer. 80 | 81 | :param channels: number of input channels. 82 | :return: an nn.Module for normalization. 83 | """ 84 | return GroupNorm32(32, channels) 85 | 86 | 87 | def timestep_embedding(timesteps, dim, max_period=10000): 88 | """Create sinusoidal timestep embeddings. 89 | 90 | :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. 91 | :param dim: the dimension of the output. 92 | :param max_period: controls the minimum frequency of the embeddings. 93 | :return: an [N x dim] Tensor of positional embeddings. 94 | """ 95 | half = dim // 2 96 | freqs = th.exp( 97 | -math.log(max_period) 98 | * th.arange(start=0, end=half, dtype=th.float32, device=timesteps.device) 99 | / half 100 | ) 101 | args = timesteps[:, None].float() * freqs[None] 102 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 103 | if dim % 2: 104 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 105 | return embedding 106 | 107 | 108 | def checkpoint(func, inputs, params, flag): 109 | """Evaluate a function without caching intermediate activations, allowing for reduced memory at 110 | the expense of extra compute in the backward pass. 111 | 112 | :param func: the function to evaluate. 113 | :param inputs: the argument sequence to pass to `func`. 114 | :param params: a sequence of parameters `func` depends on but does not 115 | explicitly take as arguments. 116 | :param flag: if False, disable gradient checkpointing. 117 | """ 118 | if flag: 119 | args = tuple(inputs) + tuple(params) 120 | return CheckpointFunction.apply(func, len(inputs), *args) 121 | else: 122 | return func(*inputs) 123 | 124 | 125 | class CheckpointFunction(th.autograd.Function): 126 | @staticmethod 127 | def forward(ctx, run_function, length, *args): 128 | ctx.run_function = run_function 129 | ctx.input_tensors = list(args[:length]) 130 | ctx.input_params = list(args[length:]) 131 | with th.no_grad(): 132 | output_tensors = ctx.run_function(*ctx.input_tensors) 133 | return output_tensors 134 | 135 | @staticmethod 136 | def backward(ctx, *output_grads): 137 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 138 | with th.enable_grad(): 139 | # Fixes a bug where the first op in run_function modifies the 140 | # Tensor storage in place, which is not allowed for detach()'d 141 | # Tensors. 142 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 143 | output_tensors = ctx.run_function(*shallow_copies) 144 | input_grads = th.autograd.grad( 145 | output_tensors, 146 | ctx.input_tensors + ctx.input_params, 147 | output_grads, 148 | allow_unused=True, 149 | ) 150 | del ctx.input_tensors 151 | del ctx.input_params 152 | del output_tensors 153 | return (None, None) + input_grads 154 | -------------------------------------------------------------------------------- /torchcfm/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import torch 6 | import torchdyn 7 | from torchdyn.datasets import generate_moons 8 | 9 | # Implement some helper functions 10 | 11 | 12 | def eight_normal_sample(n, dim, scale=1, var=1): 13 | m = torch.distributions.multivariate_normal.MultivariateNormal( 14 | torch.zeros(dim), math.sqrt(var) * torch.eye(dim) 15 | ) 16 | centers = [ 17 | (1, 0), 18 | (-1, 0), 19 | (0, 1), 20 | (0, -1), 21 | (1.0 / np.sqrt(2), 1.0 / np.sqrt(2)), 22 | (1.0 / np.sqrt(2), -1.0 / np.sqrt(2)), 23 | (-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)), 24 | (-1.0 / np.sqrt(2), -1.0 / np.sqrt(2)), 25 | ] 26 | centers = torch.tensor(centers) * scale 27 | noise = m.sample((n,)) 28 | multi = torch.multinomial(torch.ones(8), n, replacement=True) 29 | data = [] 30 | for i in range(n): 31 | data.append(centers[multi[i]] + noise[i]) 32 | data = torch.stack(data) 33 | return data 34 | 35 | 36 | def sample_moons(n): 37 | x0, _ = generate_moons(n, noise=0.2) 38 | return x0 * 3 - 1 39 | 40 | 41 | def sample_8gaussians(n): 42 | return eight_normal_sample(n, 2, scale=5, var=0.1).float() 43 | 44 | 45 | class torch_wrapper(torch.nn.Module): 46 | """Wraps model to torchdyn compatible format.""" 47 | 48 | def __init__(self, model): 49 | super().__init__() 50 | self.model = model 51 | 52 | def forward(self, t, x, *args, **kwargs): 53 | return self.model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1)) 54 | 55 | 56 | def plot_trajectories(traj): 57 | """Plot trajectories of some selected samples.""" 58 | n = 2000 59 | plt.figure(figsize=(6, 6)) 60 | plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black") 61 | plt.scatter(traj[:, :n, 0], traj[:, :n, 1], s=0.2, alpha=0.2, c="olive") 62 | plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue") 63 | plt.legend(["Prior sample z(S)", "Flow", "z(0)"]) 64 | plt.xticks([]) 65 | plt.yticks([]) 66 | plt.show() 67 | -------------------------------------------------------------------------------- /torchcfm/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.7" 2 | --------------------------------------------------------------------------------