├── .gitignore ├── .python-version ├── LICENSE ├── README.md ├── configs └── conv-8.yaml ├── media ├── step0.png ├── step128.png └── step64.png ├── notebooks └── sample.ipynb ├── pyproject.toml ├── src └── discrete_flow_matching_pytorch │ ├── __init__.py │ ├── data.py │ ├── flops.py │ ├── lightning.py │ ├── model.py │ └── train.py └── uv.lock /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | /discrete-flow-matching 165 | /lightning_logs 166 | /wandb 167 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.12 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Robin Kahlow 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Discrete Flow Matching implemented in PyTorch 2 | 3 | Implementation of Discrete Flow Matching [[1]](https://arxiv.org/abs/2402.04997)[[2]](https://arxiv.org/abs/2407.15595), which is a generative model for generating discrete things such as text with flow matching. The code is implemented in PyTorch. 4 | 5 | | Step 0 of 128 (input) | Step 64 of 128 | Step 128 of 128 (output) | 6 | | -------------------------- | ---------------------------- | ------------------------------ | 7 | | ![Step 0](media/step0.png) | ![Step 64](media/step64.png) | ![Step 128](media/step128.png) | 8 | 9 | ## How to run 10 | 11 | ### Environment setup 12 | 13 | 1. Install [uv](https://github.com/astral-sh/uv) for package management, e.g. `pip install uv` 14 | 2. Make sure Python 3.12 is installed: `uv python install 3.12` 15 | 3. Install the dependencies: `uv sync --group jupyter` 16 | 17 | Run `python -m discrete_flow_matching_pytorch.train --config configs/conv-8.yaml` to start training a text generation model logging to wandb. 18 | 19 | The [sample notebook](notebooks/sample.ipynb) demonstrates the sampling process. 20 | 21 | **Note**: Instead of using uv, it is also possible to install the dependencies in [pyproject.toml](pyproject.toml) with pip. 22 | 23 | ## Summary of discrete flow matching compared to continuous flow matching 24 | 25 | - During training, we mask out text tokens according to the timestep 26 | - The model is trained to predict the original unmasked tokens with cross entropy loss 27 | - In sampling, we unmask text gradually with the sampled tokens 28 | 29 | ## References 30 | 31 | - [[1] Generative Flows on Discrete State-Spaces: Enabling Multimodal Flows with Applications to Protein Co-Design](https://arxiv.org/abs/2402.04997) ([YouTube presentation](https://www.youtube.com/watch?v=yzc29vhM2Aw)): Combines discrete and continuous flow matching. Originally introduced Discrete Flow Matching. Appendix F was very useful for the implementation 32 | - [[2] Discrete Flow Matching](https://arxiv.org/abs/2407.15595): Builds on Multiflow's Discrete Flow Matching 33 | -------------------------------------------------------------------------------- /configs/conv-8.yaml: -------------------------------------------------------------------------------- 1 | dataset: github_code_python_mit 2 | shuffle_train: false 3 | val_split_name: train 4 | train_workers: 4 5 | hidden_dim: 1024 6 | num_layers: 8 7 | train_step_flops: 62895031320576 8 | learning_rate: 1e-3 9 | -------------------------------------------------------------------------------- /media/step0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RobinKa/discrete-flow-matching-pytorch/590f2db6c8c46e1b0138ed6a40bd603e4d2c77a8/media/step0.png -------------------------------------------------------------------------------- /media/step128.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RobinKa/discrete-flow-matching-pytorch/590f2db6c8c46e1b0138ed6a40bd603e4d2c77a8/media/step128.png -------------------------------------------------------------------------------- /media/step64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RobinKa/discrete-flow-matching-pytorch/590f2db6c8c46e1b0138ed6a40bd603e4d2c77a8/media/step64.png -------------------------------------------------------------------------------- /notebooks/sample.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "The autoreload extension is already loaded. To reload it, use:\n", 13 | " %reload_ext autoreload\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "%load_ext autoreload\n", 19 | "%autoreload 2" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 3, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "import ipywidgets as widgets\n", 29 | "import torch\n", 30 | "\n", 31 | "from discrete_flow_matching_pytorch.model import DiscreteFlowMatchingNet" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 4, 37 | "metadata": {}, 38 | "outputs": [ 39 | { 40 | "data": { 41 | "text/plain": [ 42 | "DiscreteFlowMatchingNet(\n", 43 | " (input_projection): Embedding(50259, 768)\n", 44 | " (embed_timestep): Sequential(\n", 45 | " (0): Embedding(1024, 768)\n", 46 | " (1): Unsqueeze()\n", 47 | " )\n", 48 | " (blocks): ModuleList(\n", 49 | " (0-5): 6 x Sequential(\n", 50 | " (0): Transpose()\n", 51 | " (1): Conv1d(768, 768, kernel_size=(31,), stride=(1,), padding=same)\n", 52 | " (2): Transpose()\n", 53 | " (3): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 54 | " (4): GELU(approximate='none')\n", 55 | " (5): Linear(in_features=768, out_features=768, bias=True)\n", 56 | " (6): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 57 | " (7): GELU(approximate='none')\n", 58 | " )\n", 59 | " )\n", 60 | " (timestep_embedding_norms): ModuleList(\n", 61 | " (0-5): 6 x LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 62 | " )\n", 63 | " (output_projection): Linear(in_features=768, out_features=50259, bias=True)\n", 64 | ")" 65 | ] 66 | }, 67 | "execution_count": 4, 68 | "metadata": {}, 69 | "output_type": "execute_result" 70 | } 71 | ], 72 | "source": [ 73 | "model = DiscreteFlowMatchingNet.load_from_checkpoint(\n", 74 | " \"../flow-matching-tiny-stories/1pkkmee9/checkpoints/epoch=0-step=3000.ckpt\"\n", 75 | ")\n", 76 | "model.freeze()\n", 77 | "model.eval()\n", 78 | "model.to(dtype=torch.bfloat16, device=\"cuda:0\")\n", 79 | "model" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 5, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "step_tokens: list[list[str]] = []\n", 89 | "step_texts: list[str] = []\n", 90 | "with torch.inference_mode():\n", 91 | " for t, samples in model.sample(\n", 92 | " num_samples=1,\n", 93 | " sequence_length=128,\n", 94 | " num_sampling_steps=128,\n", 95 | " stochasticity=5.0,\n", 96 | " yield_intermediate=True,\n", 97 | " ):\n", 98 | " step_tokens.append(model.tokenizer.batch_decode(samples[0]))\n", 99 | " step_texts.append(model.tokenizer.decode(samples[0]))" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 6, 105 | "metadata": {}, 106 | "outputs": [ 107 | { 108 | "data": { 109 | "application/vnd.jupyter.widget-view+json": { 110 | "model_id": "30ad5d91ccda466da693e68c74289aaa", 111 | "version_major": 2, 112 | "version_minor": 0 113 | }, 114 | "text/plain": [ 115 | "interactive(children=(IntSlider(value=64, description='step', max=128), Output()), _dom_classes=('widget-inter…" 116 | ] 117 | }, 118 | "metadata": {}, 119 | "output_type": "display_data" 120 | } 121 | ], 122 | "source": [ 123 | "def get_step_text(step):\n", 124 | " rows = []\n", 125 | " column_size = 16\n", 126 | " for i in range(0, len(step_tokens[step]), column_size):\n", 127 | " rows.append(step_tokens[step][i : i + column_size])\n", 128 | "\n", 129 | " # Create html table\n", 130 | " html_table = ''\n", 131 | " for row in rows:\n", 132 | " html_table += \"\"\n", 133 | " for token in row:\n", 134 | " style = \"border: 1px solid black; width: 50px; text-align: center;\"\n", 135 | " if token == \"[MASK]\":\n", 136 | " style += \"background-color: #cccccc;\"\n", 137 | " else:\n", 138 | " style += \"background-color: #eeeeee;\"\n", 139 | " html_table += f''\n", 140 | " html_table += \"\"\n", 141 | " html_table += \"
{token}
\"\n", 142 | "\n", 143 | " return widgets.HTML(html_table)\n", 144 | "\n", 145 | "\n", 146 | "interact_widget = widgets.interact(get_step_text, step=(0, len(step_tokens) - 1, 1))\n" 147 | ] 148 | } 149 | ], 150 | "metadata": { 151 | "kernelspec": { 152 | "display_name": ".venv", 153 | "language": "python", 154 | "name": "python3" 155 | }, 156 | "language_info": { 157 | "codemirror_mode": { 158 | "name": "ipython", 159 | "version": 3 160 | }, 161 | "file_extension": ".py", 162 | "mimetype": "text/x-python", 163 | "name": "python", 164 | "nbconvert_exporter": "python", 165 | "pygments_lexer": "ipython3", 166 | "version": "3.12.7" 167 | } 168 | }, 169 | "nbformat": 4, 170 | "nbformat_minor": 2 171 | } 172 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "discrete-flow-matching-pytorch" 3 | version = "0.1.0" 4 | description = "Implementation of Discrete Flow Matching in PyTorch" 5 | readme = "README.md" 6 | requires-python = ">=3.12" 7 | dependencies = [ 8 | "datasets>=3.1.0", 9 | "jsonargparse>=4.34.0", 10 | "lightning>=2.4.0", 11 | "more-itertools>=10.5.0", 12 | "pyinstrument>=5.0.0", 13 | "structlog>=24.4.0", 14 | "torch>=2.5.1", 15 | "transformers>=4.46.2", 16 | "wandb>=0.18.6", 17 | ] 18 | 19 | [dependency-groups] 20 | jupyter = [ 21 | "ipykernel>=6.29.5", 22 | "ipywidgets>=8.1.5", 23 | "matplotlib>=3.9.2", 24 | ] 25 | 26 | [build-system] 27 | requires = ["hatchling"] 28 | build-backend = "hatchling.build" 29 | -------------------------------------------------------------------------------- /src/discrete_flow_matching_pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RobinKa/discrete-flow-matching-pytorch/590f2db6c8c46e1b0138ed6a40bd603e4d2c77a8/src/discrete_flow_matching_pytorch/__init__.py -------------------------------------------------------------------------------- /src/discrete_flow_matching_pytorch/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import datasets 4 | import torch 5 | from jsonargparse import CLI 6 | from more_itertools import chunked 7 | from tqdm import tqdm 8 | from transformers import AutoTokenizer 9 | 10 | 11 | def get_default_tokenizer(): 12 | tokenizer = AutoTokenizer.from_pretrained("gpt2") 13 | tokenizer.add_special_tokens({"pad_token": "[PAD]", "mask_token": "[MASK]"}) 14 | return tokenizer 15 | 16 | 17 | def load_tiny_stories(tokenizer, split: str, max_length: int = 128): 18 | def tokenize_function(examples): 19 | return dict( 20 | input_ids=tokenizer( 21 | examples["text"], 22 | truncation=True, 23 | padding="max_length", 24 | max_length=max_length, 25 | )["input_ids"] 26 | ) 27 | 28 | # Load dataset 29 | def load_split(split): 30 | dataset = datasets.load_dataset("roneneldan/TinyStories", split=split) 31 | dataset = dataset.map(tokenize_function, batched=True, num_proc=os.cpu_count()) 32 | dataset.set_format(type="torch", columns=["input_ids"]) 33 | return dataset 34 | 35 | return load_split(split) 36 | 37 | 38 | def load_squad( 39 | tokenizer, split, max_length_question: int = 32, max_length_answer: int = 8 40 | ): 41 | def tokenize_function(examples): 42 | question_tokens = tokenizer( 43 | examples["question"], 44 | truncation=True, 45 | padding="max_length", 46 | max_length=max_length_question, 47 | ) 48 | 49 | answer_tokens = tokenizer( 50 | [row["text"][0] for row in examples["answers"]], 51 | truncation=True, 52 | padding="max_length", 53 | max_length=max_length_answer, 54 | ) 55 | 56 | question_tokens = torch.tensor(question_tokens["input_ids"], dtype=torch.long) 57 | answer_tokens = torch.tensor(answer_tokens["input_ids"], dtype=torch.long) 58 | 59 | input_ids = torch.cat([question_tokens, answer_tokens], dim=-1) 60 | should_noise = torch.cat( 61 | [ 62 | torch.zeros_like(question_tokens, dtype=torch.bool), 63 | torch.ones_like(answer_tokens, dtype=torch.bool), 64 | ], 65 | dim=-1, 66 | ) 67 | 68 | return dict(input_ids=input_ids, should_noise=should_noise) 69 | 70 | def load_split(split): 71 | dataset = datasets.load_dataset("rajpurkar/squad", split=split) 72 | dataset = dataset.map(tokenize_function, batched=True, num_proc=os.cpu_count()) 73 | dataset.set_format(type="torch", columns=["input_ids", "should_noise"]) 74 | return dataset 75 | 76 | return load_split(split) 77 | 78 | 79 | def load_github_code( 80 | tokenizer, 81 | split: str, 82 | languages: list[str] | None, 83 | licenses: list[str] | None, 84 | max_length: int = 128, 85 | ): 86 | assert split in ["train", "validation"], split 87 | 88 | # github code does not have a validation split, but we can use different seeds 89 | if split == "validation": 90 | shuffle_seed = 0 91 | split = "train" 92 | else: 93 | shuffle_seed = 1 94 | 95 | def tokenize_function(examples): 96 | # List to store the chunks for all examples in a batch 97 | all_chunks = {"input_ids": []} 98 | 99 | # Tokenize each code example and split into chunks of max_length 100 | for code in examples["code"]: 101 | # Tokenize the entire code snippet without truncation 102 | tokens = tokenizer(code, truncation=False, padding=False)["input_ids"] 103 | 104 | # Split tokens into chunks of max_length 105 | for chunk in chunked(tokens, max_length): 106 | # Pad the chunk to max_length if needed 107 | if len(chunk) < max_length: 108 | chunk += [tokenizer.pad_token_id] * (max_length - len(chunk)) 109 | 110 | # Append each chunk to the list under "input_ids" 111 | all_chunks["input_ids"].append(chunk) 112 | 113 | return all_chunks 114 | 115 | # Load dataset 116 | def load_split(split): 117 | dataset: datasets.IterableDataset = datasets.load_dataset( 118 | "codeparrot/github-code", 119 | split=split, 120 | streaming=True, 121 | languages=languages, 122 | licenses=licenses, 123 | filter_languages=languages is not None, 124 | filter_licenses=licenses is not None, 125 | ) 126 | dataset = dataset.select_columns(["code"]) 127 | dataset = dataset.map(tokenize_function, batched=True, remove_columns=["code"]) 128 | dataset = dataset.with_format(type="torch") 129 | dataset = dataset.shuffle(seed=shuffle_seed) 130 | return dataset 131 | 132 | return load_split(split) 133 | 134 | 135 | def load_dataset_by_name(dataset: str, tokenizer, split: str): 136 | match dataset: 137 | case "squad": 138 | return load_squad(tokenizer, split) 139 | case "tiny_stories": 140 | return load_tiny_stories(tokenizer, split) 141 | case "github_code": 142 | return load_github_code(tokenizer, split, languages=None, licenses=None) 143 | case "github_code_dockerfile_mit": 144 | return load_github_code( 145 | tokenizer, split, languages=["Dockerfile"], licenses=["mit"] 146 | ) 147 | case "github_code_python_mit": 148 | return load_github_code( 149 | tokenizer, split, languages=["Python"], licenses=["mit"] 150 | ) 151 | case _: 152 | raise ValueError(f"Unknown dataset {dataset}") 153 | 154 | 155 | def main( 156 | dataset: str = "squad", 157 | split: str = "train", 158 | rows_to_print: int = 1, 159 | benchmark: bool = False, 160 | ): 161 | tokenizer = get_default_tokenizer() 162 | dataset = load_dataset_by_name(dataset=dataset, tokenizer=tokenizer, split=split) 163 | for i, row in tqdm(enumerate(dataset)): 164 | should_print = i < rows_to_print 165 | if should_print: 166 | print(row) 167 | if not should_print and not benchmark: 168 | break 169 | 170 | 171 | if __name__ == "__main__": 172 | CLI(main) 173 | -------------------------------------------------------------------------------- /src/discrete_flow_matching_pytorch/flops.py: -------------------------------------------------------------------------------- 1 | from time import perf_counter 2 | 3 | import lightning.pytorch as pl 4 | from structlog import get_logger 5 | from torch.utils.flop_counter import FlopCounterMode 6 | 7 | logger = get_logger() 8 | 9 | 10 | class FlopCounterCallback(pl.Callback): 11 | def __init__(self, train_step_flops: float | None = None): 12 | # Constants 13 | self.device_flops_per_second = 142.5e12 # RTX 3090 142.5 TFLOPS bf16 14 | self.train_step_flops = train_step_flops 15 | 16 | # Needed for calculating durations 17 | self.t_start = None 18 | self.previous_t_train_batch_end = None 19 | self.train_batch_t_start = None 20 | 21 | # Dynamic flop counter 22 | self.flop_counter = None 23 | 24 | # State 25 | self.trained_flops = 0 26 | self.trained_optimal_flops = 0 27 | self.trained_duration = 0 28 | self.duration = 0 29 | 30 | def load_state_dict(self, state_dict): 31 | self.trained_flops = state_dict.get("trained_flops", 0) 32 | self.trained_optimal_flops = state_dict.get("trained_optimal_flops", 0) 33 | self.trained_duration = state_dict.get("trained_duration", 0) 34 | self.duration = state_dict.get("duration", 0) 35 | 36 | def state_dict(self): 37 | return dict( 38 | trained_flops=self.trained_flops, 39 | trained_optimal_flops=self.trained_optimal_flops, 40 | trained_duration=self.trained_duration, 41 | duration=self.duration, 42 | ) 43 | 44 | def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: 45 | if self.train_step_flops is None: 46 | self.flop_counter = FlopCounterMode(depth=None, display=False) 47 | else: 48 | self.flop_counter = None 49 | 50 | self.t_start = perf_counter() 51 | 52 | def on_train_batch_start(self, *args, **kwargs) -> None: 53 | self.train_batch_t_start = perf_counter() 54 | 55 | if self.flop_counter is not None: 56 | self.flop_counter.__enter__() 57 | 58 | def on_train_batch_end(self, *args, **kwargs) -> None: 59 | if self.flop_counter is not None: 60 | self.flop_counter.__exit__(None, None, None) 61 | train_step_flops = self.flop_counter.get_total_flops() 62 | if self.train_step_flops is None: 63 | logger.info( 64 | "Estimated train step flops", train_step_flops=train_step_flops 65 | ) 66 | self.train_step_flops = train_step_flops 67 | t = perf_counter() 68 | 69 | # Train step flops 70 | self.trained_flops += self.train_step_flops 71 | 72 | # Optimal train step flops 73 | train_step_duration = t - self.train_batch_t_start 74 | train_step_optimal_flops = train_step_duration * self.device_flops_per_second 75 | self.trained_optimal_flops += train_step_optimal_flops 76 | 77 | # Accumulate total runtime 78 | if self.previous_t_train_batch_end is not None: 79 | end_to_end_time = t - self.previous_t_train_batch_end 80 | self.duration += end_to_end_time 81 | end_to_start_time = ( 82 | self.train_batch_t_start - self.previous_t_train_batch_end 83 | ) 84 | else: 85 | self.duration += t - self.t_start 86 | end_to_end_time = None 87 | end_to_start_time = None 88 | 89 | self.previous_t_train_batch_end = t 90 | 91 | # Optimal total flops 92 | optimal_flops = self.duration * self.device_flops_per_second 93 | 94 | # MFU 95 | trained_mfu = self.trained_flops / self.trained_optimal_flops 96 | mfu = self.trained_flops / optimal_flops 97 | 98 | # Trained duration fraction 99 | self.trained_duration += train_step_duration 100 | trained_duration_fraction = self.trained_duration / self.duration 101 | 102 | self.log_dict( 103 | { 104 | "flops/train_step_pflops": self.train_step_flops / 1e15, 105 | "flops/train_step_optimal_pflops": train_step_optimal_flops / 1e15, 106 | "flops/trained_pflops": self.trained_flops / 1e15, 107 | "flops/trained_optimal_pflops": self.trained_optimal_flops / 1e15, 108 | "flops/trained_mfu": trained_mfu, 109 | "flops/trained_duration_fraction": trained_duration_fraction, 110 | "flops/train_duration": self.trained_duration, 111 | "flops/duration": self.duration, 112 | "flops/mfu": mfu, 113 | "flops/train_start_to_end_time": train_step_duration, 114 | **( 115 | { 116 | "flops/train_end_to_start_time": end_to_start_time, 117 | "flops/train_end_to_end_time": end_to_end_time, 118 | } 119 | if end_to_start_time is not None and end_to_end_time is not None 120 | else {} 121 | ), 122 | } 123 | ) 124 | 125 | def on_validation_batch_end(self, *args, **kwargs) -> None: 126 | if self.train_step_flops is not None: 127 | assert self.trained_flops is not None 128 | assert self.trained_optimal_flops is not None 129 | self.log_dict( 130 | { 131 | "flops/train_step_pflops": self.train_step_flops / 1e15, 132 | "flops/trained_pflops": self.trained_flops / 1e15, 133 | "flops/trained_optimal_pflops": self.trained_optimal_flops / 1e15, 134 | "flops/train_duration": self.trained_duration, 135 | "flops/duration": self.duration, 136 | } 137 | ) 138 | -------------------------------------------------------------------------------- /src/discrete_flow_matching_pytorch/lightning.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | import lightning.pytorch as pl 4 | import torch 5 | import torch._dynamo.cache_size 6 | import torch.nn.functional as F 7 | from lightning.pytorch.loggers import WandbLogger 8 | from transformers import PreTrainedTokenizerBase 9 | 10 | from discrete_flow_matching_pytorch.model import ConvNet 11 | 12 | 13 | def get_timestep_step_sizes(timesteps: torch.Tensor) -> torch.Tensor: 14 | return -torch.diff( 15 | timesteps, 16 | append=torch.zeros([1], device=timesteps.device, dtype=timesteps.dtype), 17 | ) 18 | 19 | 20 | class DiscreteFlowMatchingNet(pl.LightningModule): 21 | def __init__( 22 | self, 23 | vocab_size: int, 24 | hidden_dim: int, 25 | num_timesteps: int, 26 | num_layers: int, 27 | tokenizer: PreTrainedTokenizerBase, 28 | val_num_sampling_steps: int = 8, 29 | scheduler_type: Literal["linear", "square"] = "square", 30 | learning_rate: float = 1e-2, 31 | ): 32 | super().__init__() 33 | 34 | self.num_layers = num_layers 35 | self.tokenizer = tokenizer 36 | self.pad_token_id = self.tokenizer.pad_token_id 37 | self.mask_token_id = self.tokenizer.mask_token_id 38 | self.val_num_sampling_steps = val_num_sampling_steps 39 | self.learning_rate = learning_rate 40 | 41 | self.scheduler = torch.linspace( 42 | 1 / num_timesteps, 1, steps=num_timesteps, dtype=torch.float32 43 | ) # Probability path scheduler 44 | 45 | match scheduler_type: 46 | case "linear": 47 | pass 48 | case "square": 49 | # Put more weight on higher (=more noisy) timesteps. 50 | # Examples: 51 | # 0 -> 0 (no noise) 52 | # 0.5 -> 0.75 (50% noise moved to 75% noise) 53 | # 1 -> 1 (all noise) 54 | self.scheduler = 1 - torch.square(1 - self.scheduler) 55 | case _: 56 | raise ValueError(f"Invalid scheduler type: {scheduler_type}") 57 | 58 | self.model = ConvNet( 59 | vocab_size=vocab_size, 60 | hidden_dim=hidden_dim, 61 | num_timesteps=num_timesteps, 62 | num_layers=num_layers, 63 | ) 64 | 65 | self.save_hyperparameters() 66 | 67 | def on_fit_start(self) -> None: 68 | self.scheduler = self.scheduler.to(self.device) 69 | 70 | print("Setting learning rate to", self.learning_rate) 71 | opt = self.optimizers(False) 72 | assert isinstance(opt, torch.optim.Optimizer) 73 | opt.param_groups[0]["lr"] = self.learning_rate 74 | 75 | def on_validation_model_eval(self) -> None: 76 | self.scheduler = self.scheduler.to(self.device) 77 | return super().on_validation_model_eval() 78 | 79 | def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: 80 | # x: B, L, V 81 | # t: B 82 | # model: B, L, V 83 | return self.model(x, t) 84 | 85 | def forward_noising( 86 | self, x: torch.Tensor, t: torch.Tensor, should_noise: torch.Tensor | None 87 | ) -> torch.Tensor: 88 | """Mask x (BL) depending on time step t (BL).""" 89 | 90 | # t is the masking probability. t=0%: dont mask anything, t=100%: mask everything 91 | mask_prob = self.scheduler[t].expand(-1, x.shape[1]) 92 | will_mask = torch.bernoulli(mask_prob).to(dtype=torch.bool) 93 | 94 | # Don't mask padding tokens 95 | will_mask &= x != self.pad_token_id 96 | 97 | # Don't mask tokens that should not be noised 98 | if should_noise is not None: 99 | will_mask &= should_noise 100 | 101 | noised_x = x.clone() 102 | noised_x[will_mask] = self.mask_token_id 103 | 104 | return noised_x 105 | 106 | @torch._dynamo.disable 107 | def log_training_step(self, log_dict): 108 | self.log_dict( 109 | { 110 | **log_dict, 111 | "train/learning_rate": self.trainer.optimizers[0].param_groups[0]["lr"], 112 | } 113 | ) 114 | 115 | def training_step(self, batch, batch_idx: int): 116 | # B L 117 | x: torch.Tensor = batch["input_ids"] 118 | should_noise: torch.Tensor | None = batch.get("should_noise") 119 | 120 | # t: B 121 | t = torch.randint(0, len(self.scheduler), [x.size(0)], device=x.device) 122 | 123 | # noised_x: B L 124 | noised_x = self.forward_noising( 125 | x=x, t=t.unsqueeze(1), should_noise=should_noise 126 | ) 127 | 128 | # Unmasking logits: B L V 129 | logits = self(noised_x, t) # .to(torch.float32) 130 | 131 | target = x.clone() 132 | # Only calculate loss on tokens that were masked 133 | target[noised_x != self.mask_token_id] = -100 134 | 135 | loss = F.cross_entropy( 136 | # CE expects input BVL, target BL 137 | input=logits.transpose(-1, -2), 138 | target=target, 139 | reduction="mean", 140 | ) 141 | self.log_training_step({"train/loss": loss}) 142 | 143 | return loss 144 | 145 | @torch._dynamo.disable 146 | def log_validation_step( 147 | self, 148 | num_samples, 149 | input_text_tokenized, 150 | generated_texts_tokenized, 151 | noised_texts_tokenized, 152 | sampling_timesteps, 153 | losses, 154 | ): 155 | # input_text: B 156 | # generated_texts: T B 157 | # noised_texts: T B 158 | # sampling_timesteps: T 159 | 160 | input_text = self.tokenizer.batch_decode(input_text_tokenized) 161 | generated_texts = [ 162 | self.tokenizer.batch_decode(t) for t in generated_texts_tokenized 163 | ] 164 | noised_texts = [self.tokenizer.batch_decode(t) for t in noised_texts_tokenized] 165 | 166 | self.log_dict(losses) 167 | 168 | num_samples = min(num_samples, len(input_text)) 169 | 170 | if isinstance(self.logger, WandbLogger): 171 | for i_t, t in enumerate(sampling_timesteps): 172 | self.logger.log_table( 173 | f"validation-texts/{t}", 174 | columns=["input_text", "generated_text", "generated_text_inputs"], 175 | data=[ 176 | [ 177 | input_text[i], 178 | generated_texts[i_t][i], 179 | noised_texts[i_t][i], 180 | ] 181 | for i in range(num_samples) 182 | ], 183 | ) 184 | 185 | def _get_sampling_timesteps(self, num_sampling_steps): 186 | return torch.linspace( 187 | len(self.scheduler) - 1, 188 | len(self.scheduler) // num_sampling_steps, 189 | num_sampling_steps, 190 | device=self.device, 191 | dtype=torch.long, 192 | ) 193 | 194 | @torch._dynamo.disable 195 | def validation_step_without_compile( 196 | self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int 197 | ): 198 | # x: B L 199 | x = batch["input_ids"] 200 | should_noise = batch.get("should_noise") 201 | 202 | num_samples = 5 # Number of samples to visualize 203 | 204 | sampling_timesteps = self._get_sampling_timesteps(self.val_num_sampling_steps) 205 | 206 | losses = {} 207 | 208 | # Apply forward noising and reverse process 209 | noised_texts_tokenized = [] 210 | generated_texts_tokenized = [] 211 | 212 | # for t, step_size in zip(sampling_timesteps, step_sizes, strict=True): 213 | for t in sampling_timesteps: 214 | t = t.repeat(x.shape[0]) 215 | assert t.shape == x.shape[:1], t.shape 216 | 217 | # B L 218 | noised_x = self.forward_noising( 219 | x, t.unsqueeze(1), should_noise=should_noise 220 | ) 221 | 222 | # Only calculate loss on tokens that were masked 223 | # B L 224 | target = x.clone() 225 | # Only calculate loss on tokens that were masked 226 | target[noised_x != self.mask_token_id] = -100 227 | 228 | # B L V 229 | logits = self(noised_x, t) 230 | 231 | # Get samples for each token 232 | # B L 233 | # samples = torch.distributions.Categorical(logits=logits).sample() 234 | samples = torch.argmax(logits, dim=-1) 235 | 236 | # Unmask the masked tokens 237 | # B L 238 | generated_tokens = noised_x.clone() 239 | generated_tokens[noised_x == self.mask_token_id] = samples[ 240 | noised_x == self.mask_token_id 241 | ] 242 | 243 | generated_texts_tokenized.append(generated_tokens) 244 | noised_texts_tokenized.append(noised_x) 245 | 246 | losses[f"validation-losses/loss_{t[0]}"] = F.cross_entropy( 247 | input=logits.transpose(-1, -2), target=target, reduction="mean" 248 | ) 249 | losses["validation/loss_mean"] = torch.mean(torch.tensor(list(losses.values()))) 250 | 251 | self.log_validation_step( 252 | num_samples=num_samples, 253 | input_text_tokenized=x, 254 | generated_texts_tokenized=generated_texts_tokenized, 255 | noised_texts_tokenized=noised_texts_tokenized, 256 | sampling_timesteps=sampling_timesteps, 257 | losses=losses, 258 | ) 259 | 260 | return losses["validation/loss_mean"] 261 | 262 | def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int): 263 | return self.validation_step_without_compile(batch, batch_idx) 264 | 265 | def sample( 266 | self, 267 | num_sampling_steps: int, 268 | num_samples: int | None = None, 269 | sequence_length: int | None = None, 270 | x: torch.Tensor | None = None, 271 | stochasticity: float = 0.0, 272 | yield_intermediate: bool = False, 273 | yield_logits: bool = False, 274 | temperature: float = 1.0, 275 | cfg_scale: float = 1.0, 276 | ): 277 | assert ( 278 | num_samples is not None and sequence_length is not None 279 | ) or x is not None, "Must pass either (num_samples and sequence_length) or x" 280 | 281 | assert not ( 282 | yield_intermediate and yield_logits 283 | ), "Can't yield both logits and intermediate results" 284 | 285 | # B L 286 | if x is None: 287 | # Start fully masked 288 | x = torch.full( 289 | [num_samples, sequence_length], 290 | fill_value=self.tokenizer.mask_token_id, 291 | dtype=torch.long, 292 | device=self.device, 293 | ) 294 | should_noise = None 295 | else: 296 | should_noise = x == self.mask_token_id 297 | 298 | # Create the integer timesteps and step sizes for the given num_sampling_steps 299 | # S 300 | sampling_timesteps = self._get_sampling_timesteps(num_sampling_steps) 301 | relative_ts = self.scheduler[sampling_timesteps] 302 | relative_dts = get_timestep_step_sizes(relative_ts) 303 | 304 | for t, relative_t, relative_dt in zip( 305 | sampling_timesteps, relative_ts, relative_dts 306 | ): 307 | is_last_step = t == sampling_timesteps[-1] 308 | if yield_intermediate: 309 | yield t, x 310 | 311 | # B 312 | t = t.repeat(x.shape[0]) 313 | assert t.shape == x.shape[:1], t.shape 314 | 315 | # B L V 316 | logits = self(x, t) 317 | 318 | if cfg_scale != 1.0: 319 | assert should_noise is not None 320 | x_uncond = x.clone() 321 | x_uncond[~should_noise] = self.mask_token_id 322 | 323 | # Classifier-free guidance 324 | # Run model unconditionally (conditioning fully masked) 325 | logits_uncond = self(x_uncond, t) 326 | 327 | # Mix the logits according to cfg_scale 328 | logits = logits_uncond + cfg_scale * (logits - logits_uncond) 329 | 330 | if yield_logits: 331 | yield t, logits 332 | 333 | # B L 334 | samples = torch.distributions.Categorical( 335 | logits=logits / temperature 336 | ).sample() 337 | 338 | # B L 339 | # Chance to unmask proportional to 340 | # - step size: higher step size means higher chance 341 | # - timestep: lower timestep means higher chance (so in the end the chance is 100%) 342 | unmask_threshold = relative_dt / relative_t 343 | 344 | # With remasking, the unmasking probability is changed 345 | if stochasticity != 0: 346 | unmask_threshold *= 1 + stochasticity * (1 - relative_t) 347 | 348 | was_masked = x == self.mask_token_id 349 | 350 | # Unmask 351 | will_unmask = ( 352 | torch.rand( 353 | x.shape[:2], 354 | device=unmask_threshold.device, 355 | dtype=unmask_threshold.dtype, 356 | ) 357 | < unmask_threshold 358 | ) 359 | # Only unmask the tokens that were masked 360 | will_unmask &= was_masked 361 | 362 | # Remask when stochasticity is non-zero 363 | if stochasticity != 0 and not is_last_step: 364 | remask_threshold = relative_dt * stochasticity 365 | will_remask = ( 366 | torch.rand( 367 | x.shape[:2], 368 | device=unmask_threshold.device, 369 | dtype=unmask_threshold.dtype, 370 | ) 371 | < remask_threshold 372 | ) 373 | # Only remask the tokens that were unmasked 374 | will_remask &= ~was_masked 375 | 376 | # Only remask tokens that aren't constant 377 | if should_noise is not None: 378 | will_remask &= should_noise 379 | 380 | x[will_remask] = self.mask_token_id 381 | 382 | # B L 383 | x[will_unmask] = samples[will_unmask] 384 | 385 | if yield_intermediate: 386 | yield torch.zeros_like(t), x 387 | else: 388 | return x 389 | 390 | def configure_optimizers(self): 391 | return torch.optim.AdamW(self.parameters(), lr=self.learning_rate) 392 | -------------------------------------------------------------------------------- /src/discrete_flow_matching_pytorch/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Unsqueeze(nn.Module): 6 | def __init__(self, dim: int): 7 | super().__init__() 8 | self.dim = dim 9 | 10 | def forward(self, x): 11 | return x.unsqueeze(self.dim) 12 | 13 | 14 | class Reshape(nn.Module): 15 | def __init__(self, shape: list[int]): 16 | super().__init__() 17 | self.shape = shape 18 | 19 | def forward(self, x): 20 | return x.reshape(self.shape) 21 | 22 | 23 | class Transpose(nn.Module): 24 | def forward(self, x): 25 | return x.transpose(-1, -2) 26 | 27 | 28 | class ConvNet(nn.Module): 29 | def __init__( 30 | self, 31 | vocab_size: int, 32 | hidden_dim: int, 33 | num_timesteps: int, 34 | num_layers: int, 35 | kernel_size: int = 31, 36 | ): 37 | super().__init__() 38 | 39 | # x: B, L 40 | 41 | # Input embedding, B L -> B L C 42 | self.input_projection = nn.Embedding(vocab_size, hidden_dim) 43 | 44 | # Embed timestep to B, 1, C 45 | self.embed_timestep = nn.Sequential( 46 | nn.Embedding(num_timesteps, hidden_dim), 47 | Unsqueeze(1), 48 | ) 49 | 50 | self.blocks = nn.ModuleList() 51 | self.timestep_embedding_norms = nn.ModuleList() 52 | 53 | for _ in range(num_layers): 54 | self.blocks.append( 55 | nn.Sequential( 56 | Transpose(), 57 | nn.Conv1d( 58 | hidden_dim, hidden_dim, kernel_size=kernel_size, padding="same" 59 | ), 60 | Transpose(), 61 | nn.LayerNorm([hidden_dim]), 62 | nn.GELU(), 63 | nn.Linear(hidden_dim, hidden_dim), 64 | nn.LayerNorm([hidden_dim]), 65 | nn.GELU(), 66 | ), 67 | ) 68 | self.timestep_embedding_norms.append(nn.LayerNorm([hidden_dim])) 69 | 70 | # Output projection, B L C -> B L V 71 | self.output_projection = nn.Linear(hidden_dim, vocab_size) 72 | 73 | def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: 74 | # x: B, L, V, t: B 75 | x = self.input_projection(x) # BLC 76 | 77 | for block, timestep_embedding_norm in zip( 78 | self.blocks, self.timestep_embedding_norms 79 | ): 80 | x = x + block(x + timestep_embedding_norm(self.embed_timestep(t))) # BLC 81 | 82 | x = self.output_projection(x) # BLV 83 | 84 | return x 85 | -------------------------------------------------------------------------------- /src/discrete_flow_matching_pytorch/train.py: -------------------------------------------------------------------------------- 1 | from contextlib import nullcontext 2 | from time import time 3 | 4 | import lightning.pytorch as pl 5 | import torch 6 | from jsonargparse import CLI 7 | from lightning.pytorch.callbacks import ModelCheckpoint 8 | from lightning.pytorch.loggers import WandbLogger 9 | from pyinstrument import Profiler 10 | from pyinstrument.renderers.speedscope import SpeedscopeRenderer 11 | from structlog import get_logger 12 | from torch.utils.data import DataLoader 13 | 14 | from discrete_flow_matching_pytorch.data import ( 15 | get_default_tokenizer, 16 | load_dataset_by_name, 17 | ) 18 | from discrete_flow_matching_pytorch.flops import FlopCounterCallback 19 | from discrete_flow_matching_pytorch.lightning import DiscreteFlowMatchingNet 20 | 21 | logger = get_logger() 22 | 23 | 24 | def get_run_name( 25 | dataset: str, 26 | train_batch_size: int, 27 | hidden_dim: int, 28 | num_layers: int, 29 | scheduler_type: str, 30 | ): 31 | return f"{dataset}-bs={train_batch_size}-h={hidden_dim}-l={num_layers}-s={scheduler_type}" 32 | 33 | 34 | def main( 35 | compile: bool = True, 36 | wandb: bool = True, 37 | max_steps: int = -1, 38 | profile: bool = False, 39 | train_step_flops: float | None = None, 40 | ckpt_path: str = "", 41 | val_interval: int = 500, 42 | checkpoint_interval: int = 1_000, 43 | dataset: str = "tiny_stories", 44 | train_batch_size: int = 256, 45 | shuffle_train: bool = True, # Needs to be false for IterableDataset 46 | train_workers: int = 2, 47 | val_split_name: str = "validation", # Some datasets don't have validation, but we can still use train 48 | hidden_dim: int = 768, 49 | num_layers: int = 6, 50 | scheduler_type: str = "square", 51 | learning_rate: float = 1e-2, 52 | ): 53 | torch.set_float32_matmul_precision("high") 54 | 55 | # Load tokenizer 56 | logger.info("Loading tokenizer") 57 | tokenizer = get_default_tokenizer() 58 | 59 | logger.info("Loading dataset", dataset=dataset) 60 | train_data = load_dataset_by_name( 61 | dataset=dataset, tokenizer=tokenizer, split="train" 62 | ) 63 | val_data = load_dataset_by_name( 64 | dataset=dataset, tokenizer=tokenizer, split=val_split_name 65 | ) 66 | 67 | # Dataloader 68 | logger.info("Creating data loaders") 69 | train_loader = DataLoader( 70 | train_data, 71 | batch_size=train_batch_size, 72 | shuffle=shuffle_train, 73 | num_workers=train_workers, 74 | prefetch_factor=2, 75 | ) 76 | val_loader = DataLoader( 77 | val_data, batch_size=64, shuffle=False, num_workers=1, prefetch_factor=2 78 | ) 79 | 80 | # Create model 81 | logger.info("Creating model") 82 | model = DiscreteFlowMatchingNet( 83 | vocab_size=len(tokenizer), # .vocab_size excludes the new tokens 84 | hidden_dim=hidden_dim, 85 | num_timesteps=1024, 86 | num_layers=num_layers, 87 | tokenizer=tokenizer, 88 | scheduler_type=scheduler_type, 89 | learning_rate=learning_rate, 90 | ).to(dtype=torch.bfloat16) 91 | 92 | if compile: 93 | torch._dynamo.config.cache_size_limit = 512 94 | torch._dynamo.config.capture_scalar_outputs = True 95 | torch._dynamo.config.capture_dynamic_output_shape_ops = True 96 | torch._inductor.config.fx_graph_cache = True 97 | model = torch.compile(model) 98 | 99 | # Train model 100 | if train_step_flops is not None: 101 | logger.info("Using train step flops", train_step_flops=train_step_flops) 102 | else: 103 | logger.info("Using dynamic train step flops") 104 | 105 | trainer = pl.Trainer( 106 | max_epochs=-1, 107 | max_steps=max_steps, 108 | accelerator="gpu", 109 | devices=1 if torch.cuda.is_available() else 0, 110 | limit_val_batches=1, 111 | callbacks=[ 112 | FlopCounterCallback(train_step_flops=train_step_flops), 113 | ModelCheckpoint(every_n_train_steps=checkpoint_interval), 114 | ], 115 | logger=WandbLogger( 116 | project="discrete-flow-matching", 117 | name=get_run_name( 118 | dataset=dataset, 119 | train_batch_size=train_batch_size, 120 | hidden_dim=hidden_dim, 121 | num_layers=num_layers, 122 | scheduler_type=scheduler_type, 123 | ), 124 | ) 125 | if wandb 126 | else None, 127 | precision="bf16", 128 | gradient_clip_algorithm="norm", 129 | gradient_clip_val=0.1, 130 | check_val_every_n_epoch=None, 131 | val_check_interval=val_interval, 132 | log_every_n_steps=25, 133 | num_sanity_val_steps=0, 134 | ) 135 | 136 | fit_context = nullcontext() if not profile else Profiler(async_mode="disabled") 137 | with fit_context: 138 | # Run validation before training to get the initial loss 139 | trainer.validate(model=model, dataloaders=val_loader) 140 | 141 | trainer.fit( 142 | model=model, 143 | train_dataloaders=train_loader, 144 | val_dataloaders=val_loader, 145 | ckpt_path=ckpt_path if ckpt_path else None, 146 | ) 147 | 148 | if profile: 149 | with open(f"speedscope-{time()}.json", "w") as f: 150 | f.write(fit_context.output(SpeedscopeRenderer())) 151 | 152 | 153 | if __name__ == "__main__": 154 | CLI(main) 155 | --------------------------------------------------------------------------------