├── .gitignore ├── assets └── demo.gif ├── LICENSE ├── demo.py ├── extra └── train_alternative.py ├── train.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | distilbert-diffusion-TinyStories 2 | __pycache__ -------------------------------------------------------------------------------- /assets/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gumran/language-diffusion/HEAD/assets/demo.gif -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Alim Gumran 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForMaskedLM, AutoTokenizer 2 | import torch 3 | from rich.live import Live # rich will help visualize diffusion sampling 4 | from rich.console import Console 5 | 6 | device = 'cuda' 7 | model_name = "distilbert-diffusion-TinyStories" 8 | model = AutoModelForMaskedLM.from_pretrained(model_name).to(device) 9 | tokenizer = AutoTokenizer.from_pretrained(model_name) 10 | 11 | seq_len = 512 12 | num_steps = 512 # can increase for better quality 13 | times = torch.linspace(1, 0, num_steps + 1, device=device) # linear reverse process time steps 14 | 15 | # initialize the fully masked sequence 16 | x = torch.full((1, seq_len), tokenizer.mask_token_id, dtype=torch.int64, device=device, requires_grad=False) 17 | mask = torch.ones((1, seq_len), dtype=torch.bool, device=device, requires_grad=False) 18 | attention_mask = torch.ones((1, seq_len), dtype=torch.long, device=device, requires_grad=False) # attend to all tokens 19 | 20 | # sampling process based on Algorithm 4 from https://arxiv.org/abs/2502.09992 21 | model.eval() 22 | console = Console() 23 | with torch.no_grad(): 24 | with Live("", refresh_per_second=10, console=console) as live: 25 | for t, s in zip(times[:-1], times[1:]): 26 | logits = model(x, attention_mask=attention_mask).logits 27 | x[mask] = logits[mask].argmax(-1) # greedily predict the masked tokens 28 | decoded = tokenizer.batch_decode(x, skip_special_tokens=True)[0] 29 | live.update(decoded) 30 | 31 | remask_probs = torch.rand((1, seq_len), device=device) < s/t 32 | mask = mask & remask_probs 33 | x[mask] = tokenizer.mask_token_id # remask each of the predicted tokens with probability s/t -------------------------------------------------------------------------------- /extra/train_alternative.py: -------------------------------------------------------------------------------- 1 | import math 2 | from transformers import AutoTokenizer, AutoModelForMaskedLM, get_cosine_schedule_with_warmup 3 | import datasets 4 | import torch 5 | from tqdm import tqdm 6 | 7 | device = "cuda" 8 | model_name = "distilbert-base-cased" 9 | tokenizer = AutoTokenizer.from_pretrained(model_name) 10 | model = AutoModelForMaskedLM.from_pretrained(model_name).to(device) 11 | 12 | # training args 13 | num_epochs = 3 14 | batch_size = 80 15 | max_length = 512 16 | gradient_accumulation_steps = 2 17 | log_steps = 50 18 | mixed_precision = "fp16" 19 | 20 | # load and tokenize the dataset 21 | dataset_name = "roneneldan/TinyStories" 22 | dataset = datasets.load_dataset(dataset_name, split="train") 23 | def tok_fn(examples): 24 | return tokenizer(examples["text"], max_length=max_length, 25 | padding="max_length", truncation=True, add_special_tokens=False) 26 | tok_dataset = dataset.map(tok_fn, batched=True, remove_columns=["text"]) 27 | tok_dataset = tok_dataset.with_format("torch") 28 | dataloader = torch.utils.data.DataLoader(tok_dataset, batch_size=batch_size, shuffle=True, drop_last=True) 29 | 30 | # define the optimizer and learning rate scheduler 31 | lr = 1e-4 32 | weight_decay = 0.01 33 | optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) 34 | total_steps = num_epochs * math.ceil(len(dataloader) / gradient_accumulation_steps) 35 | warmup_ratio = 0.05 # 5% warmup 36 | warmup_steps = int(warmup_ratio * total_steps) 37 | lr_scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps) 38 | 39 | model.train() 40 | for epoch in range(num_epochs): 41 | loss_cumsum = 0 42 | pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}") 43 | for step, inputs in enumerate(pbar): 44 | input_ids = inputs['input_ids'] 45 | 46 | # sample timesteps and mask the sequences 47 | t = torch.rand(batch_size, 1, device=device).clamp_min(1e-4).expand(batch_size, max_length) 48 | mask = torch.bernoulli(t).bool() 49 | corrupted = input_ids.masked_fill(mask, tokenizer.mask_token_id) 50 | labels = input_ids.masked_fill(~mask, -100) # ground truth (ignore all unmasked tokens) 51 | 52 | outputs = model(input_ids=corrupted) # attend to and predict padding tokens too! 53 | logits = outputs.logits 54 | per_tok_loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), 55 | labels.view(-1), reduction="none", ignore_index=-100).view(batch_size, max_length) 56 | loss = (per_tok_loss / t).mean() # weight by time step 57 | (loss / gradient_accumulation_steps).backward() 58 | if (step + 1) % gradient_accumulation_steps == 0: 59 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 60 | optimizer.step() 61 | optimizer.zero_grad() 62 | lr_scheduler.step() 63 | loss_cumsum += loss.item() 64 | if (step + 1) % log_steps == 0: 65 | pbar.set_postfix({"Loss": f"{loss_cumsum / log_steps:.4f}"}) 66 | loss_cumsum = 0 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import math 2 | from transformers import AutoTokenizer, AutoModelForMaskedLM, get_cosine_schedule_with_warmup 3 | import datasets 4 | import torch 5 | from accelerate import Accelerator 6 | from tqdm import tqdm 7 | 8 | model_name = "distilbert-base-cased" 9 | tokenizer = AutoTokenizer.from_pretrained(model_name) 10 | model = AutoModelForMaskedLM.from_pretrained(model_name) 11 | 12 | num_epochs = 3 13 | batch_size = 80 14 | seq_len = 512 15 | gradient_accumulation_steps = 1 16 | log_steps = 50 17 | mixed_precision = "fp16" 18 | lr = 1e-4 19 | weight_decay = 0.01 20 | warmup_ratio = 0.05 # 5% warmup steps 21 | 22 | dataset_name = "roneneldan/TinyStories" 23 | dataset = datasets.load_dataset(dataset_name, split="train") 24 | def tok_fn(examples): 25 | return tokenizer(examples["text"], max_length=seq_len, 26 | padding="max_length", truncation=True, add_special_tokens=False) 27 | tok_dataset = dataset.map(tok_fn, batched=True, remove_columns=["text"]) 28 | tok_dataset = tok_dataset.with_format("torch") 29 | dataloader = torch.utils.data.DataLoader(tok_dataset, batch_size=batch_size, shuffle=True, drop_last=True) 30 | 31 | optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) 32 | total_steps = num_epochs * math.ceil(len(dataloader) / gradient_accumulation_steps) 33 | lr_scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=int(warmup_ratio * total_steps), num_training_steps=total_steps) 34 | 35 | accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps, mixed_precision=mixed_precision) 36 | model, optimizer, dataloader, lr_scheduler = accelerator.prepare(model, optimizer, dataloader, lr_scheduler) 37 | 38 | # training loop adapted from Algorithms 1 & 2 from https://arxiv.org/abs/2502.09992 39 | model.train() 40 | for epoch in range(num_epochs): 41 | loss_cumsum = 0 42 | pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", disable=not accelerator.is_main_process) 43 | for step, inputs in enumerate(pbar): 44 | input_ids = inputs['input_ids'] 45 | attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long, 46 | device=accelerator.device) # attend to all tokens 47 | # sample t and randomly mask each token with probability t 48 | t = torch.rand(batch_size, 1, device=accelerator.device).clamp_min(1e-4).expand(batch_size, seq_len) 49 | mask = torch.bernoulli(t).bool() 50 | corrupted = input_ids.masked_fill(mask, tokenizer.mask_token_id) 51 | labels = input_ids.masked_fill(~mask, -100) # compute loss only on masked tokens 52 | with accelerator.accumulate(model): 53 | outputs = model(input_ids=corrupted, attention_mask=attention_mask) 54 | logits = outputs.logits 55 | per_tok_loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), 56 | labels.view(-1), reduction="none", ignore_index=-100).view(batch_size, seq_len) 57 | loss = (per_tok_loss / t).mean() # weight by time step 58 | accelerator.backward(loss) 59 | if accelerator.sync_gradients: 60 | accelerator.clip_grad_norm_(model.parameters(), 1.0) 61 | optimizer.step() 62 | optimizer.zero_grad() 63 | lr_scheduler.step() 64 | loss_cumsum += accelerator.gather(loss.detach()).mean().item() 65 | if accelerator.is_main_process and (step + 1) % log_steps == 0: 66 | pbar.set_postfix({"Loss": f"{loss_cumsum / log_steps:.4f}"}) 67 | loss_cumsum = 0 68 | accelerator.end_training() 69 | 70 | save_directory = "distilbert-diffusion-TinyStories" 71 | accelerator.unwrap_model(model).save_pretrained(save_directory) 72 | tokenizer.save_pretrained(save_directory) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Language Diffusion in <80 Lines of Code 2 | A quick implementation of diffusion language models using `transformers`. 3 | 4 | ![diffusion](assets/demo.gif) 5 | 6 | Much of this work is adapted from the paper [Large Language Diffusion Models](https://arxiv.org/pdf/2502.09992) by Nie et al. (2025). I've tried to keep the code clean and concise, so currently the training script has fewer than 80 lines of code. 7 | 8 | ## Setup 9 | I recommend using `uv` to install packages (you can also just use `pip`): 10 | ``` 11 | pip install uv 12 | uv pip install torch transformers datasets accelerate tqdm rich 13 | ``` 14 | 15 | ## Run 16 | - Run `accelerate launch train.py` to finetune [DistilBERT](https://huggingface.co/distilbert/distilbert-base-cased) on the [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) dataset. 17 | - Change the training arguments as required by your compute constraints. 18 | - I also uploaded the trained diffusion model to [Hugging Face](https://huggingface.co/gumran/distilbert-diffusion-TinyStories). 19 | - Run `python demo.py` to use a trained model to generate short stories similar to those in the dataset. 20 | - See below for details on how the scripts work. 21 | 22 | ## How it works 23 | ### Model 24 | The model used is DistilBERT, which is pretrained for masked language modeling. It is an encoder-only transformer well-suited for our purposes. Otherwise, you can swap it for any other language model - even a "decoder-only" transformer like GPT - just make sure the attention mask is full of 1s instead of causal. 25 | 26 | ### Training 27 | The training script is adapted from Algorthms 1 and 2 from the Nie et al. paper: 28 | 1. A sequence `x` is sampled from the training corpus; 29 | 2. A time step `t` is sampled uniformly between 0 and 1; 30 | 3. Each token in `x` is masked with probability `t`; 31 | 4. The model is trained to predict the masked tokens via maximum likelihood. 32 | 33 | Importantly, the padding tokens can also be attended to, masked and modeled. 34 | 35 | ### Inference 36 | The `demo.py` file is based on Algorithm 4: 37 | 1. We start with a fully masked sequence `x`; 38 | 2. For `t` going from 1 to 0 linearly in `T` steps: 39 | - Predict the masked tokens; 40 | - Remask each of the predicted tokens with probability `s/t`, where `s` is the next value of `t`. 41 | 42 | We have a fully unmasked sequence at the end. 43 | 44 | Note that Nie et al. also describe a "lowest confidence" sampling process, but it is deterministic and unsuitable for unconditional generation. For more details, I recommend reading the paper and its references on language diffusion, such as [Simple and Effective Masked Diffusion Language Models](https://arxiv.org/abs/2406.07524) by Sahoo et al. (2024). 45 | 46 | ## Notes 47 | - Diffusion language models strongly remind me of the novella ["Story of Your Life"](https://en.wikipedia.org/wiki/Story_of_Your_Life) by Ted Chiang, which the movie [_Arrival_](https://en.wikipedia.org/wiki/Arrival_(film)) is based on. 48 | - Of course, these models cannot tell the future, but they do process and communicate language non-sequentially. However, the language itself that they are trained to produce is sequential in nature, unlike in Chiang's story. Perhaps there is a way to train them on non-sequential representations of human language - if there are any good systems for that? 49 | - For more comprehension, one might build the language model architecture from the ground up in PyTorch without `transformers`, but this was not the point of this project. 50 | - Still, I might do something like that in future. `extra/train_alternative.py` is supposed to handle single-GPU training without `accelerate`, but I haven't tested it yet. Dependencies can be further removed and this might grow into something resembling a package. 51 | - Increasing the size of the model and pretraining data and training on an instruction dataset via conditional maximum likelihood should achieve similar results to Nie et al. Interestingly, there should also be ways to align language diffusion models via RLHF since [it has been done](https://arxiv.org/abs/2302.08242) in the image domain. 52 | 53 | ## Contributing 54 | I welcome any contributions to this repository. As mentioned above, I might want to relax the reliance on dependencies and/or think of instruction tuning and alignment. 55 | --------------------------------------------------------------------------------