├── .gitignore
├── .python-version
├── LICENSE
├── README.md
├── configs
└── conv-8.yaml
├── media
├── step0.png
├── step128.png
└── step64.png
├── notebooks
└── sample.ipynb
├── pyproject.toml
├── src
└── discrete_flow_matching_pytorch
│ ├── __init__.py
│ ├── data.py
│ ├── flops.py
│ ├── lightning.py
│ ├── model.py
│ └── train.py
└── uv.lock
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
164 | /discrete-flow-matching
165 | /lightning_logs
166 | /wandb
167 |
--------------------------------------------------------------------------------
/.python-version:
--------------------------------------------------------------------------------
1 | 3.12
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Robin Kahlow
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Discrete Flow Matching implemented in PyTorch
2 |
3 | Implementation of Discrete Flow Matching [[1]](https://arxiv.org/abs/2402.04997)[[2]](https://arxiv.org/abs/2407.15595), which is a generative model for generating discrete things such as text with flow matching. The code is implemented in PyTorch.
4 |
5 | | Step 0 of 128 (input) | Step 64 of 128 | Step 128 of 128 (output) |
6 | | -------------------------- | ---------------------------- | ------------------------------ |
7 | |  |  |  |
8 |
9 | ## How to run
10 |
11 | ### Environment setup
12 |
13 | 1. Install [uv](https://github.com/astral-sh/uv) for package management, e.g. `pip install uv`
14 | 2. Make sure Python 3.12 is installed: `uv python install 3.12`
15 | 3. Install the dependencies: `uv sync --group jupyter`
16 |
17 | Run `python -m discrete_flow_matching_pytorch.train --config configs/conv-8.yaml` to start training a text generation model logging to wandb.
18 |
19 | The [sample notebook](notebooks/sample.ipynb) demonstrates the sampling process.
20 |
21 | **Note**: Instead of using uv, it is also possible to install the dependencies in [pyproject.toml](pyproject.toml) with pip.
22 |
23 | ## Summary of discrete flow matching compared to continuous flow matching
24 |
25 | - During training, we mask out text tokens according to the timestep
26 | - The model is trained to predict the original unmasked tokens with cross entropy loss
27 | - In sampling, we unmask text gradually with the sampled tokens
28 |
29 | ## References
30 |
31 | - [[1] Generative Flows on Discrete State-Spaces: Enabling Multimodal Flows with Applications to Protein Co-Design](https://arxiv.org/abs/2402.04997) ([YouTube presentation](https://www.youtube.com/watch?v=yzc29vhM2Aw)): Combines discrete and continuous flow matching. Originally introduced Discrete Flow Matching. Appendix F was very useful for the implementation
32 | - [[2] Discrete Flow Matching](https://arxiv.org/abs/2407.15595): Builds on Multiflow's Discrete Flow Matching
33 |
--------------------------------------------------------------------------------
/configs/conv-8.yaml:
--------------------------------------------------------------------------------
1 | dataset: github_code_python_mit
2 | shuffle_train: false
3 | val_split_name: train
4 | train_workers: 4
5 | hidden_dim: 1024
6 | num_layers: 8
7 | train_step_flops: 62895031320576
8 | learning_rate: 1e-3
9 |
--------------------------------------------------------------------------------
/media/step0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RobinKa/discrete-flow-matching-pytorch/590f2db6c8c46e1b0138ed6a40bd603e4d2c77a8/media/step0.png
--------------------------------------------------------------------------------
/media/step128.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RobinKa/discrete-flow-matching-pytorch/590f2db6c8c46e1b0138ed6a40bd603e4d2c77a8/media/step128.png
--------------------------------------------------------------------------------
/media/step64.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RobinKa/discrete-flow-matching-pytorch/590f2db6c8c46e1b0138ed6a40bd603e4d2c77a8/media/step64.png
--------------------------------------------------------------------------------
/notebooks/sample.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 2,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stdout",
10 | "output_type": "stream",
11 | "text": [
12 | "The autoreload extension is already loaded. To reload it, use:\n",
13 | " %reload_ext autoreload\n"
14 | ]
15 | }
16 | ],
17 | "source": [
18 | "%load_ext autoreload\n",
19 | "%autoreload 2"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": 3,
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "import ipywidgets as widgets\n",
29 | "import torch\n",
30 | "\n",
31 | "from discrete_flow_matching_pytorch.model import DiscreteFlowMatchingNet"
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": 4,
37 | "metadata": {},
38 | "outputs": [
39 | {
40 | "data": {
41 | "text/plain": [
42 | "DiscreteFlowMatchingNet(\n",
43 | " (input_projection): Embedding(50259, 768)\n",
44 | " (embed_timestep): Sequential(\n",
45 | " (0): Embedding(1024, 768)\n",
46 | " (1): Unsqueeze()\n",
47 | " )\n",
48 | " (blocks): ModuleList(\n",
49 | " (0-5): 6 x Sequential(\n",
50 | " (0): Transpose()\n",
51 | " (1): Conv1d(768, 768, kernel_size=(31,), stride=(1,), padding=same)\n",
52 | " (2): Transpose()\n",
53 | " (3): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
54 | " (4): GELU(approximate='none')\n",
55 | " (5): Linear(in_features=768, out_features=768, bias=True)\n",
56 | " (6): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
57 | " (7): GELU(approximate='none')\n",
58 | " )\n",
59 | " )\n",
60 | " (timestep_embedding_norms): ModuleList(\n",
61 | " (0-5): 6 x LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
62 | " )\n",
63 | " (output_projection): Linear(in_features=768, out_features=50259, bias=True)\n",
64 | ")"
65 | ]
66 | },
67 | "execution_count": 4,
68 | "metadata": {},
69 | "output_type": "execute_result"
70 | }
71 | ],
72 | "source": [
73 | "model = DiscreteFlowMatchingNet.load_from_checkpoint(\n",
74 | " \"../flow-matching-tiny-stories/1pkkmee9/checkpoints/epoch=0-step=3000.ckpt\"\n",
75 | ")\n",
76 | "model.freeze()\n",
77 | "model.eval()\n",
78 | "model.to(dtype=torch.bfloat16, device=\"cuda:0\")\n",
79 | "model"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": 5,
85 | "metadata": {},
86 | "outputs": [],
87 | "source": [
88 | "step_tokens: list[list[str]] = []\n",
89 | "step_texts: list[str] = []\n",
90 | "with torch.inference_mode():\n",
91 | " for t, samples in model.sample(\n",
92 | " num_samples=1,\n",
93 | " sequence_length=128,\n",
94 | " num_sampling_steps=128,\n",
95 | " stochasticity=5.0,\n",
96 | " yield_intermediate=True,\n",
97 | " ):\n",
98 | " step_tokens.append(model.tokenizer.batch_decode(samples[0]))\n",
99 | " step_texts.append(model.tokenizer.decode(samples[0]))"
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": 6,
105 | "metadata": {},
106 | "outputs": [
107 | {
108 | "data": {
109 | "application/vnd.jupyter.widget-view+json": {
110 | "model_id": "30ad5d91ccda466da693e68c74289aaa",
111 | "version_major": 2,
112 | "version_minor": 0
113 | },
114 | "text/plain": [
115 | "interactive(children=(IntSlider(value=64, description='step', max=128), Output()), _dom_classes=('widget-inter…"
116 | ]
117 | },
118 | "metadata": {},
119 | "output_type": "display_data"
120 | }
121 | ],
122 | "source": [
123 | "def get_step_text(step):\n",
124 | " rows = []\n",
125 | " column_size = 16\n",
126 | " for i in range(0, len(step_tokens[step]), column_size):\n",
127 | " rows.append(step_tokens[step][i : i + column_size])\n",
128 | "\n",
129 | " # Create html table\n",
130 | " html_table = '
'\n",
131 | " for row in rows:\n",
132 | " html_table += \"\"\n",
133 | " for token in row:\n",
134 | " style = \"border: 1px solid black; width: 50px; text-align: center;\"\n",
135 | " if token == \"[MASK]\":\n",
136 | " style += \"background-color: #cccccc;\"\n",
137 | " else:\n",
138 | " style += \"background-color: #eeeeee;\"\n",
139 | " html_table += f'{token} | '\n",
140 | " html_table += \"
\"\n",
141 | " html_table += \"
\"\n",
142 | "\n",
143 | " return widgets.HTML(html_table)\n",
144 | "\n",
145 | "\n",
146 | "interact_widget = widgets.interact(get_step_text, step=(0, len(step_tokens) - 1, 1))\n"
147 | ]
148 | }
149 | ],
150 | "metadata": {
151 | "kernelspec": {
152 | "display_name": ".venv",
153 | "language": "python",
154 | "name": "python3"
155 | },
156 | "language_info": {
157 | "codemirror_mode": {
158 | "name": "ipython",
159 | "version": 3
160 | },
161 | "file_extension": ".py",
162 | "mimetype": "text/x-python",
163 | "name": "python",
164 | "nbconvert_exporter": "python",
165 | "pygments_lexer": "ipython3",
166 | "version": "3.12.7"
167 | }
168 | },
169 | "nbformat": 4,
170 | "nbformat_minor": 2
171 | }
172 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "discrete-flow-matching-pytorch"
3 | version = "0.1.0"
4 | description = "Implementation of Discrete Flow Matching in PyTorch"
5 | readme = "README.md"
6 | requires-python = ">=3.12"
7 | dependencies = [
8 | "datasets>=3.1.0",
9 | "jsonargparse>=4.34.0",
10 | "lightning>=2.4.0",
11 | "more-itertools>=10.5.0",
12 | "pyinstrument>=5.0.0",
13 | "structlog>=24.4.0",
14 | "torch>=2.5.1",
15 | "transformers>=4.46.2",
16 | "wandb>=0.18.6",
17 | ]
18 |
19 | [dependency-groups]
20 | jupyter = [
21 | "ipykernel>=6.29.5",
22 | "ipywidgets>=8.1.5",
23 | "matplotlib>=3.9.2",
24 | ]
25 |
26 | [build-system]
27 | requires = ["hatchling"]
28 | build-backend = "hatchling.build"
29 |
--------------------------------------------------------------------------------
/src/discrete_flow_matching_pytorch/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RobinKa/discrete-flow-matching-pytorch/590f2db6c8c46e1b0138ed6a40bd603e4d2c77a8/src/discrete_flow_matching_pytorch/__init__.py
--------------------------------------------------------------------------------
/src/discrete_flow_matching_pytorch/data.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import datasets
4 | import torch
5 | from jsonargparse import CLI
6 | from more_itertools import chunked
7 | from tqdm import tqdm
8 | from transformers import AutoTokenizer
9 |
10 |
11 | def get_default_tokenizer():
12 | tokenizer = AutoTokenizer.from_pretrained("gpt2")
13 | tokenizer.add_special_tokens({"pad_token": "[PAD]", "mask_token": "[MASK]"})
14 | return tokenizer
15 |
16 |
17 | def load_tiny_stories(tokenizer, split: str, max_length: int = 128):
18 | def tokenize_function(examples):
19 | return dict(
20 | input_ids=tokenizer(
21 | examples["text"],
22 | truncation=True,
23 | padding="max_length",
24 | max_length=max_length,
25 | )["input_ids"]
26 | )
27 |
28 | # Load dataset
29 | def load_split(split):
30 | dataset = datasets.load_dataset("roneneldan/TinyStories", split=split)
31 | dataset = dataset.map(tokenize_function, batched=True, num_proc=os.cpu_count())
32 | dataset.set_format(type="torch", columns=["input_ids"])
33 | return dataset
34 |
35 | return load_split(split)
36 |
37 |
38 | def load_squad(
39 | tokenizer, split, max_length_question: int = 32, max_length_answer: int = 8
40 | ):
41 | def tokenize_function(examples):
42 | question_tokens = tokenizer(
43 | examples["question"],
44 | truncation=True,
45 | padding="max_length",
46 | max_length=max_length_question,
47 | )
48 |
49 | answer_tokens = tokenizer(
50 | [row["text"][0] for row in examples["answers"]],
51 | truncation=True,
52 | padding="max_length",
53 | max_length=max_length_answer,
54 | )
55 |
56 | question_tokens = torch.tensor(question_tokens["input_ids"], dtype=torch.long)
57 | answer_tokens = torch.tensor(answer_tokens["input_ids"], dtype=torch.long)
58 |
59 | input_ids = torch.cat([question_tokens, answer_tokens], dim=-1)
60 | should_noise = torch.cat(
61 | [
62 | torch.zeros_like(question_tokens, dtype=torch.bool),
63 | torch.ones_like(answer_tokens, dtype=torch.bool),
64 | ],
65 | dim=-1,
66 | )
67 |
68 | return dict(input_ids=input_ids, should_noise=should_noise)
69 |
70 | def load_split(split):
71 | dataset = datasets.load_dataset("rajpurkar/squad", split=split)
72 | dataset = dataset.map(tokenize_function, batched=True, num_proc=os.cpu_count())
73 | dataset.set_format(type="torch", columns=["input_ids", "should_noise"])
74 | return dataset
75 |
76 | return load_split(split)
77 |
78 |
79 | def load_github_code(
80 | tokenizer,
81 | split: str,
82 | languages: list[str] | None,
83 | licenses: list[str] | None,
84 | max_length: int = 128,
85 | ):
86 | assert split in ["train", "validation"], split
87 |
88 | # github code does not have a validation split, but we can use different seeds
89 | if split == "validation":
90 | shuffle_seed = 0
91 | split = "train"
92 | else:
93 | shuffle_seed = 1
94 |
95 | def tokenize_function(examples):
96 | # List to store the chunks for all examples in a batch
97 | all_chunks = {"input_ids": []}
98 |
99 | # Tokenize each code example and split into chunks of max_length
100 | for code in examples["code"]:
101 | # Tokenize the entire code snippet without truncation
102 | tokens = tokenizer(code, truncation=False, padding=False)["input_ids"]
103 |
104 | # Split tokens into chunks of max_length
105 | for chunk in chunked(tokens, max_length):
106 | # Pad the chunk to max_length if needed
107 | if len(chunk) < max_length:
108 | chunk += [tokenizer.pad_token_id] * (max_length - len(chunk))
109 |
110 | # Append each chunk to the list under "input_ids"
111 | all_chunks["input_ids"].append(chunk)
112 |
113 | return all_chunks
114 |
115 | # Load dataset
116 | def load_split(split):
117 | dataset: datasets.IterableDataset = datasets.load_dataset(
118 | "codeparrot/github-code",
119 | split=split,
120 | streaming=True,
121 | languages=languages,
122 | licenses=licenses,
123 | filter_languages=languages is not None,
124 | filter_licenses=licenses is not None,
125 | )
126 | dataset = dataset.select_columns(["code"])
127 | dataset = dataset.map(tokenize_function, batched=True, remove_columns=["code"])
128 | dataset = dataset.with_format(type="torch")
129 | dataset = dataset.shuffle(seed=shuffle_seed)
130 | return dataset
131 |
132 | return load_split(split)
133 |
134 |
135 | def load_dataset_by_name(dataset: str, tokenizer, split: str):
136 | match dataset:
137 | case "squad":
138 | return load_squad(tokenizer, split)
139 | case "tiny_stories":
140 | return load_tiny_stories(tokenizer, split)
141 | case "github_code":
142 | return load_github_code(tokenizer, split, languages=None, licenses=None)
143 | case "github_code_dockerfile_mit":
144 | return load_github_code(
145 | tokenizer, split, languages=["Dockerfile"], licenses=["mit"]
146 | )
147 | case "github_code_python_mit":
148 | return load_github_code(
149 | tokenizer, split, languages=["Python"], licenses=["mit"]
150 | )
151 | case _:
152 | raise ValueError(f"Unknown dataset {dataset}")
153 |
154 |
155 | def main(
156 | dataset: str = "squad",
157 | split: str = "train",
158 | rows_to_print: int = 1,
159 | benchmark: bool = False,
160 | ):
161 | tokenizer = get_default_tokenizer()
162 | dataset = load_dataset_by_name(dataset=dataset, tokenizer=tokenizer, split=split)
163 | for i, row in tqdm(enumerate(dataset)):
164 | should_print = i < rows_to_print
165 | if should_print:
166 | print(row)
167 | if not should_print and not benchmark:
168 | break
169 |
170 |
171 | if __name__ == "__main__":
172 | CLI(main)
173 |
--------------------------------------------------------------------------------
/src/discrete_flow_matching_pytorch/flops.py:
--------------------------------------------------------------------------------
1 | from time import perf_counter
2 |
3 | import lightning.pytorch as pl
4 | from structlog import get_logger
5 | from torch.utils.flop_counter import FlopCounterMode
6 |
7 | logger = get_logger()
8 |
9 |
10 | class FlopCounterCallback(pl.Callback):
11 | def __init__(self, train_step_flops: float | None = None):
12 | # Constants
13 | self.device_flops_per_second = 142.5e12 # RTX 3090 142.5 TFLOPS bf16
14 | self.train_step_flops = train_step_flops
15 |
16 | # Needed for calculating durations
17 | self.t_start = None
18 | self.previous_t_train_batch_end = None
19 | self.train_batch_t_start = None
20 |
21 | # Dynamic flop counter
22 | self.flop_counter = None
23 |
24 | # State
25 | self.trained_flops = 0
26 | self.trained_optimal_flops = 0
27 | self.trained_duration = 0
28 | self.duration = 0
29 |
30 | def load_state_dict(self, state_dict):
31 | self.trained_flops = state_dict.get("trained_flops", 0)
32 | self.trained_optimal_flops = state_dict.get("trained_optimal_flops", 0)
33 | self.trained_duration = state_dict.get("trained_duration", 0)
34 | self.duration = state_dict.get("duration", 0)
35 |
36 | def state_dict(self):
37 | return dict(
38 | trained_flops=self.trained_flops,
39 | trained_optimal_flops=self.trained_optimal_flops,
40 | trained_duration=self.trained_duration,
41 | duration=self.duration,
42 | )
43 |
44 | def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
45 | if self.train_step_flops is None:
46 | self.flop_counter = FlopCounterMode(depth=None, display=False)
47 | else:
48 | self.flop_counter = None
49 |
50 | self.t_start = perf_counter()
51 |
52 | def on_train_batch_start(self, *args, **kwargs) -> None:
53 | self.train_batch_t_start = perf_counter()
54 |
55 | if self.flop_counter is not None:
56 | self.flop_counter.__enter__()
57 |
58 | def on_train_batch_end(self, *args, **kwargs) -> None:
59 | if self.flop_counter is not None:
60 | self.flop_counter.__exit__(None, None, None)
61 | train_step_flops = self.flop_counter.get_total_flops()
62 | if self.train_step_flops is None:
63 | logger.info(
64 | "Estimated train step flops", train_step_flops=train_step_flops
65 | )
66 | self.train_step_flops = train_step_flops
67 | t = perf_counter()
68 |
69 | # Train step flops
70 | self.trained_flops += self.train_step_flops
71 |
72 | # Optimal train step flops
73 | train_step_duration = t - self.train_batch_t_start
74 | train_step_optimal_flops = train_step_duration * self.device_flops_per_second
75 | self.trained_optimal_flops += train_step_optimal_flops
76 |
77 | # Accumulate total runtime
78 | if self.previous_t_train_batch_end is not None:
79 | end_to_end_time = t - self.previous_t_train_batch_end
80 | self.duration += end_to_end_time
81 | end_to_start_time = (
82 | self.train_batch_t_start - self.previous_t_train_batch_end
83 | )
84 | else:
85 | self.duration += t - self.t_start
86 | end_to_end_time = None
87 | end_to_start_time = None
88 |
89 | self.previous_t_train_batch_end = t
90 |
91 | # Optimal total flops
92 | optimal_flops = self.duration * self.device_flops_per_second
93 |
94 | # MFU
95 | trained_mfu = self.trained_flops / self.trained_optimal_flops
96 | mfu = self.trained_flops / optimal_flops
97 |
98 | # Trained duration fraction
99 | self.trained_duration += train_step_duration
100 | trained_duration_fraction = self.trained_duration / self.duration
101 |
102 | self.log_dict(
103 | {
104 | "flops/train_step_pflops": self.train_step_flops / 1e15,
105 | "flops/train_step_optimal_pflops": train_step_optimal_flops / 1e15,
106 | "flops/trained_pflops": self.trained_flops / 1e15,
107 | "flops/trained_optimal_pflops": self.trained_optimal_flops / 1e15,
108 | "flops/trained_mfu": trained_mfu,
109 | "flops/trained_duration_fraction": trained_duration_fraction,
110 | "flops/train_duration": self.trained_duration,
111 | "flops/duration": self.duration,
112 | "flops/mfu": mfu,
113 | "flops/train_start_to_end_time": train_step_duration,
114 | **(
115 | {
116 | "flops/train_end_to_start_time": end_to_start_time,
117 | "flops/train_end_to_end_time": end_to_end_time,
118 | }
119 | if end_to_start_time is not None and end_to_end_time is not None
120 | else {}
121 | ),
122 | }
123 | )
124 |
125 | def on_validation_batch_end(self, *args, **kwargs) -> None:
126 | if self.train_step_flops is not None:
127 | assert self.trained_flops is not None
128 | assert self.trained_optimal_flops is not None
129 | self.log_dict(
130 | {
131 | "flops/train_step_pflops": self.train_step_flops / 1e15,
132 | "flops/trained_pflops": self.trained_flops / 1e15,
133 | "flops/trained_optimal_pflops": self.trained_optimal_flops / 1e15,
134 | "flops/train_duration": self.trained_duration,
135 | "flops/duration": self.duration,
136 | }
137 | )
138 |
--------------------------------------------------------------------------------
/src/discrete_flow_matching_pytorch/lightning.py:
--------------------------------------------------------------------------------
1 | from typing import Literal
2 |
3 | import lightning.pytorch as pl
4 | import torch
5 | import torch._dynamo.cache_size
6 | import torch.nn.functional as F
7 | from lightning.pytorch.loggers import WandbLogger
8 | from transformers import PreTrainedTokenizerBase
9 |
10 | from discrete_flow_matching_pytorch.model import ConvNet
11 |
12 |
13 | def get_timestep_step_sizes(timesteps: torch.Tensor) -> torch.Tensor:
14 | return -torch.diff(
15 | timesteps,
16 | append=torch.zeros([1], device=timesteps.device, dtype=timesteps.dtype),
17 | )
18 |
19 |
20 | class DiscreteFlowMatchingNet(pl.LightningModule):
21 | def __init__(
22 | self,
23 | vocab_size: int,
24 | hidden_dim: int,
25 | num_timesteps: int,
26 | num_layers: int,
27 | tokenizer: PreTrainedTokenizerBase,
28 | val_num_sampling_steps: int = 8,
29 | scheduler_type: Literal["linear", "square"] = "square",
30 | learning_rate: float = 1e-2,
31 | ):
32 | super().__init__()
33 |
34 | self.num_layers = num_layers
35 | self.tokenizer = tokenizer
36 | self.pad_token_id = self.tokenizer.pad_token_id
37 | self.mask_token_id = self.tokenizer.mask_token_id
38 | self.val_num_sampling_steps = val_num_sampling_steps
39 | self.learning_rate = learning_rate
40 |
41 | self.scheduler = torch.linspace(
42 | 1 / num_timesteps, 1, steps=num_timesteps, dtype=torch.float32
43 | ) # Probability path scheduler
44 |
45 | match scheduler_type:
46 | case "linear":
47 | pass
48 | case "square":
49 | # Put more weight on higher (=more noisy) timesteps.
50 | # Examples:
51 | # 0 -> 0 (no noise)
52 | # 0.5 -> 0.75 (50% noise moved to 75% noise)
53 | # 1 -> 1 (all noise)
54 | self.scheduler = 1 - torch.square(1 - self.scheduler)
55 | case _:
56 | raise ValueError(f"Invalid scheduler type: {scheduler_type}")
57 |
58 | self.model = ConvNet(
59 | vocab_size=vocab_size,
60 | hidden_dim=hidden_dim,
61 | num_timesteps=num_timesteps,
62 | num_layers=num_layers,
63 | )
64 |
65 | self.save_hyperparameters()
66 |
67 | def on_fit_start(self) -> None:
68 | self.scheduler = self.scheduler.to(self.device)
69 |
70 | print("Setting learning rate to", self.learning_rate)
71 | opt = self.optimizers(False)
72 | assert isinstance(opt, torch.optim.Optimizer)
73 | opt.param_groups[0]["lr"] = self.learning_rate
74 |
75 | def on_validation_model_eval(self) -> None:
76 | self.scheduler = self.scheduler.to(self.device)
77 | return super().on_validation_model_eval()
78 |
79 | def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
80 | # x: B, L, V
81 | # t: B
82 | # model: B, L, V
83 | return self.model(x, t)
84 |
85 | def forward_noising(
86 | self, x: torch.Tensor, t: torch.Tensor, should_noise: torch.Tensor | None
87 | ) -> torch.Tensor:
88 | """Mask x (BL) depending on time step t (BL)."""
89 |
90 | # t is the masking probability. t=0%: dont mask anything, t=100%: mask everything
91 | mask_prob = self.scheduler[t].expand(-1, x.shape[1])
92 | will_mask = torch.bernoulli(mask_prob).to(dtype=torch.bool)
93 |
94 | # Don't mask padding tokens
95 | will_mask &= x != self.pad_token_id
96 |
97 | # Don't mask tokens that should not be noised
98 | if should_noise is not None:
99 | will_mask &= should_noise
100 |
101 | noised_x = x.clone()
102 | noised_x[will_mask] = self.mask_token_id
103 |
104 | return noised_x
105 |
106 | @torch._dynamo.disable
107 | def log_training_step(self, log_dict):
108 | self.log_dict(
109 | {
110 | **log_dict,
111 | "train/learning_rate": self.trainer.optimizers[0].param_groups[0]["lr"],
112 | }
113 | )
114 |
115 | def training_step(self, batch, batch_idx: int):
116 | # B L
117 | x: torch.Tensor = batch["input_ids"]
118 | should_noise: torch.Tensor | None = batch.get("should_noise")
119 |
120 | # t: B
121 | t = torch.randint(0, len(self.scheduler), [x.size(0)], device=x.device)
122 |
123 | # noised_x: B L
124 | noised_x = self.forward_noising(
125 | x=x, t=t.unsqueeze(1), should_noise=should_noise
126 | )
127 |
128 | # Unmasking logits: B L V
129 | logits = self(noised_x, t) # .to(torch.float32)
130 |
131 | target = x.clone()
132 | # Only calculate loss on tokens that were masked
133 | target[noised_x != self.mask_token_id] = -100
134 |
135 | loss = F.cross_entropy(
136 | # CE expects input BVL, target BL
137 | input=logits.transpose(-1, -2),
138 | target=target,
139 | reduction="mean",
140 | )
141 | self.log_training_step({"train/loss": loss})
142 |
143 | return loss
144 |
145 | @torch._dynamo.disable
146 | def log_validation_step(
147 | self,
148 | num_samples,
149 | input_text_tokenized,
150 | generated_texts_tokenized,
151 | noised_texts_tokenized,
152 | sampling_timesteps,
153 | losses,
154 | ):
155 | # input_text: B
156 | # generated_texts: T B
157 | # noised_texts: T B
158 | # sampling_timesteps: T
159 |
160 | input_text = self.tokenizer.batch_decode(input_text_tokenized)
161 | generated_texts = [
162 | self.tokenizer.batch_decode(t) for t in generated_texts_tokenized
163 | ]
164 | noised_texts = [self.tokenizer.batch_decode(t) for t in noised_texts_tokenized]
165 |
166 | self.log_dict(losses)
167 |
168 | num_samples = min(num_samples, len(input_text))
169 |
170 | if isinstance(self.logger, WandbLogger):
171 | for i_t, t in enumerate(sampling_timesteps):
172 | self.logger.log_table(
173 | f"validation-texts/{t}",
174 | columns=["input_text", "generated_text", "generated_text_inputs"],
175 | data=[
176 | [
177 | input_text[i],
178 | generated_texts[i_t][i],
179 | noised_texts[i_t][i],
180 | ]
181 | for i in range(num_samples)
182 | ],
183 | )
184 |
185 | def _get_sampling_timesteps(self, num_sampling_steps):
186 | return torch.linspace(
187 | len(self.scheduler) - 1,
188 | len(self.scheduler) // num_sampling_steps,
189 | num_sampling_steps,
190 | device=self.device,
191 | dtype=torch.long,
192 | )
193 |
194 | @torch._dynamo.disable
195 | def validation_step_without_compile(
196 | self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
197 | ):
198 | # x: B L
199 | x = batch["input_ids"]
200 | should_noise = batch.get("should_noise")
201 |
202 | num_samples = 5 # Number of samples to visualize
203 |
204 | sampling_timesteps = self._get_sampling_timesteps(self.val_num_sampling_steps)
205 |
206 | losses = {}
207 |
208 | # Apply forward noising and reverse process
209 | noised_texts_tokenized = []
210 | generated_texts_tokenized = []
211 |
212 | # for t, step_size in zip(sampling_timesteps, step_sizes, strict=True):
213 | for t in sampling_timesteps:
214 | t = t.repeat(x.shape[0])
215 | assert t.shape == x.shape[:1], t.shape
216 |
217 | # B L
218 | noised_x = self.forward_noising(
219 | x, t.unsqueeze(1), should_noise=should_noise
220 | )
221 |
222 | # Only calculate loss on tokens that were masked
223 | # B L
224 | target = x.clone()
225 | # Only calculate loss on tokens that were masked
226 | target[noised_x != self.mask_token_id] = -100
227 |
228 | # B L V
229 | logits = self(noised_x, t)
230 |
231 | # Get samples for each token
232 | # B L
233 | # samples = torch.distributions.Categorical(logits=logits).sample()
234 | samples = torch.argmax(logits, dim=-1)
235 |
236 | # Unmask the masked tokens
237 | # B L
238 | generated_tokens = noised_x.clone()
239 | generated_tokens[noised_x == self.mask_token_id] = samples[
240 | noised_x == self.mask_token_id
241 | ]
242 |
243 | generated_texts_tokenized.append(generated_tokens)
244 | noised_texts_tokenized.append(noised_x)
245 |
246 | losses[f"validation-losses/loss_{t[0]}"] = F.cross_entropy(
247 | input=logits.transpose(-1, -2), target=target, reduction="mean"
248 | )
249 | losses["validation/loss_mean"] = torch.mean(torch.tensor(list(losses.values())))
250 |
251 | self.log_validation_step(
252 | num_samples=num_samples,
253 | input_text_tokenized=x,
254 | generated_texts_tokenized=generated_texts_tokenized,
255 | noised_texts_tokenized=noised_texts_tokenized,
256 | sampling_timesteps=sampling_timesteps,
257 | losses=losses,
258 | )
259 |
260 | return losses["validation/loss_mean"]
261 |
262 | def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int):
263 | return self.validation_step_without_compile(batch, batch_idx)
264 |
265 | def sample(
266 | self,
267 | num_sampling_steps: int,
268 | num_samples: int | None = None,
269 | sequence_length: int | None = None,
270 | x: torch.Tensor | None = None,
271 | stochasticity: float = 0.0,
272 | yield_intermediate: bool = False,
273 | yield_logits: bool = False,
274 | temperature: float = 1.0,
275 | cfg_scale: float = 1.0,
276 | ):
277 | assert (
278 | num_samples is not None and sequence_length is not None
279 | ) or x is not None, "Must pass either (num_samples and sequence_length) or x"
280 |
281 | assert not (
282 | yield_intermediate and yield_logits
283 | ), "Can't yield both logits and intermediate results"
284 |
285 | # B L
286 | if x is None:
287 | # Start fully masked
288 | x = torch.full(
289 | [num_samples, sequence_length],
290 | fill_value=self.tokenizer.mask_token_id,
291 | dtype=torch.long,
292 | device=self.device,
293 | )
294 | should_noise = None
295 | else:
296 | should_noise = x == self.mask_token_id
297 |
298 | # Create the integer timesteps and step sizes for the given num_sampling_steps
299 | # S
300 | sampling_timesteps = self._get_sampling_timesteps(num_sampling_steps)
301 | relative_ts = self.scheduler[sampling_timesteps]
302 | relative_dts = get_timestep_step_sizes(relative_ts)
303 |
304 | for t, relative_t, relative_dt in zip(
305 | sampling_timesteps, relative_ts, relative_dts
306 | ):
307 | is_last_step = t == sampling_timesteps[-1]
308 | if yield_intermediate:
309 | yield t, x
310 |
311 | # B
312 | t = t.repeat(x.shape[0])
313 | assert t.shape == x.shape[:1], t.shape
314 |
315 | # B L V
316 | logits = self(x, t)
317 |
318 | if cfg_scale != 1.0:
319 | assert should_noise is not None
320 | x_uncond = x.clone()
321 | x_uncond[~should_noise] = self.mask_token_id
322 |
323 | # Classifier-free guidance
324 | # Run model unconditionally (conditioning fully masked)
325 | logits_uncond = self(x_uncond, t)
326 |
327 | # Mix the logits according to cfg_scale
328 | logits = logits_uncond + cfg_scale * (logits - logits_uncond)
329 |
330 | if yield_logits:
331 | yield t, logits
332 |
333 | # B L
334 | samples = torch.distributions.Categorical(
335 | logits=logits / temperature
336 | ).sample()
337 |
338 | # B L
339 | # Chance to unmask proportional to
340 | # - step size: higher step size means higher chance
341 | # - timestep: lower timestep means higher chance (so in the end the chance is 100%)
342 | unmask_threshold = relative_dt / relative_t
343 |
344 | # With remasking, the unmasking probability is changed
345 | if stochasticity != 0:
346 | unmask_threshold *= 1 + stochasticity * (1 - relative_t)
347 |
348 | was_masked = x == self.mask_token_id
349 |
350 | # Unmask
351 | will_unmask = (
352 | torch.rand(
353 | x.shape[:2],
354 | device=unmask_threshold.device,
355 | dtype=unmask_threshold.dtype,
356 | )
357 | < unmask_threshold
358 | )
359 | # Only unmask the tokens that were masked
360 | will_unmask &= was_masked
361 |
362 | # Remask when stochasticity is non-zero
363 | if stochasticity != 0 and not is_last_step:
364 | remask_threshold = relative_dt * stochasticity
365 | will_remask = (
366 | torch.rand(
367 | x.shape[:2],
368 | device=unmask_threshold.device,
369 | dtype=unmask_threshold.dtype,
370 | )
371 | < remask_threshold
372 | )
373 | # Only remask the tokens that were unmasked
374 | will_remask &= ~was_masked
375 |
376 | # Only remask tokens that aren't constant
377 | if should_noise is not None:
378 | will_remask &= should_noise
379 |
380 | x[will_remask] = self.mask_token_id
381 |
382 | # B L
383 | x[will_unmask] = samples[will_unmask]
384 |
385 | if yield_intermediate:
386 | yield torch.zeros_like(t), x
387 | else:
388 | return x
389 |
390 | def configure_optimizers(self):
391 | return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
392 |
--------------------------------------------------------------------------------
/src/discrete_flow_matching_pytorch/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class Unsqueeze(nn.Module):
6 | def __init__(self, dim: int):
7 | super().__init__()
8 | self.dim = dim
9 |
10 | def forward(self, x):
11 | return x.unsqueeze(self.dim)
12 |
13 |
14 | class Reshape(nn.Module):
15 | def __init__(self, shape: list[int]):
16 | super().__init__()
17 | self.shape = shape
18 |
19 | def forward(self, x):
20 | return x.reshape(self.shape)
21 |
22 |
23 | class Transpose(nn.Module):
24 | def forward(self, x):
25 | return x.transpose(-1, -2)
26 |
27 |
28 | class ConvNet(nn.Module):
29 | def __init__(
30 | self,
31 | vocab_size: int,
32 | hidden_dim: int,
33 | num_timesteps: int,
34 | num_layers: int,
35 | kernel_size: int = 31,
36 | ):
37 | super().__init__()
38 |
39 | # x: B, L
40 |
41 | # Input embedding, B L -> B L C
42 | self.input_projection = nn.Embedding(vocab_size, hidden_dim)
43 |
44 | # Embed timestep to B, 1, C
45 | self.embed_timestep = nn.Sequential(
46 | nn.Embedding(num_timesteps, hidden_dim),
47 | Unsqueeze(1),
48 | )
49 |
50 | self.blocks = nn.ModuleList()
51 | self.timestep_embedding_norms = nn.ModuleList()
52 |
53 | for _ in range(num_layers):
54 | self.blocks.append(
55 | nn.Sequential(
56 | Transpose(),
57 | nn.Conv1d(
58 | hidden_dim, hidden_dim, kernel_size=kernel_size, padding="same"
59 | ),
60 | Transpose(),
61 | nn.LayerNorm([hidden_dim]),
62 | nn.GELU(),
63 | nn.Linear(hidden_dim, hidden_dim),
64 | nn.LayerNorm([hidden_dim]),
65 | nn.GELU(),
66 | ),
67 | )
68 | self.timestep_embedding_norms.append(nn.LayerNorm([hidden_dim]))
69 |
70 | # Output projection, B L C -> B L V
71 | self.output_projection = nn.Linear(hidden_dim, vocab_size)
72 |
73 | def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
74 | # x: B, L, V, t: B
75 | x = self.input_projection(x) # BLC
76 |
77 | for block, timestep_embedding_norm in zip(
78 | self.blocks, self.timestep_embedding_norms
79 | ):
80 | x = x + block(x + timestep_embedding_norm(self.embed_timestep(t))) # BLC
81 |
82 | x = self.output_projection(x) # BLV
83 |
84 | return x
85 |
--------------------------------------------------------------------------------
/src/discrete_flow_matching_pytorch/train.py:
--------------------------------------------------------------------------------
1 | from contextlib import nullcontext
2 | from time import time
3 |
4 | import lightning.pytorch as pl
5 | import torch
6 | from jsonargparse import CLI
7 | from lightning.pytorch.callbacks import ModelCheckpoint
8 | from lightning.pytorch.loggers import WandbLogger
9 | from pyinstrument import Profiler
10 | from pyinstrument.renderers.speedscope import SpeedscopeRenderer
11 | from structlog import get_logger
12 | from torch.utils.data import DataLoader
13 |
14 | from discrete_flow_matching_pytorch.data import (
15 | get_default_tokenizer,
16 | load_dataset_by_name,
17 | )
18 | from discrete_flow_matching_pytorch.flops import FlopCounterCallback
19 | from discrete_flow_matching_pytorch.lightning import DiscreteFlowMatchingNet
20 |
21 | logger = get_logger()
22 |
23 |
24 | def get_run_name(
25 | dataset: str,
26 | train_batch_size: int,
27 | hidden_dim: int,
28 | num_layers: int,
29 | scheduler_type: str,
30 | ):
31 | return f"{dataset}-bs={train_batch_size}-h={hidden_dim}-l={num_layers}-s={scheduler_type}"
32 |
33 |
34 | def main(
35 | compile: bool = True,
36 | wandb: bool = True,
37 | max_steps: int = -1,
38 | profile: bool = False,
39 | train_step_flops: float | None = None,
40 | ckpt_path: str = "",
41 | val_interval: int = 500,
42 | checkpoint_interval: int = 1_000,
43 | dataset: str = "tiny_stories",
44 | train_batch_size: int = 256,
45 | shuffle_train: bool = True, # Needs to be false for IterableDataset
46 | train_workers: int = 2,
47 | val_split_name: str = "validation", # Some datasets don't have validation, but we can still use train
48 | hidden_dim: int = 768,
49 | num_layers: int = 6,
50 | scheduler_type: str = "square",
51 | learning_rate: float = 1e-2,
52 | ):
53 | torch.set_float32_matmul_precision("high")
54 |
55 | # Load tokenizer
56 | logger.info("Loading tokenizer")
57 | tokenizer = get_default_tokenizer()
58 |
59 | logger.info("Loading dataset", dataset=dataset)
60 | train_data = load_dataset_by_name(
61 | dataset=dataset, tokenizer=tokenizer, split="train"
62 | )
63 | val_data = load_dataset_by_name(
64 | dataset=dataset, tokenizer=tokenizer, split=val_split_name
65 | )
66 |
67 | # Dataloader
68 | logger.info("Creating data loaders")
69 | train_loader = DataLoader(
70 | train_data,
71 | batch_size=train_batch_size,
72 | shuffle=shuffle_train,
73 | num_workers=train_workers,
74 | prefetch_factor=2,
75 | )
76 | val_loader = DataLoader(
77 | val_data, batch_size=64, shuffle=False, num_workers=1, prefetch_factor=2
78 | )
79 |
80 | # Create model
81 | logger.info("Creating model")
82 | model = DiscreteFlowMatchingNet(
83 | vocab_size=len(tokenizer), # .vocab_size excludes the new tokens
84 | hidden_dim=hidden_dim,
85 | num_timesteps=1024,
86 | num_layers=num_layers,
87 | tokenizer=tokenizer,
88 | scheduler_type=scheduler_type,
89 | learning_rate=learning_rate,
90 | ).to(dtype=torch.bfloat16)
91 |
92 | if compile:
93 | torch._dynamo.config.cache_size_limit = 512
94 | torch._dynamo.config.capture_scalar_outputs = True
95 | torch._dynamo.config.capture_dynamic_output_shape_ops = True
96 | torch._inductor.config.fx_graph_cache = True
97 | model = torch.compile(model)
98 |
99 | # Train model
100 | if train_step_flops is not None:
101 | logger.info("Using train step flops", train_step_flops=train_step_flops)
102 | else:
103 | logger.info("Using dynamic train step flops")
104 |
105 | trainer = pl.Trainer(
106 | max_epochs=-1,
107 | max_steps=max_steps,
108 | accelerator="gpu",
109 | devices=1 if torch.cuda.is_available() else 0,
110 | limit_val_batches=1,
111 | callbacks=[
112 | FlopCounterCallback(train_step_flops=train_step_flops),
113 | ModelCheckpoint(every_n_train_steps=checkpoint_interval),
114 | ],
115 | logger=WandbLogger(
116 | project="discrete-flow-matching",
117 | name=get_run_name(
118 | dataset=dataset,
119 | train_batch_size=train_batch_size,
120 | hidden_dim=hidden_dim,
121 | num_layers=num_layers,
122 | scheduler_type=scheduler_type,
123 | ),
124 | )
125 | if wandb
126 | else None,
127 | precision="bf16",
128 | gradient_clip_algorithm="norm",
129 | gradient_clip_val=0.1,
130 | check_val_every_n_epoch=None,
131 | val_check_interval=val_interval,
132 | log_every_n_steps=25,
133 | num_sanity_val_steps=0,
134 | )
135 |
136 | fit_context = nullcontext() if not profile else Profiler(async_mode="disabled")
137 | with fit_context:
138 | # Run validation before training to get the initial loss
139 | trainer.validate(model=model, dataloaders=val_loader)
140 |
141 | trainer.fit(
142 | model=model,
143 | train_dataloaders=train_loader,
144 | val_dataloaders=val_loader,
145 | ckpt_path=ckpt_path if ckpt_path else None,
146 | )
147 |
148 | if profile:
149 | with open(f"speedscope-{time()}.json", "w") as f:
150 | f.write(fit_context.output(SpeedscopeRenderer()))
151 |
152 |
153 | if __name__ == "__main__":
154 | CLI(main)
155 |
--------------------------------------------------------------------------------