├── README.md └── finetune_fm.py /README.md: -------------------------------------------------------------------------------- 1 | # ORW-CFM-W2: Online Reward-Weighted Conditional Flow Matching with Wasserstein-2 Regularization 2 | 3 | This repository contains the implementation of "Online Reward-Weighted Fine-Tuning of Flow Matching with Wasserstein Regularization," a method for fine-tuning flow-based generative models using reinforcement learning. 4 | 5 | ## Overview 6 | 7 | ORW-CFM-W2 is a novel reinforcement learning approach for fine-tuning continuous flow-based generative models to align with arbitrary user-defined reward functions. Unlike previous methods that require filtered datasets or gradients of rewards, our method enables optimization with arbitrary reward functions while preventing policy collapse through Wasserstein-2 distance regularization. 8 | 9 | ## Method 10 | 11 | Our approach integrates reinforcement learning into the flow matching framework through three key components: 12 | 13 | 1. **Online Reward-Weighting**: Guides the model to prioritize high-reward regions in the data manifold 14 | 2. **Wasserstein-2 Regularization**: Prevents policy collapse and maintains diversity 15 | 3. **Tractable W2 Distance Bound**: Enables efficient computation of the W2 distance in flow matching models 16 | 17 | The loss function is defined as: 18 | 19 | $$ L_{ORW-CFM-W2} = \mathbb{E}[ w(x_1) \|v_{\theta_{\text{ft}}}(t, x) - u_t(x|x_1)\|^2 + \alpha \|v_{\theta_{\text{ft}}}(t, x) - v_{\theta_{\text{ref}}}(t, x)\|^2] $$ 20 | 21 | 22 | Where: 23 | - $w(x_1) \propto r(x_1)$ is the weighting function proportional to the reward 24 | - $v_{\theta_{\text{ft}}}$ is the fine-tuned model's vector field 25 | - $v_{\theta_{\text{ref}}}$ is the reference (pre-trained) model's vector field 26 | - $u_t(x|x_1)$ is the true conditional vector field 27 | - $\alpha$ is the regularization coefficient that controls the trade-off between reward and diversity 28 | 29 | ## Implementation 30 | 31 | The core implementation is in the `ORWCFMTrainer` class, which handles: 32 | 33 | 1. Initialization of models (network model, last policy, reference model) 34 | 2. Sampling from the current policy 35 | 3. Computing rewards for samples 36 | 4. Computing the loss with both FM and W2 components 37 | 5. Updating the model parameters 38 | 6. Periodically updating the sampling policy 39 | 40 | ## Usage 41 | 42 | ### Basic Usage 43 | 44 | ```python 45 | from torchcfm.conditional_flow_matching import ExactOptimalTransportConditionalFlowMatcher 46 | from torchcfm.models.unet.unet import UNetModelWrapper 47 | from orwcfm import ORWCFMTrainer 48 | 49 | # Define configuration 50 | config = { 51 | 'learning_rate': 2e-4, 52 | 'warmup_steps': 5000, 53 | 'w2_coefficient': 1.0, # alpha parameter for W2 regularization, we encourage you to use at least alpha>=1.0 54 | 'temperature': 0.5, # tau parameter for reward weighting 55 | 'grad_clip': 1.0, 56 | 'batch_size': 128, 57 | 'text_prompts': ["An image of dog", "Not an image of dog"], 58 | 'use_wandb': True, 59 | 'wandb_project': 'flow-matching', 60 | 'run_name': 'orw-cfm-w2', 61 | 'savedir': './results', 62 | 'ref_path': './pretrained/model.pt' # Path to pre-trained model 63 | } 64 | 65 | # Initialize model 66 | model = UNetModelWrapper( 67 | dim=(3, 32, 32), 68 | num_res_blocks=2, 69 | num_channels=128, 70 | channel_mult=[1, 2, 2, 2], 71 | num_heads=4, 72 | num_head_channels=64, 73 | attention_resolutions="16", 74 | dropout=0.1 75 | ) 76 | 77 | # Initialize device 78 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 79 | 80 | # Initialize trainer 81 | trainer = ORWCFMTrainer(model, config, device) 82 | 83 | # Load pre-trained model 84 | trainer.load_pretrained(config['ref_path']) 85 | 86 | # Train model 87 | trainer.train( 88 | num_epochs=1000, 89 | steps_per_epoch=100 90 | ) 91 | 92 | # Save checkpoint 93 | trainer.save_checkpoint('./checkpoints/orw_cfm_w2.pt') 94 | ``` 95 | 96 | ### Key Parameters 97 | 98 | - **w2_coefficient (alpha)**: Controls the strength of the W2 regularization. Higher values prioritize staying close to the reference model, leading to more diverse outputs. Lower values prioritize reward maximization. 99 | - **temperature (tau)**: Controls the sharpness of the reward weighting. Higher values lead to more aggressive focusing on high-reward regions. 100 | 101 | ## Theoretical Guarantees 102 | 103 | Our method provides the following theoretical guarantees: 104 | 105 | 1. **Convergence Behavior**: The data distribution after N epochs evolves according to: 106 | 107 | $$q^N_{\theta}(x_1) \propto w(x_1) q^{N-1}_{\theta}(x_1) \exp(-\beta D^{N-1}(x_1))$$ 108 | 109 | Where $D^{N-1}(x_1)$ measures the discrepancy between the current and reference models. 110 | 111 | 2. **Limiting Behavior**: Without regularization (α=0), the model converges to a delta distribution centered at the maximum reward point. 112 | 113 | 3. **Reward-Diversity Trade-off**: W2 regularization enables a controllable trade-off between reward maximization and diversity preservation. 114 | 115 | ## Citation 116 | 117 | If you find this code useful for your research, please consider citing our paper: 118 | 119 | ```bibtex 120 | @inproceedings{ 121 | fan2025online, 122 | title={Online Reward-Weighted Fine-Tuning of Flow Matching with Wasserstein Regularization}, 123 | author={Jiajun Fan and Shuaike Shen and Chaoran Cheng and Yuxin Chen and Chumeng Liang and Ge Liu}, 124 | booktitle={The Thirteenth International Conference on Learning Representations}, 125 | year={2025}, 126 | url={https://openreview.net/forum?id=2IoFFexvuw} 127 | } 128 | ``` 129 | 130 | ## Dependencies 131 | 132 | - PyTorch 133 | - TorchCFM 134 | - wandb (optional, for logging) 135 | - tqdm 136 | 137 | ## Pre-trained Models 138 | 139 | For experiments like CIFAR-10 or MNIST, we recommend using pre-trained flow matching models from the torch-cfm repository. You can train a model using their example script: 140 | 141 | ```bash 142 | git clone https://github.com/atong01/conditional-flow-matching 143 | cd conditional-flow-matching 144 | pip install -e . 145 | python examples/images/cifar10/train_cifar10.py 146 | ``` 147 | 148 | ## License 149 | 150 | [MIT License](https://mit-license.org/) 151 | 152 | 153 | 154 | -------------------------------------------------------------------------------- /finetune_fm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.optim import Adam 4 | from torch.optim.lr_scheduler import LambdaLR 5 | import wandb 6 | from tqdm import tqdm 7 | import copy 8 | import os 9 | from typing import List, Tuple 10 | 11 | from torchcfm.conditional_flow_matching import ExactOptimalTransportConditionalFlowMatcher 12 | from torchcfm.models.unet.unet import UNetModelWrapper 13 | from utils_cifar import generate_samples 14 | from clip_reward import get_text_image_clip_score 15 | 16 | 17 | class ORWCFMTrainer: 18 | def __init__( 19 | self, 20 | model: nn.Module, 21 | config: dict, 22 | device: torch.device 23 | ): 24 | self.device = device 25 | self.config = config 26 | 27 | # Initialize models 28 | self.net_model = model.to(device) 29 | self.last_policy = copy.deepcopy(model).to(device) 30 | self.ref_model = copy.deepcopy(model).to(device) 31 | 32 | # Initialize Flow Matcher 33 | self.fm = ExactOptimalTransportConditionalFlowMatcher(sigma=0.0) 34 | 35 | # Initialize optimizer and scheduler 36 | self.optimizer = Adam(self.net_model.parameters(), lr=config['learning_rate']) 37 | self.scheduler = LambdaLR( 38 | self.optimizer, 39 | lr_lambda=lambda step: min(step, config['warmup_steps']) / config['warmup_steps'] 40 | ) 41 | 42 | # Training parameters 43 | self.alpha = config['w2_coefficient'] 44 | self.beta = config['temperature'] 45 | self.grad_clip = config['grad_clip'] 46 | self.use_wandb = config.get('use_wandb', False) 47 | self.parallel = config.get('parallel', False) 48 | self.savedir = config.get('savedir', './results') 49 | 50 | # Initialize wandb if needed 51 | if self.use_wandb: 52 | wandb.init( 53 | project=config['wandb_project'], 54 | name=config['run_name'] 55 | ) 56 | 57 | def load_pretrained(self, path: str): 58 | """Load pretrained checkpoints""" 59 | checkpoint = torch.load(path) 60 | self.net_model.load_state_dict(checkpoint['net_model']) 61 | self.last_policy.load_state_dict(checkpoint['net_model']) 62 | self.ref_model.load_state_dict(checkpoint['net_model']) 63 | 64 | if 'optim' in checkpoint and 'sched' in checkpoint: 65 | self.optimizer.load_state_dict(checkpoint['optim']) 66 | self.scheduler.load_state_dict(checkpoint['sched']) 67 | 68 | # Set models to appropriate modes 69 | self.ref_model.eval() 70 | self.last_policy.eval() 71 | 72 | @torch.no_grad() 73 | def sample_batch(self, ep: int) -> Tuple[torch.Tensor, torch.Tensor]: 74 | """Sample batch using last policy and compute rewards""" 75 | self.last_policy.eval() 76 | 77 | # Generate samples using last policy 78 | _, samples = generate_samples( 79 | self.last_policy, 80 | self.parallel, 81 | self.savedir, 82 | ep, 83 | net_="last_policy", 84 | save_img=False, 85 | use_wandb=self.use_wandb, 86 | log_image_interval=50, 87 | return_x0=True 88 | ) 89 | 90 | # Compute rewards using CLIP 91 | image_prob, _ = get_text_image_clip_score( 92 | image=samples, 93 | text=self.config['text_prompts'], 94 | return_logit=True 95 | ) 96 | rewards = image_prob[:, 0] - image_prob[:, 1] # you can use other rewards 97 | 98 | return samples, rewards 99 | 100 | def compute_loss( 101 | self, 102 | x1: torch.Tensor, 103 | weights: torch.Tensor 104 | ) -> Tuple[torch.Tensor, dict]: 105 | """Compute ORW-CFM-W2 loss""" 106 | # Generate noise 107 | x0 = torch.randn_like(x1) 108 | 109 | # Get flow matching components 110 | t, xt, ut = self.fm.sample_location_and_conditional_flow(x0, x1) 111 | 112 | # Compute vector fields 113 | vt = self.net_model(t, xt) 114 | vt_ref = self.ref_model(t, xt).detach() 115 | 116 | # Compute losses 117 | fm_loss = ((vt - ut) ** 2).mean(dim=(1, 2, 3)) 118 | w2_loss = ((vt - vt_ref) ** 2).mean(dim=(1, 2, 3)) 119 | 120 | # Combine losses 121 | total_loss = torch.mean(weights * fm_loss + self.alpha * w2_loss) 122 | 123 | metrics = { 124 | 'fm_loss': fm_loss.mean().item(), 125 | 'w2_loss': w2_loss.mean().item(), 126 | 'total_loss': total_loss.item() 127 | } 128 | 129 | return total_loss, metrics 130 | 131 | def training_step(self, ep: int) -> dict: 132 | """Execute single training step""" 133 | # Sample using last policy 134 | samples, rewards = self.sample_batch(ep) 135 | 136 | # Compute weights 137 | weights = torch.exp(self.beta * rewards).to(self.device) 138 | 139 | # Compute and optimize loss 140 | self.optimizer.zero_grad() 141 | loss, metrics = self.compute_loss(samples, weights) 142 | loss.backward() 143 | 144 | # Gradient clipping 145 | torch.nn.utils.clip_grad_norm_( 146 | self.net_model.parameters(), 147 | self.grad_clip 148 | ) 149 | 150 | self.optimizer.step() 151 | self.scheduler.step() 152 | 153 | # Update metrics 154 | metrics.update({ 155 | 'reward': rewards.mean().item(), 156 | 'max_reward': rewards.max().item() 157 | }) 158 | 159 | return metrics 160 | 161 | def update_last_policy(self): 162 | """Update last policy with current model weights""" 163 | self.last_policy.load_state_dict(self.net_model.state_dict()) 164 | self.last_policy.eval() 165 | 166 | def train(self, num_epochs: int, steps_per_epoch: int): 167 | """Training loop""" 168 | for ep in range(num_epochs): 169 | running_metrics = [] 170 | 171 | # Training steps 172 | pbar = tqdm(range(steps_per_epoch), desc=f'Epoch {ep}') 173 | for _ in pbar: 174 | metrics = self.training_step(ep) 175 | running_metrics.append(metrics) 176 | 177 | # Update progress bar 178 | pbar.set_postfix({ 179 | 'loss': metrics['total_loss'], 180 | 'reward': metrics['reward'] 181 | }) 182 | 183 | # Update last policy 184 | self.update_last_policy() 185 | 186 | # Compute epoch metrics 187 | epoch_metrics = { 188 | k: sum(d[k] for d in running_metrics) / len(running_metrics) 189 | for k in running_metrics[0].keys() 190 | } 191 | 192 | # Log metrics 193 | if self.use_wandb: 194 | wandb.log( 195 | {f'train/{k}': v for k, v in epoch_metrics.items()}, 196 | step=ep 197 | ) 198 | 199 | def save_checkpoint(self, path: str): 200 | """Save training checkpoint""" 201 | checkpoint = { 202 | 'net_model': self.net_model.state_dict(), 203 | 'last_policy': self.last_policy.state_dict(), 204 | 'optim': self.optimizer.state_dict(), 205 | 'sched': self.scheduler.state_dict(), 206 | 'config': self.config 207 | } 208 | torch.save(checkpoint, path) 209 | 210 | 211 | def main(): 212 | config = { 213 | 'learning_rate': 2e-4, 214 | 'warmup_steps': 5000, 215 | 'w2_coefficient': 1.0, 216 | 'temperature': 0.5, 217 | 'grad_clip': 1.0, 218 | 'batch_size': 128, 219 | 'text_prompts': ["An image of dog", "Not an image of dog"], 220 | 'use_wandb': True, 221 | 'wandb_project': 'cifar10-flow-matching', 222 | 'run_name': 'orw-cfm-w2', 223 | 'parallel': False, 224 | 'savedir': './results', 225 | 'ref_path': f'./pretrained/fm_cifar_{400000}.pt' 226 | } 227 | 228 | # Initialize model 229 | model = UNetModelWrapper( 230 | dim=(3, 32, 32), 231 | num_res_blocks=2, 232 | num_channels=128, 233 | channel_mult=[1, 2, 2, 2], 234 | num_heads=4, 235 | num_head_channels=64, 236 | attention_resolutions="16", 237 | dropout=0.1 238 | ) 239 | 240 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 241 | trainer = ORWCFMTrainer(model, config, device) 242 | trainer.load_pretrained(config['ref_path']) 243 | 244 | trainer.train( 245 | num_epochs=1000, 246 | steps_per_epoch=int(1e4) 247 | ) 248 | 249 | 250 | if __name__ == "__main__": 251 | main() --------------------------------------------------------------------------------