├── .gitignore ├── params.json ├── src ├── decoder │ ├── configs │ │ └── decoder_config.yaml │ ├── decoder_dataset.py │ ├── models.py │ ├── utils │ │ └── evaluation.py │ └── train.py ├── encoder │ ├── collator.py │ ├── configs │ │ └── base_lang_config.yaml │ ├── models.py │ ├── utils │ │ ├── helper.py │ │ └── monitor.py │ └── train.py └── common │ ├── logging.py │ ├── schedulers.py │ ├── config.py │ └── datasets │ ├── fineweb_edu.py │ └── utils │ └── sentence_splitting.py ├── main_encoder.py ├── pyproject.toml ├── main_decoder.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | .idea/ 3 | __pycache__ 4 | logs/ 5 | .env 6 | wandb 7 | outputs/ 8 | monitor_logs/ 9 | -------------------------------------------------------------------------------- /params.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://params.com/params.json", 3 | "docs": { 4 | "main": "main_encoder.md", 5 | "sidebar": { 6 | "Introduction": "README.md" 7 | } 8 | } 9 | } -------------------------------------------------------------------------------- /src/decoder/configs/decoder_config.yaml: -------------------------------------------------------------------------------- 1 | decoder: 2 | hidden_dim: 1536 # 2x the base model embed_dim 3 | num_layers: 4 4 | num_heads: 8 5 | dropout: 0.1 6 | max_length: 128 7 | 8 | training: 9 | batch_size: 32 10 | learning_rate: 1e-4 11 | num_epochs: 10 12 | warmup_steps: 1000 13 | weight_decay: 0.01 14 | gradient_clip: 1.0 15 | 16 | evaluation: 17 | eval_steps: 100 18 | save_steps: 1000 19 | num_samples: 5 # Number of samples to show during evaluation -------------------------------------------------------------------------------- /main_encoder.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | 3 | from src.common.config import LANGJEPAConfig 4 | from src.encoder.train import train 5 | 6 | if __name__ == "__main__": 7 | # Load and validate config 8 | config = LANGJEPAConfig.from_yaml("src/encoder/configs/base_lang_config.yaml") 9 | 10 | # Initialize tokenizer 11 | tokenizer = AutoTokenizer.from_pretrained(config.data.tokenizer_path) 12 | # Add padding token if it doesn't exist 13 | if tokenizer.pad_token is None: 14 | # Use EOS token as padding token 15 | tokenizer.pad_token = tokenizer.eos_token 16 | # Or add a new [PAD] token: 17 | # tokenizer.add_special_tokens({'pad_token': '[PAD]'}) 18 | 19 | config.data.tokenizer = tokenizer 20 | 21 | # Train with validated config 22 | train(config) 23 | -------------------------------------------------------------------------------- /src/encoder/collator.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from torch import Tensor 4 | from transformers import PreTrainedTokenizer 5 | 6 | from src.common.datasets.fineweb_edu import DatasetOutput 7 | 8 | 9 | @dataclass 10 | class Batch: 11 | context_ids: Tensor # Tokenized context sequences [batch_size, seq_len] 12 | padding_masks: Tensor # Padding masks (1 for real tokens, 0 for padding) 13 | context_texts: list[str] # Original context text before tokenization 14 | target_texts: list[str] # Target sentences to predict 15 | 16 | 17 | class Collator: 18 | def __init__(self, tokenizer: PreTrainedTokenizer, max_length: int): 19 | self.tokenizer = tokenizer 20 | self.max_length = max_length 21 | 22 | def __call__(self, batch: list[DatasetOutput]) -> Batch: 23 | contexts = [item.context for item in batch] 24 | tokens = self.tokenizer( 25 | contexts, 26 | padding=True, 27 | truncation=True, 28 | max_length=self.max_length, 29 | return_tensors="pt", 30 | ) 31 | 32 | return Batch( 33 | context_ids=tokens["input_ids"], 34 | padding_masks=tokens["attention_mask"], 35 | context_texts=contexts, 36 | target_texts=[item.target for item in batch], 37 | ) 38 | -------------------------------------------------------------------------------- /src/encoder/configs/base_lang_config.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | # Dataset configuration 3 | train_file: "sample-10BT" # FineWeb-Edu dataset file 4 | batch_size: 32 5 | num_workers: 4 6 | tokenizer_path: "roberta-base" # Base tokenizer to use 7 | # tokenizer_path: "bert-base-uncased" 8 | # tokenizer_path: "gpt2" 9 | limit: 100 # Number of training samples to use 10 | min_length: 100 # Minimum text length to consider 11 | min_sentences: 2 # Minimum number of sentences (needed for context/target split) 12 | 13 | model: 14 | max_length: 128 # Should be smaller, around 128-256 15 | pred_dim: 384 16 | embed_dim: 768 # Standard size for base models 17 | num_layers: 12 18 | num_heads: 12 19 | mlp_ratio: 4.0 20 | dropout: 0.1 21 | 22 | optimization: 23 | # Training parameters 24 | epochs: 5 25 | lr: 0.001 # Peak learning rate 26 | warmup: 1 # Warmup epochs 27 | weight_decay: 0.04 # Initial weight decay 28 | final_weight_decay: 0.4 # Final weight decay 29 | final_lr: 0.000001 # Final learning rate after decay 30 | 31 | logging: 32 | # Logging configuration 33 | log_dir: "logs/lang_jepa" # Directory for saving logs and checkpoints 34 | log_freq: 50 # How often to log training metrics (iterations) 35 | checkpoint_freq: 1 # How often to save checkpoints (epochs) 36 | num_examples: 3 # Number of examples to monitor during training 37 | log_to_wandb: true # Whether to log to Weights & Biases 38 | 39 | meta: 40 | # Mixed precision and checkpointing 41 | use_bfloat16: false # Whether to use bfloat16 mixed precision 42 | load_checkpoint: false # Whether to load from checkpoint 43 | checkpoint_path: null # Path to checkpoint if loading 44 | use_gradient_checkpointing: false 45 | -------------------------------------------------------------------------------- /src/common/logging.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class CSVLogger: 5 | def __init__(self, fname, *argv): 6 | """ 7 | I create a CSV logger that writes header columns defined by `argv`. 8 | 9 | Args: 10 | fname (str): Path to the CSV file. 11 | *argv: Tuples of the form (format_str, column_name). 12 | For example: ('%d', 'epoch'), ('%.5f', 'loss'). 13 | 14 | I open the file and write a header row from column names. 15 | """ 16 | self.fname = fname 17 | self.types = [] 18 | os.makedirs(os.path.dirname(fname), exist_ok=True) 19 | with open(self.fname, "w") as f: 20 | for i, v in enumerate(argv, 1): 21 | self.types.append(v[0]) 22 | if i < len(argv): 23 | print(v[1], end=",", file=f) 24 | else: 25 | print(v[1], end="\n", file=f) 26 | 27 | def log(self, *argv): 28 | """ 29 | I log a row into the CSV file. 30 | Each argument corresponds to the column formatting defined in __init__. 31 | """ 32 | with open(self.fname, "a") as f: 33 | for i, (t, val) in enumerate(zip(self.types, argv, strict=False), 1): 34 | end = "," if i < len(argv) else "\n" 35 | print(t % val, end=end, file=f) 36 | 37 | 38 | class AverageMeter: 39 | """ 40 | I keep track of an average value (e.g., loss) over time. 41 | I store current value, average, sum, count, max, and min. 42 | """ 43 | 44 | def __init__(self): 45 | self.reset() 46 | 47 | def reset(self): 48 | """ 49 | I reset all stored statistics. 50 | """ 51 | self.val = 0.0 52 | self.avg = 0.0 53 | self.max = float("-inf") 54 | self.min = float("inf") 55 | self.sum = 0.0 56 | self.count = 0 57 | 58 | def update(self, val, n=1): 59 | """ 60 | I update the meter with a new value. 61 | 62 | Args: 63 | val (float): new value to record 64 | n (int): how many instances this value represents (e.g., batch size) 65 | """ 66 | self.val = val 67 | self.sum += val * n 68 | self.count += n 69 | self.avg = self.sum / self.count 70 | self.max = max(self.max, val) 71 | self.min = min(self.min, val) 72 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "lang-jepa-v2" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["Jeremy Berman "] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.10" 10 | ruff = "^0.8.3" 11 | wtpsplit = "^2.1.2" 12 | torch = "^2.5.1" 13 | devtools = "^0.12.2" 14 | datasets = "^3.2.0" 15 | wandb = "^0.19.1" 16 | python-dotenv = "^1.0.1" 17 | rich = "^13.9.4" 18 | nltk = "^3.9.1" 19 | rouge-score = "^0.1.2" 20 | pydantic = "^2.10.3" 21 | 22 | 23 | [build-system] 24 | requires = ["poetry-core"] 25 | build-backend = "poetry.core.masonry.api" 26 | 27 | 28 | [tool.mypy] 29 | plugins = 'pydantic.mypy' 30 | strict = true 31 | check_untyped_defs = true 32 | disallow_untyped_calls = true 33 | disallow_incomplete_defs = true 34 | disallow_untyped_defs = true 35 | disallow_untyped_decorators = false 36 | disallow_subclassing_any = false 37 | ignore_missing_imports = true 38 | follow_imports = 'skip' 39 | exclude = [ 40 | "venv", 41 | ".venv", 42 | "alembic", 43 | "app/dbs/sql_gen/gen/db.py" 44 | ] 45 | 46 | [tool.ruff] 47 | target-version = "py312" 48 | exclude = [ 49 | "alembic", 50 | "app/dbs/sql_gen/gen/db.py", 51 | ".bzr", 52 | ".direnv", 53 | ".eggs", 54 | ".git", 55 | ".git-rewrite", 56 | ".hg", 57 | ".ipynb_checkpoints", 58 | ".mypy_cache", 59 | ".nox", 60 | ".pants.d", 61 | ".pyenv", 62 | ".pytest_cache", 63 | ".pytype", 64 | ".ruff_cache", 65 | ".svn", 66 | ".tox", 67 | ".venv", 68 | ".vscode", 69 | "__pypackages__", 70 | "_build", 71 | "buck-out", 72 | "build", 73 | "dist", 74 | "node_modules", 75 | "site-packages", 76 | "venv", 77 | ] 78 | 79 | [tool.ruff.lint] 80 | unfixable = [ 81 | "F401", # unused imports 82 | ] 83 | select = [ 84 | "E", # pycodestyle errors 85 | "W", # pycodestyle warnings 86 | "F", # pyflakes 87 | "I", # isort 88 | "B", # flake8-bugbear 89 | "C4", # flake8-comprehensions 90 | "UP", # pyupgrade 91 | "ARG001", # unused arguments in functions 92 | ] 93 | ignore = [ 94 | "E501", # line too long, handled by black 95 | "B008", # do not perform function calls in argument defaults 96 | "W191", # indentation contains tabs 97 | "B904", # Allow raising exceptions without from e, for HTTPException 98 | ] 99 | 100 | [tool.ruff.lint.pyupgrade] 101 | # Preserve types, even if a file imports `from __future__ import annotations`. 102 | keep-runtime-typing = true -------------------------------------------------------------------------------- /src/common/schedulers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | class WarmupCosineSchedule: 5 | """ 6 | I implement a warmup + cosine decay learning rate scheduler. 7 | During the warmup period, I increase LR from start_lr to ref_lr. 8 | Then I apply a cosine decay from ref_lr down to final_lr. 9 | """ 10 | 11 | def __init__(self, optimizer, warmup_steps, start_lr, ref_lr, T_max, final_lr=0.0): 12 | self.optimizer = optimizer 13 | self.start_lr = start_lr 14 | self.ref_lr = ref_lr 15 | self.final_lr = final_lr 16 | self.warmup_steps = warmup_steps 17 | self.T_max = T_max - warmup_steps 18 | self._step = 0 19 | 20 | def step(self): 21 | """ 22 | I advance the scheduler by one step. 23 | Returns the new learning rate. 24 | """ 25 | self._step += 1 26 | if self._step < self.warmup_steps: 27 | progress = float(self._step) / float(max(1, self.warmup_steps)) 28 | new_lr = self.start_lr + progress * (self.ref_lr - self.start_lr) 29 | else: 30 | # after warmup 31 | progress = float(self._step - self.warmup_steps) / float(max(1, self.T_max)) 32 | new_lr = max( 33 | self.final_lr, 34 | self.final_lr 35 | + (self.ref_lr - self.final_lr) 36 | * 0.5 37 | * (1.0 + math.cos(math.pi * progress)), 38 | ) 39 | 40 | for group in self.optimizer.param_groups: 41 | group["lr"] = new_lr 42 | return new_lr 43 | 44 | 45 | class CosineWDSchedule: 46 | """ 47 | I implement a cosine decay for weight decay. 48 | """ 49 | 50 | def __init__(self, optimizer, ref_wd, T_max, final_wd=0.0): 51 | self.optimizer = optimizer 52 | self.ref_wd = ref_wd 53 | self.final_wd = final_wd 54 | self.T_max = T_max 55 | self._step = 0 56 | 57 | def step(self): 58 | """ 59 | I advance the weight decay schedule by one step. 60 | Returns the new weight decay. 61 | """ 62 | self._step += 1 63 | progress = self._step / self.T_max 64 | new_wd = self.final_wd + (self.ref_wd - self.final_wd) * 0.5 * ( 65 | 1.0 + math.cos(math.pi * progress) 66 | ) 67 | 68 | # handle direction (in case final_wd > ref_wd or final_wd < ref_wd) 69 | if self.final_wd <= self.ref_wd: 70 | new_wd = max(self.final_wd, new_wd) 71 | else: 72 | new_wd = min(self.final_wd, new_wd) 73 | 74 | for group in self.optimizer.param_groups: 75 | if ("WD_exclude" not in group) or not group["WD_exclude"]: 76 | group["weight_decay"] = new_wd 77 | return new_wd 78 | -------------------------------------------------------------------------------- /main_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoTokenizer 3 | 4 | from src.common.config import LANGJEPAConfig 5 | from src.decoder.decoder_dataset import ( 6 | create_eval_loader, 7 | create_train_loader, 8 | split_train_eval, 9 | ) 10 | from src.decoder.models import ConceptDecoder, DecoderConfig 11 | from src.decoder.train import DecoderTrainer, DecoderTrainingConfig 12 | from src.encoder.models import TextTransformer 13 | 14 | 15 | def main(): 16 | # Load base config 17 | config = LANGJEPAConfig.from_yaml("src/encoder/configs/base_lang_config.yaml") 18 | 19 | # Initialize tokenizer 20 | tokenizer = AutoTokenizer.from_pretrained(config.data.tokenizer_path) 21 | config.data.tokenizer = tokenizer 22 | 23 | # Load pretrained encoder 24 | encoder = TextTransformer(config) 25 | checkpoint = torch.load( 26 | "logs/lang_jepa_exp1/checkpoint-epoch5.pth", 27 | weights_only=True, # Fix for the warning 28 | ) 29 | encoder.load_state_dict(checkpoint["encoder"]) 30 | 31 | # Initialize decoder with proper config 32 | decoder_config = DecoderConfig.from_tokenizer( 33 | tokenizer=tokenizer, 34 | embed_dim=config.model.embed_dim, 35 | num_layers=4, 36 | num_heads=8, 37 | dropout=0.1, 38 | max_length=config.model.max_length, 39 | ) 40 | decoder = ConceptDecoder( 41 | config=decoder_config, 42 | tokenizer=tokenizer, # Pass tokenizer here 43 | ) 44 | 45 | # Load all texts 46 | from src.common.datasets.fineweb_edu import TextDataset 47 | 48 | dataset = TextDataset( 49 | train_file=config.data.train_file, 50 | limit=config.data.limit, 51 | min_length=config.data.min_length, 52 | ) 53 | 54 | # Split into train/eval 55 | train_texts, eval_texts = split_train_eval(dataset.samples, eval_ratio=0.1) 56 | 57 | # Create data loaders 58 | train_loader = create_train_loader(config, texts=train_texts) 59 | eval_loader = create_eval_loader(config, texts=eval_texts) 60 | 61 | # Training config 62 | training_config = DecoderTrainingConfig( 63 | batch_size=32, 64 | learning_rate=1e-4, 65 | num_epochs=10, 66 | warmup_steps=1000, 67 | grad_clip=1.0, 68 | weight_decay=0.01, 69 | eval_steps=100, 70 | save_steps=1000, 71 | output_dir="outputs/decoder", 72 | ) 73 | 74 | # Initialize trainer directly with tokenizer 75 | trainer = DecoderTrainer( 76 | config=training_config, 77 | encoder=encoder, 78 | decoder=decoder, 79 | train_loader=train_loader, 80 | eval_loader=eval_loader, 81 | ) 82 | 83 | # Train 84 | trainer.train() 85 | 86 | 87 | if __name__ == "__main__": 88 | main() 89 | -------------------------------------------------------------------------------- /src/encoder/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import Tensor 5 | from transformers import AutoConfig, AutoModel 6 | 7 | from src.common.config import LANGJEPAConfig 8 | 9 | 10 | class TextTransformer(nn.Module): 11 | """Text encoder based on pre-trained transformer models.""" 12 | 13 | def __init__(self, config: LANGJEPAConfig): 14 | super().__init__() 15 | # Load base model config and update with our settings 16 | model_config = AutoConfig.from_pretrained(config.data.tokenizer_path) 17 | model_config.update( 18 | { 19 | "hidden_size": config.model.embed_dim, 20 | "num_hidden_layers": config.model.num_layers, 21 | "num_attention_heads": config.model.num_heads, 22 | "intermediate_size": int( 23 | config.model.embed_dim * config.model.mlp_ratio 24 | ), 25 | "hidden_dropout_prob": config.model.dropout, 26 | "attention_probs_dropout_prob": config.model.dropout, 27 | "vocab_size": len(config.data.tokenizer), 28 | "gradient_checkpointing": config.meta.use_gradient_checkpointing, 29 | } 30 | ) 31 | self.encoder = AutoModel.from_config(model_config) 32 | 33 | # After creating encoder, enable gradient checkpointing 34 | if ( 35 | hasattr(self.encoder, "gradient_checkpointing_enable") 36 | and config.meta.use_gradient_checkpointing 37 | ): 38 | self.encoder.gradient_checkpointing_enable() 39 | 40 | def forward( 41 | self, input_ids: Tensor, attention_mask: Tensor | None = None 42 | ) -> Tensor: 43 | """Get contextual embeddings for input tokens.""" 44 | outputs = self.encoder(input_ids, attention_mask, return_dict=True) 45 | return outputs.last_hidden_state 46 | 47 | 48 | class TextPredictor(nn.Module): 49 | """Predicts next sentence embeddings from context embeddings.""" 50 | 51 | def __init__(self, input_dim: int, pred_dim: int, num_heads: int = 8): 52 | super().__init__() 53 | # Attention to aggregate context sequence 54 | self.context_attention = nn.MultiheadAttention( 55 | embed_dim=input_dim, num_heads=num_heads, dropout=0.1, batch_first=True 56 | ) 57 | self.query = nn.Parameter(torch.randn(1, 1, input_dim) * 0.02) 58 | 59 | # Project to prediction space 60 | self.projection = nn.Sequential( 61 | nn.Linear(input_dim, pred_dim), nn.LayerNorm(pred_dim) 62 | ) 63 | 64 | def forward( 65 | self, context_feats: Tensor, attention_mask: Tensor | None = None 66 | ) -> Tensor: 67 | """Generate predictions from context.""" 68 | # Prepare attention inputs 69 | query = self.query.expand(context_feats.size(0), -1, -1) 70 | key_padding_mask = ( 71 | ~attention_mask.bool() if attention_mask is not None else None 72 | ) 73 | 74 | # Get context representation 75 | context, _ = self.context_attention( 76 | query=query, 77 | key=context_feats, 78 | value=context_feats, 79 | key_padding_mask=key_padding_mask, 80 | ) 81 | 82 | # Project and normalize 83 | predictions = self.projection(context.squeeze(1)) 84 | return F.normalize(predictions, p=2, dim=-1) 85 | 86 | def project_targets(self, features: Tensor) -> Tensor: 87 | """Project target features to prediction space.""" 88 | predictions = self.projection(features) 89 | return F.normalize(predictions, p=2, dim=-1) 90 | -------------------------------------------------------------------------------- /src/common/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, ConfigDict, Field 2 | from transformers import PreTrainedTokenizer 3 | 4 | 5 | class DataConfig(BaseModel): 6 | """Configuration for data loading and processing.""" 7 | 8 | train_file: str = Field(description="Dataset file to use for training") 9 | batch_size: int = Field(gt=0, description="Training batch size") 10 | num_workers: int = Field(ge=0, description="Number of data loader workers") 11 | tokenizer_path: str = Field(description="Path or name of the pretrained tokenizer") 12 | limit: int = Field(gt=0, description="Limit on number of training samples") 13 | min_length: int = Field(gt=0, description="Minimum text length to consider") 14 | min_sentences: int = Field( 15 | gt=1, default=2, description="Minimum number of sentences required" 16 | ) 17 | tokenizer: PreTrainedTokenizer | None = Field( 18 | default=None, description="Loaded tokenizer instance" 19 | ) 20 | 21 | model_config = ConfigDict(arbitrary_types_allowed=True) 22 | 23 | 24 | class ModelConfig(BaseModel): 25 | """Configuration for model architecture.""" 26 | 27 | max_length: int = Field(gt=0, description="Maximum sequence length") 28 | pred_dim: int = Field(gt=0, description="Prediction dimension") 29 | embed_dim: int = Field(gt=0, description="Embedding dimension") 30 | num_layers: int = Field(gt=0, description="Number of transformer layers") 31 | num_heads: int = Field(gt=0, description="Number of attention heads") 32 | mlp_ratio: float = Field(gt=0.0, description="MLP hidden dimension ratio") 33 | dropout: float = Field(ge=0.0, lt=1.0, description="Dropout rate") 34 | 35 | 36 | class OptimizationConfig(BaseModel): 37 | """Configuration for training optimization.""" 38 | 39 | epochs: int = Field(gt=0, description="Number of training epochs") 40 | lr: float = Field(gt=0.0, description="Learning rate") 41 | warmup: int = Field(ge=0, description="Number of warmup epochs") 42 | weight_decay: float = Field(ge=0.0, description="Weight decay") 43 | final_weight_decay: float = Field(ge=0.0, description="Final weight decay") 44 | final_lr: float = Field(ge=0.0, description="Final learning rate") 45 | 46 | 47 | class LoggingConfig(BaseModel): 48 | """Configuration for logging and checkpoints.""" 49 | 50 | log_dir: str = Field(description="Directory for logs") 51 | log_freq: int = Field( 52 | default=50, gt=0, description="Logging frequency in iterations" 53 | ) 54 | checkpoint_freq: int = Field( 55 | default=1, gt=0, description="Checkpoint saving frequency in epochs" 56 | ) 57 | 58 | 59 | class MetaConfig(BaseModel): 60 | """Meta configuration for training.""" 61 | 62 | use_bfloat16: bool = Field( 63 | default=False, description="Whether to use bfloat16 precision" 64 | ) 65 | load_checkpoint: bool = Field( 66 | default=False, description="Whether to load from checkpoint" 67 | ) 68 | checkpoint_path: str | None = Field( 69 | default=None, description="Path to checkpoint file" 70 | ) 71 | use_gradient_checkpointing: bool = Field( 72 | description="Whether to use gradient checkpointing" 73 | ) 74 | 75 | 76 | class LANGJEPAConfig(BaseModel): 77 | """Main configuration class combining all sub-configs.""" 78 | 79 | data: DataConfig 80 | model: ModelConfig 81 | optimization: OptimizationConfig 82 | logging: LoggingConfig 83 | meta: MetaConfig 84 | 85 | @classmethod 86 | def from_yaml(cls, yaml_path: str) -> "LANGJEPAConfig": 87 | """Load config from YAML file.""" 88 | import yaml 89 | 90 | with open(yaml_path) as f: 91 | config_dict = yaml.safe_load(f) 92 | return cls(**config_dict) 93 | 94 | def to_yaml(self, yaml_path: str) -> None: 95 | """Save config to YAML file.""" 96 | import yaml 97 | 98 | config_dict = self.model_dump() 99 | # Remove tokenizer since it can't be serialized 100 | if "tokenizer" in config_dict["data"]: 101 | del config_dict["data"]["tokenizer"] 102 | with open(yaml_path, "w") as f: 103 | yaml.dump(config_dict, f) 104 | -------------------------------------------------------------------------------- /src/encoder/utils/helper.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import torch 5 | from torch.cuda.amp import GradScaler 6 | from torch.optim import AdamW 7 | 8 | from src.common.schedulers import CosineWDSchedule, WarmupCosineSchedule 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def init_optimizer( 14 | encoder: torch.nn.Module, 15 | predictor: torch.nn.Module, 16 | lr: float, 17 | weight_decay: float, 18 | warmup: int, 19 | total_epochs: int, 20 | steps_per_epoch: int, 21 | final_wd: float = 0.0, 22 | final_lr: float = 0.0, 23 | use_bfloat16: bool = False, 24 | ): 25 | """ 26 | I initialize optimizer, schedulers, and scaler. 27 | """ 28 | param_groups = [ 29 | { 30 | "params": ( 31 | p 32 | for n, p in encoder.named_parameters() 33 | if ("bias" not in n) and (len(p.shape) != 1) 34 | ) 35 | }, 36 | { 37 | "params": ( 38 | p 39 | for n, p in predictor.named_parameters() 40 | if ("bias" not in n) and (len(p.shape) != 1) 41 | ) 42 | }, 43 | { 44 | "params": ( 45 | p 46 | for n, p in encoder.named_parameters() 47 | if ("bias" in n) or (len(p.shape) == 1) 48 | ), 49 | "WD_exclude": True, 50 | "weight_decay": 0.0, 51 | }, 52 | { 53 | "params": ( 54 | p 55 | for n, p in predictor.named_parameters() 56 | if ("bias" in n) or (len(p.shape) == 1) 57 | ), 58 | "WD_exclude": True, 59 | "weight_decay": 0.0, 60 | }, 61 | ] 62 | 63 | optimizer = AdamW(param_groups, lr=lr, weight_decay=weight_decay) 64 | total_steps = steps_per_epoch * total_epochs 65 | 66 | scheduler = WarmupCosineSchedule( 67 | optimizer=optimizer, 68 | warmup_steps=int(warmup * steps_per_epoch), 69 | start_lr=lr * 0.1, 70 | ref_lr=lr, 71 | final_lr=final_lr, 72 | T_max=total_steps, 73 | ) 74 | 75 | wd_scheduler = CosineWDSchedule( 76 | optimizer=optimizer, ref_wd=weight_decay, final_wd=final_wd, T_max=total_steps 77 | ) 78 | 79 | scaler = GradScaler() if use_bfloat16 and torch.cuda.is_available() else None 80 | 81 | return optimizer, scaler, scheduler, wd_scheduler 82 | 83 | 84 | def load_checkpoint( 85 | checkpoint_path: str, 86 | encoder: torch.nn.Module, 87 | predictor: torch.nn.Module, 88 | optimizer: torch.optim.Optimizer, 89 | scaler: GradScaler, 90 | device: torch.device, 91 | ) -> int: 92 | """ 93 | I load a checkpoint if available. 94 | """ 95 | if checkpoint_path is None or not os.path.isfile(checkpoint_path): 96 | logger.info("No checkpoint found, starting from scratch.") 97 | return 0 98 | 99 | try: 100 | checkpoint = torch.load(checkpoint_path, map_location=device) 101 | encoder.load_state_dict(checkpoint["encoder"]) 102 | predictor.load_state_dict(checkpoint["predictor"]) 103 | optimizer.load_state_dict(checkpoint["opt"]) 104 | 105 | if ( 106 | scaler is not None 107 | and "scaler" in checkpoint 108 | and checkpoint["scaler"] is not None 109 | ): 110 | scaler.load_state_dict(checkpoint["scaler"]) 111 | 112 | start_epoch = checkpoint["epoch"] 113 | logger.info(f"Loaded checkpoint from {checkpoint_path} (epoch {start_epoch}).") 114 | return start_epoch 115 | except Exception as e: 116 | logger.error(f"Failed to load checkpoint from {checkpoint_path}: {e}") 117 | return 0 118 | 119 | 120 | def save_checkpoint( 121 | checkpoint_path: str, 122 | encoder: torch.nn.Module, 123 | predictor: torch.nn.Module, 124 | optimizer: torch.optim.Optimizer, 125 | scaler: GradScaler, 126 | epoch: int, 127 | loss: float, 128 | ): 129 | """ 130 | I save a checkpoint. 131 | """ 132 | state = { 133 | "encoder": encoder.state_dict(), 134 | "predictor": predictor.state_dict(), 135 | "opt": optimizer.state_dict(), 136 | "scaler": scaler.state_dict() if scaler is not None else None, 137 | "epoch": epoch, 138 | "loss": loss, 139 | } 140 | 141 | try: 142 | torch.save(state, checkpoint_path) 143 | logger.info(f"Checkpoint saved at {checkpoint_path}") 144 | except Exception as e: 145 | logger.error(f"Failed to save checkpoint at {checkpoint_path}: {e}") 146 | -------------------------------------------------------------------------------- /src/decoder/decoder_dataset.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from torch.utils.data import DataLoader, Dataset 5 | from transformers import PreTrainedTokenizer 6 | 7 | from src.common.config import LANGJEPAConfig 8 | 9 | 10 | @dataclass 11 | class DecoderBatch: 12 | """Holds batched data for decoder training.""" 13 | 14 | input_ids: torch.Tensor 15 | attention_mask: torch.Tensor 16 | input_texts: list[str] # Original texts for evaluation 17 | 18 | 19 | class DecoderDataset(Dataset): 20 | """Dataset for training the concept decoder.""" 21 | 22 | def __init__( 23 | self, texts: list[str], tokenizer: PreTrainedTokenizer, max_length: int = 128 24 | ): 25 | self.texts = texts 26 | self.tokenizer = tokenizer 27 | self.max_length = max_length 28 | 29 | def __len__(self) -> int: 30 | return len(self.texts) 31 | 32 | def __getitem__(self, idx: int) -> dict[str, str]: 33 | return {"text": self.texts[idx]} 34 | 35 | def collate_fn(self, batch: list[dict[str, str]]) -> DecoderBatch: 36 | """Collate batch of texts into tensors.""" 37 | texts = [item["text"] for item in batch] 38 | 39 | # Tokenize 40 | encodings = self.tokenizer( 41 | texts, 42 | padding=True, 43 | truncation=True, 44 | max_length=self.max_length, 45 | return_tensors="pt", 46 | ) 47 | 48 | return DecoderBatch( 49 | input_ids=encodings["input_ids"], 50 | attention_mask=encodings["attention_mask"], 51 | input_texts=texts, 52 | ) 53 | 54 | 55 | def create_train_loader( 56 | config: LANGJEPAConfig, texts: list[str] | None = None 57 | ) -> DataLoader: 58 | """Create training data loader.""" 59 | # If texts not provided, load from config's dataset 60 | if texts is None: 61 | from src.common.datasets.fineweb_edu import TextDataset 62 | 63 | dataset = TextDataset( 64 | train_file=config.data.train_file, 65 | limit=config.data.limit, 66 | min_length=config.data.min_length, 67 | ) 68 | texts = dataset.samples 69 | 70 | # Create decoder dataset 71 | decoder_dataset = DecoderDataset( 72 | texts=texts, tokenizer=config.data.tokenizer, max_length=config.model.max_length 73 | ) 74 | 75 | # Create loader 76 | return DataLoader( 77 | decoder_dataset, 78 | batch_size=config.data.batch_size, 79 | shuffle=True, 80 | num_workers=config.data.num_workers, 81 | collate_fn=decoder_dataset.collate_fn, 82 | pin_memory=True, 83 | ) 84 | 85 | 86 | def create_eval_loader( 87 | config: "LANGJEPAConfig", texts: list[str] | None = None, eval_size: int = 1000 88 | ) -> DataLoader: 89 | """Create evaluation data loader.""" 90 | if texts is None: 91 | # Load validation split if available, otherwise use subset of training 92 | try: 93 | from src.common.datasets.fineweb_edu import TextDataset 94 | 95 | eval_dataset = TextDataset( 96 | train_file=config.data.train_file, 97 | limit=eval_size, 98 | min_length=config.data.min_length, 99 | split="validation", # Assuming this is added to TextDataset 100 | ) 101 | texts = eval_dataset.samples 102 | except: 103 | # If no validation split, use subset of training data 104 | from src.common.datasets.fineweb_edu import TextDataset 105 | 106 | dataset = TextDataset( 107 | train_file=config.data.train_file, 108 | limit=eval_size, 109 | min_length=config.data.min_length, 110 | ) 111 | texts = dataset.samples[:eval_size] 112 | 113 | # Create decoder dataset 114 | decoder_dataset = DecoderDataset( 115 | texts=texts, tokenizer=config.data.tokenizer, max_length=config.model.max_length 116 | ) 117 | 118 | # Create loader without shuffling for consistent evaluation 119 | return DataLoader( 120 | decoder_dataset, 121 | batch_size=config.data.batch_size, 122 | shuffle=False, 123 | num_workers=config.data.num_workers, 124 | collate_fn=decoder_dataset.collate_fn, 125 | pin_memory=True, 126 | ) 127 | 128 | 129 | # Utility function to split data for training and evaluation 130 | def split_train_eval( 131 | texts: list[str], eval_ratio: float = 0.1, shuffle: bool = True, seed: int = 42 132 | ) -> tuple[list[str], list[str]]: 133 | """Split texts into training and evaluation sets.""" 134 | if shuffle: 135 | import random 136 | 137 | random.seed(seed) 138 | texts = texts.copy() 139 | random.shuffle(texts) 140 | 141 | split_idx = int(len(texts) * (1 - eval_ratio)) 142 | return texts[:split_idx], texts[split_idx:] 143 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LANG-JEPA: Learning to Think in Latent Space 2 | 3 | LANG-JEPA is an experimental language model architecture that operates in "concept space" rather than "token space." Building on Meta AI's JEPA framework, it predicts semantic features of future text rather than raw tokens, focusing on conceptual understanding and semantic relationships. 4 | 5 | Previous JEPA implementations include: 6 | - [I-JEPA](https://ai.meta.com/blog/yann-lecun-ai-model-i-jepa/) for images: Predicts feature representations of masked image regions 7 | - [V-JEPA](https://ai.meta.com/blog/v-jepa-yann-lecun-ai-model-video-joint-embedding-predictive-architecture/) for videos: Predicts future visual features without pixel reconstruction 8 | 9 | LANG-JEPA applies this approach to text, training models to predict feature-level representations of future text segments rather than specific tokens. The goal is to develop models that reason at a conceptual level, like humans. 10 | 11 | ## How It Works 12 | 13 | LANG-JEPA learns by predicting the semantic features of upcoming text. Given a sequence of text, it: 14 | 1. Encodes both the context and the next sentence into a semantic latent space 15 | 2. Learns to predict the latent representation of the next sentence from the context 16 | 3. Uses cosine similarity in the latent space as a training signal 17 | 18 | The system consists of two core components: 19 | 20 | ### 1. LANG-JEPA Encoder 21 | - A transformer-based model that transforms text into semantic embeddings 22 | - Projects input text into a high-dimensional latent space 23 | - Learns to capture semantic relationships between sentences 24 | 25 | ### 2. Concept Decoder 26 | - Converts learned feature embeddings back into human-readable text 27 | - Enables evaluation and interpretation of the model's semantic understanding 28 | - Trained separately after the encoder 29 | 30 | ## Architecture 31 | 32 | ### Encoder Architecture: 33 | - Text Encoder: Transforms input into semantic embeddings 34 | - Context Processing: Processes context sequence with self-attention 35 | - Feature Prediction: Uses attention to predict next sentence embeddings 36 | - Loss: Cosine similarity between predicted and actual next sentence embeddings 37 | 38 | ### Decoder Architecture: 39 | - Projects LANG-JEPA embeddings to decoder space 40 | - Generates text via transformer decoder 41 | - Trains with teacher forcing and cross-entropy loss 42 | - Evaluates using reconstruction metrics 43 | 44 | ## File Structure 45 | 46 | ``` 47 | ./ 48 | ├── src 49 | │ ├── common 50 | │ │ ├── datasets 51 | │ │ │ ├── utils 52 | │ │ │ │ └── sentence_splitting.py # Sentence splitting utilities 53 | │ │ │ └── fineweb_edu.py # FineWeb-Edu dataset wrapper 54 | │ │ ├── config.py # Configuration classes (Pydantic-based) 55 | │ │ ├── logging.py # Logging utilities (CSV logging, meters) 56 | │ │ └── schedulers.py # Learning rate and weight decay schedulers 57 | │ │ 58 | │ ├── decoder 59 | │ │ ├── configs 60 | │ │ │ └── decoder_config.yaml # YAML config for the decoder model 61 | │ │ ├── utils 62 | │ │ │ └── evaluation.py # Metrics (BLEU, ROUGE, etc.) 63 | │ │ ├── decoder_dataset.py # Dataset utilities for decoder 64 | │ │ ├── models.py # Concept decoder model 65 | │ │ └── train.py # Decoder training loop 66 | │ │ 67 | │ └── encoder 68 | │ ├── configs 69 | │ │ └── base_lang_config.yaml # YAML config for the encoder model 70 | │ ├── utils 71 | │ │ ├── helper.py # Initialization utilities 72 | │ │ └── monitor.py # Training monitoring and logging 73 | │ ├── collator.py # Dataset collation for training 74 | │ ├── models.py # LANG-JEPA encoder and predictor 75 | │ └── train.py # Encoder training loop 76 | │ 77 | ├── main_decoder.py # Decoder training entry point 78 | ├── main_encoder.py # Encoder training entry point 79 | ├── pyproject.toml # Dependencies and configuration 80 | └── README.md # This readme 81 | ``` 82 | 83 | ## Configuration 84 | ### Encoder Configuration 85 | Defined in `src/encoder/configs/base_lang_config.yaml`. 86 | 87 | Controls: 88 | - Model architecture (layers, heads, dimensions) 89 | - Data loading and sequence length 90 | - Optimization parameters (learning rate, epochs, warmup) 91 | - Logging settings 92 | 93 | ### Decoder Configuration 94 | Defined in `src/decoder/configs/decoder_config.yaml`. 95 | 96 | Controls: 97 | - Decoder model architecture 98 | - Training hyperparameters 99 | - Evaluation settings 100 | 101 | ## Training Process 102 | 103 | 1. **LANG-JEPA Training:** 104 | ``` 105 | Text → Split into Context/Target → Encode → Predict Next Features → Update Model 106 | ``` 107 | 108 | 2. **Decoder Training:** 109 | ``` 110 | Concept → Project → Generate Text → Compare with Original → Update Decoder 111 | ``` 112 | 113 | 3. **Evaluation:** 114 | - Feature similarity in latent space 115 | - BLEU and ROUGE scores for generated text 116 | - Perplexity for language model quality 117 | - Semantic similarity metrics 118 | 119 | ## Getting Started 120 | 121 | 1. Install dependencies: 122 | ```bash 123 | poetry shell 124 | poetry install 125 | ``` 126 | 127 | 2. Train LANG-JEPA encoder: 128 | ```bash 129 | python main_encoder.py 130 | ``` 131 | 132 | 3. Train decoder (optional, for text generation): 133 | ```bash 134 | python main_decoder.py 135 | ``` 136 | 137 | ## Model Details 138 | 139 | ### Encoder Architecture 140 | - Built on top of any transformer model (RoBERTa, GPT2, etc.) 141 | - Customized for semantic feature prediction 142 | - Outputs normalized embeddings in latent space 143 | 144 | ### Training Objectives 145 | - Primary: Next sentence feature prediction 146 | - Loss: Cosine similarity in normalized latent space 147 | - Regularization: Weight decay with cosine schedule 148 | 149 | ### Key Features 150 | - Works directly in semantic space 151 | - No token-level predictions 152 | - Focus on semantic relationships 153 | - Efficient training with cosine similarity -------------------------------------------------------------------------------- /src/decoder/models.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | import torch.nn as nn 5 | from transformers import PreTrainedTokenizer 6 | 7 | 8 | @dataclass 9 | class DecoderConfig: 10 | """Configuration for the concept decoder.""" 11 | 12 | embed_dim: int # Dimension of concept space 13 | hidden_dim: int # Internal dimension of decoder 14 | vocab_size: int # Set from tokenizer 15 | pad_token_id: int # Set from tokenizer 16 | bos_token_id: int # Set from tokenizer 17 | eos_token_id: int # Set from tokenizer 18 | num_layers: int = 4 19 | num_heads: int = 8 20 | dropout: float = 0.1 21 | max_length: int = 128 22 | 23 | @classmethod 24 | def from_tokenizer( 25 | cls, 26 | tokenizer: PreTrainedTokenizer, 27 | embed_dim: int, 28 | hidden_dim: int | None = None, 29 | **kwargs, 30 | ) -> "DecoderConfig": 31 | """Create config from tokenizer and embedding dimension.""" 32 | return cls( 33 | embed_dim=embed_dim, 34 | hidden_dim=hidden_dim or embed_dim * 2, 35 | vocab_size=len(tokenizer), 36 | pad_token_id=tokenizer.pad_token_id, 37 | bos_token_id=tokenizer.bos_token_id, 38 | eos_token_id=tokenizer.eos_token_id, 39 | **kwargs, # This will override defaults if provided 40 | ) 41 | 42 | 43 | class ConceptDecoder(nn.Module): 44 | """Decoder for converting concept embeddings back to text.""" 45 | 46 | def __init__(self, config: DecoderConfig, tokenizer: PreTrainedTokenizer): 47 | super().__init__() 48 | self.config = config 49 | self.tokenizer = tokenizer 50 | 51 | # Project from concept space to decoder dimension 52 | self.concept_proj = nn.Linear(config.embed_dim, config.hidden_dim) 53 | 54 | # Embeddings 55 | self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_dim) 56 | self.pos_embedding = nn.Parameter( 57 | torch.zeros(1, config.max_length, config.hidden_dim) 58 | ) 59 | 60 | # Decoder 61 | layer = nn.TransformerDecoderLayer( 62 | d_model=config.hidden_dim, 63 | nhead=config.num_heads, 64 | dim_feedforward=config.hidden_dim * 4, 65 | dropout=config.dropout, 66 | batch_first=True, 67 | ) 68 | self.decoder = nn.TransformerDecoder(layer, num_layers=config.num_layers) 69 | 70 | # Output projection 71 | self.out_proj = nn.Linear(config.hidden_dim, config.vocab_size) 72 | 73 | self._init_weights() 74 | 75 | def _init_weights(self) -> None: 76 | """Initialize weights with truncated normal distribution.""" 77 | nn.init.trunc_normal_(self.pos_embedding, std=0.02) 78 | 79 | def forward( 80 | self, 81 | concepts: torch.Tensor, # [batch_size, embed_dim] 82 | target_ids: torch.Tensor | None = None, # [batch_size, seq_len] 83 | ) -> torch.Tensor: 84 | batch_size = concepts.shape[0] 85 | device = concepts.device 86 | 87 | # Reshape concepts if needed 88 | if len(concepts.shape) > 2: 89 | concepts = concepts.reshape(batch_size, -1) # Flatten to [B, D] 90 | 91 | # Project concept to decoder space 92 | memory = self.concept_proj(concepts) # [B, H] 93 | 94 | if self.training and target_ids is not None: 95 | # Get sequence length for teacher forcing (excluding last token) 96 | seq_length = target_ids.size(1) - 1 97 | 98 | # Expand memory to match sequence length 99 | memory = memory.unsqueeze(1).expand(-1, seq_length, -1) # [B, L-1, H] 100 | 101 | # Teacher forcing (exclude last token from input) 102 | tgt_emb = self.token_embedding(target_ids[:, :-1]) # [B, L-1, H] 103 | tgt_emb = tgt_emb + self.pos_embedding[:, :seq_length] 104 | 105 | # Create causal mask 106 | tgt_mask = nn.Transformer.generate_square_subsequent_mask( 107 | seq_length, device=device 108 | ) 109 | 110 | # Decode 111 | hidden = self.decoder(tgt_emb, memory, tgt_mask=tgt_mask) 112 | logits = self.out_proj(hidden) # [B, L-1, V] 113 | 114 | return logits 115 | 116 | else: 117 | # Autoregressive generation 118 | curr_ids = torch.full( 119 | (batch_size, 1), 120 | self.config.bos_token_id, 121 | device=device, 122 | dtype=torch.long, 123 | ) 124 | 125 | logits = [] 126 | for _ in range(self.config.max_length - 1): 127 | # Embed current sequence 128 | tgt_emb = self.token_embedding(curr_ids) 129 | tgt_emb = tgt_emb + self.pos_embedding[:, : curr_ids.size(1)] 130 | 131 | # Decode one step 132 | tgt_mask = nn.Transformer.generate_square_subsequent_mask( 133 | curr_ids.size(1), device=device 134 | ) 135 | 136 | hidden = self.decoder( 137 | tgt_emb, memory[:, : curr_ids.size(1)], tgt_mask=tgt_mask 138 | ) 139 | step_logits = self.out_proj(hidden[:, -1:]) # [B, 1, V] 140 | logits.append(step_logits) 141 | 142 | # Sample next token 143 | next_token = step_logits.argmax(dim=-1) # [B, 1] 144 | curr_ids = torch.cat([curr_ids, next_token], dim=1) 145 | 146 | # Stop if we see end token 147 | if (next_token == self.config.eos_token_id).any(): 148 | break 149 | 150 | logits = torch.cat(logits, dim=1) 151 | 152 | return logits 153 | 154 | @torch.no_grad() 155 | def generate( 156 | self, 157 | concepts: torch.Tensor, 158 | tokenizer: PreTrainedTokenizer, 159 | max_length: int | None = None, 160 | ) -> list[str]: 161 | """ 162 | Generate text from concepts. 163 | 164 | Args: 165 | concepts: Concept embeddings to decode 166 | tokenizer: Tokenizer for decoding 167 | max_length: Optional override for maximum length 168 | 169 | Returns: 170 | List of generated strings 171 | """ 172 | self.eval() 173 | logits = self.forward(concepts) 174 | sequences = logits.argmax(dim=-1) 175 | 176 | # Decode to text 177 | texts = tokenizer.batch_decode(sequences, skip_special_tokens=True) 178 | return texts 179 | -------------------------------------------------------------------------------- /src/decoder/utils/evaluation.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from nltk.translate.bleu_score import corpus_bleu, sentence_bleu 6 | from rouge_score import rouge_scorer 7 | from torch import Tensor 8 | from transformers import PreTrainedTokenizer 9 | 10 | 11 | @dataclass 12 | class DecoderMetrics: 13 | """Holds evaluation metrics for concept decoder.""" 14 | 15 | bleu: float 16 | rouge: dict[str, float] 17 | perplexity: float 18 | concept_cosine_sim: float 19 | diversity: float 20 | 21 | 22 | class ConceptMetrics: 23 | """Evaluates concept decoder performance.""" 24 | 25 | def __init__( 26 | self, 27 | tokenizer: PreTrainedTokenizer, 28 | device: torch.device, 29 | ): 30 | self.tokenizer = tokenizer 31 | self.device = device 32 | self.rouge = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"]) 33 | 34 | @torch.no_grad() 35 | def compute_metrics( 36 | self, 37 | encoder: torch.nn.Module, 38 | decoder: torch.nn.Module, 39 | original_texts: list[str], 40 | generated_texts: list[str], 41 | concept_embeddings: Tensor | None = None, 42 | ) -> DecoderMetrics: 43 | """ 44 | Compute all metrics for generated text and concepts. 45 | 46 | Args: 47 | encoder: LANG-JEPA encoder model 48 | decoder: Concept decoder model 49 | original_texts: Ground truth texts 50 | generated_texts: Generated texts from decoder 51 | concept_embeddings: Optional pre-computed concept embeddings 52 | 53 | Returns: 54 | DecoderMetrics containing all computed metrics 55 | """ 56 | # Compute BLEU score 57 | refs = [[t.split()] for t in original_texts] # Split into words 58 | hyps = [t.split() for t in generated_texts] 59 | bleu = corpus_bleu(refs, hyps) 60 | 61 | # Compute ROUGE scores 62 | rouge_scores = { 63 | name: sum( 64 | self.rouge.score(orig, gen)[name].fmeasure 65 | for orig, gen in zip(original_texts, generated_texts, strict=False) 66 | ) 67 | / len(original_texts) 68 | for name in ["rouge1", "rouge2", "rougeL"] 69 | } 70 | 71 | # Compute perplexity 72 | perplexity = self._compute_perplexity(decoder, original_texts) 73 | 74 | # Compute concept similarity if embeddings not provided 75 | if concept_embeddings is None: 76 | orig_concepts = encoder( 77 | self.tokenizer( 78 | original_texts, return_tensors="pt", padding=True, truncation=True 79 | ).to(self.device) 80 | ) 81 | else: 82 | orig_concepts = concept_embeddings 83 | 84 | gen_concepts = encoder( 85 | self.tokenizer( 86 | generated_texts, return_tensors="pt", padding=True, truncation=True 87 | ).to(self.device) 88 | ) 89 | 90 | concept_sim = ( 91 | F.cosine_similarity(orig_concepts.mean(dim=1), gen_concepts.mean(dim=1)) 92 | .mean() 93 | .item() 94 | ) 95 | 96 | # Compute diversity 97 | diversity = self._compute_diversity(generated_texts) 98 | 99 | return DecoderMetrics( 100 | bleu=bleu, 101 | rouge=rouge_scores, 102 | perplexity=perplexity, 103 | concept_cosine_sim=concept_sim, 104 | diversity=diversity, 105 | ) 106 | 107 | def _compute_perplexity( 108 | self, 109 | decoder: torch.nn.Module, 110 | texts: list[str], 111 | ) -> float: 112 | """Compute perplexity of generated texts.""" 113 | # Tokenize texts 114 | encodings = self.tokenizer( 115 | texts, return_tensors="pt", padding=True, truncation=True 116 | ).to(self.device) 117 | 118 | input_ids = encodings["input_ids"] 119 | 120 | # Forward pass through decoder 121 | with torch.no_grad(): 122 | outputs = decoder(input_ids=input_ids[:, :-1], labels=input_ids[:, 1:]) 123 | 124 | return torch.exp(outputs.loss).item() 125 | 126 | def _compute_diversity(self, texts: list[str]) -> float: 127 | """Compute lexical diversity of generated texts.""" 128 | if not texts: 129 | return 0.0 130 | 131 | # Split into words and get unique words 132 | all_words = [] 133 | for text in texts: 134 | all_words.extend(text.split()) 135 | 136 | if not all_words: 137 | return 0.0 138 | 139 | unique_words = set(all_words) 140 | return len(unique_words) / len(all_words) 141 | 142 | 143 | class SampleGenerator: 144 | """Generates and displays sample decoder outputs.""" 145 | 146 | def __init__( 147 | self, 148 | encoder: torch.nn.Module, 149 | decoder: torch.nn.Module, 150 | tokenizer: PreTrainedTokenizer, 151 | device: torch.device, 152 | ): 153 | self.encoder = encoder 154 | self.decoder = decoder 155 | self.tokenizer = tokenizer 156 | self.device = device 157 | 158 | @torch.no_grad() 159 | def generate_samples( 160 | self, texts: list[str], num_samples: int = 3 161 | ) -> list[dict[str, str]]: 162 | """Generate samples for visualization.""" 163 | samples = [] 164 | 165 | for text in texts[:num_samples]: 166 | # Encode to concept space 167 | inputs = self.tokenizer( 168 | text, return_tensors="pt", padding=True, truncation=True 169 | ).to(self.device) 170 | 171 | concept = self.encoder(inputs["input_ids"]) 172 | 173 | # Generate from concept 174 | generated = self.decoder.generate(concept, self.tokenizer)[ 175 | 0 176 | ] # Take first (and only) generation 177 | 178 | samples.append( 179 | { 180 | "original": text, 181 | "generated": generated, 182 | "bleu": sentence_bleu([text.split()], generated.split()), 183 | } 184 | ) 185 | 186 | return samples 187 | 188 | 189 | def format_metrics(metrics: DecoderMetrics) -> str: 190 | """Format metrics for printing.""" 191 | return ( 192 | f"BLEU: {metrics.bleu:.4f}\n" 193 | f"ROUGE-1: {metrics.rouge['rouge1']:.4f}\n" 194 | f"ROUGE-2: {metrics.rouge['rouge2']:.4f}\n" 195 | f"ROUGE-L: {metrics.rouge['rougeL']:.4f}\n" 196 | f"Perplexity: {metrics.perplexity:.2f}\n" 197 | f"Concept Similarity: {metrics.concept_cosine_sim:.4f}\n" 198 | f"Diversity: {metrics.diversity:.4f}" 199 | ) 200 | -------------------------------------------------------------------------------- /src/common/datasets/fineweb_edu.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from datasets import load_dataset 4 | from torch.utils.data import Dataset 5 | from tqdm.auto import tqdm 6 | from transformers import PreTrainedTokenizer 7 | 8 | from src.common.datasets.utils.sentence_splitting import ( 9 | SentenceSplitter, 10 | SentenceSplitterConfig, 11 | ) 12 | 13 | 14 | @dataclass 15 | class Sentence: 16 | text: str 17 | start_idx: int 18 | end_idx: int 19 | 20 | 21 | @dataclass 22 | class DatasetOutput: 23 | context: str 24 | target: str 25 | 26 | 27 | class TextDataset(Dataset): 28 | def __init__( 29 | self, 30 | *, 31 | train_file: str, 32 | limit: int | None, 33 | min_length: int, 34 | window_size: int = 25, 35 | min_sentences: int = 2, 36 | tokenizer: PreTrainedTokenizer | None = None, 37 | max_tokens: int | None = None, 38 | cache_dir: str = "~/.cache/huggingface/datasets", 39 | ): 40 | """Enhanced dataset wrapper with precise sentence boundary handling. 41 | 42 | Args: 43 | train_file: Which dataset file to load 44 | limit: Number of documents to process 45 | min_length: Minimum text length to consider 46 | window_size: Number of sentences to use as context (default: 25) 47 | min_sentences: Minimum sentences required (default: 2) 48 | tokenizer: Optional tokenizer for length checking 49 | max_tokens: Optional maximum tokens per context window 50 | cache_dir: HuggingFace cache directory 51 | """ 52 | self.samples: list[DatasetOutput] = [] 53 | self.stats = { 54 | "total_docs": 0, 55 | "docs_processed": 0, 56 | "docs_rejected_length": 0, 57 | "docs_rejected_sentences": 0, 58 | "context_target_pairs": 0, 59 | "pairs_rejected_length": 0, 60 | } 61 | 62 | # Load dataset 63 | print(f"Loading dataset with {window_size}-sentence sliding window...") 64 | ds = load_dataset( 65 | path="HuggingFaceFW/fineweb-edu", 66 | name=train_file, 67 | split="train", 68 | streaming=True, 69 | cache_dir=cache_dir, 70 | ) 71 | 72 | # Initialize sentence splitter 73 | splitter = SentenceSplitter(SentenceSplitterConfig()) 74 | 75 | # Process documents 76 | pbar = tqdm(total=limit, desc="Processing documents", unit="docs") 77 | 78 | for doc in ds: 79 | self.stats["total_docs"] += 1 80 | text = doc.get("text", "").strip() 81 | 82 | # Check minimum length 83 | if len(text) < min_length: 84 | self.stats["docs_rejected_length"] += 1 85 | continue 86 | 87 | try: 88 | # Split into sentences 89 | sentences = splitter([text])[0] 90 | if len(sentences) < min_sentences: 91 | self.stats["docs_rejected_sentences"] += 1 92 | continue 93 | 94 | # Find sentence boundaries in original text 95 | sentence_objs: list[Sentence] = [] 96 | search_start = 0 97 | 98 | for sent in sentences: 99 | # Find the sentence in the original text 100 | start_idx = text.index(sent, search_start) 101 | end_idx = start_idx + len(sent) 102 | 103 | sentence_objs.append( 104 | Sentence(text=sent, start_idx=start_idx, end_idx=end_idx) 105 | ) 106 | search_start = end_idx 107 | 108 | # Create context-target pairs with sliding window 109 | for i in range(1, len(sentence_objs)): 110 | # Get previous sentences as context (up to window_size) 111 | start_sent_idx = max(0, i - window_size) 112 | 113 | # Get exact text slice from original document 114 | context_start = sentence_objs[start_sent_idx].start_idx 115 | context_end = sentence_objs[i - 1].end_idx 116 | context = text[context_start:context_end] 117 | 118 | # Get target sentence with exact boundaries 119 | target = text[sentence_objs[i].start_idx : sentence_objs[i].end_idx] 120 | 121 | # Check token length if tokenizer provided 122 | if tokenizer and max_tokens: 123 | context_tokens = len(tokenizer.encode(context)) 124 | if context_tokens > max_tokens: 125 | self.stats["pairs_rejected_length"] += 1 126 | continue 127 | 128 | self.samples.append( 129 | DatasetOutput( 130 | context=context, 131 | target=target, 132 | ) 133 | ) 134 | self.stats["context_target_pairs"] += 1 135 | 136 | self.stats["docs_processed"] += 1 137 | pbar.update(1) 138 | 139 | if limit and self.stats["docs_processed"] >= limit: 140 | break 141 | 142 | except Exception as e: 143 | print(f"Error processing document: {e}") 144 | continue 145 | 146 | pbar.close() 147 | 148 | # Print statistics 149 | print("\nDataset Processing Statistics:") 150 | print(f"Total documents seen: {self.stats['total_docs']:,}") 151 | print(f"Documents processed: {self.stats['docs_processed']:,}") 152 | print(f"Documents rejected (length): {self.stats['docs_rejected_length']:,}") 153 | print( 154 | f"Documents rejected (sentences): {self.stats['docs_rejected_sentences']:,}" 155 | ) 156 | print(f"Context-target pairs generated: {self.stats['context_target_pairs']:,}") 157 | print(f"Pairs rejected (length): {self.stats['pairs_rejected_length']:,}") 158 | 159 | if not self.samples: 160 | raise RuntimeError( 161 | f"No valid samples found in dataset ({train_file}). " 162 | f"Try adjusting the minimum length ({min_length}) or " 163 | f"minimum sentences ({min_sentences}) requirements." 164 | ) 165 | 166 | def __len__(self) -> int: 167 | return len(self.samples) 168 | 169 | def __getitem__(self, idx: int) -> DatasetOutput: 170 | return self.samples[idx] 171 | 172 | 173 | def worker_init_fn(worker_id: int) -> None: 174 | """Initialize any worker-specific resources.""" 175 | # No need for worker-specific initialization anymore since we process 176 | # everything in __init__ 177 | pass 178 | -------------------------------------------------------------------------------- /src/common/datasets/utils/sentence_splitting.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | from devtools import debug 6 | from wtpsplit import SaT 7 | 8 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 9 | 10 | 11 | @dataclass 12 | class SentenceSplitterConfig: 13 | model_name: str = "sat-3l-sm" 14 | sentence_suffix: str = "_sentences" 15 | sentence_threshold: float = 0.01 16 | max_sentence_len: int = 256 17 | min_text_length: int = 10 18 | min_unique_chars: int = 0 19 | fallback_separators: list[str] = field( 20 | default_factory=lambda: [ 21 | "...", 22 | "\n", 23 | "!", 24 | "?", 25 | ";", 26 | ":", 27 | ".", 28 | ",", 29 | "\t", 30 | " ", 31 | ] 32 | ) 33 | device: str = "cuda" 34 | remove_whitespace_before_inference: bool = False 35 | batch_size: int = 256 36 | block_size: int = 256 37 | stride: int = 256 38 | outer_batch_size: int = 1024 39 | verbose: bool = False 40 | pad_last_batch: bool = False 41 | 42 | 43 | class SentenceSplitter: 44 | def __init__(self, config: SentenceSplitterConfig): 45 | self.config = config 46 | device = torch.device(config.device if torch.cuda.is_available() else "cpu") 47 | 48 | try: 49 | self.model = SaT( 50 | self.config.model_name, 51 | from_pretrained_kwargs={"local_files_only": True}, 52 | ) 53 | except Exception: 54 | self.model = SaT(self.config.model_name) 55 | 56 | if "cuda" in config.device and torch.cuda.is_available(): 57 | self.model.half() 58 | 59 | self.model.eval().to(device) 60 | self.device = device 61 | 62 | @torch.inference_mode() 63 | def _resplit_long_sentence(self, sentence: str) -> list[str]: 64 | # If a single sentence is too long, split further by fallback separators 65 | # until no segment exceeds max_sentence_len or fallback is exhausted. 66 | segments = [sentence] 67 | for sep in self.config.fallback_separators: 68 | new_segments = [] 69 | for seg in segments: 70 | if len(seg) > self.config.max_sentence_len: 71 | # Split using the current separator 72 | parts = [s.strip() for s in seg.split(sep) if s.strip()] 73 | # If splitting didn't help (e.g., no sep found), just do a brute force word split. 74 | if len(parts) == 1 and len(parts[0]) == len(seg): 75 | parts = self._brute_force_split(seg) 76 | new_segments.extend(parts) 77 | else: 78 | new_segments.append(seg) 79 | segments = new_segments 80 | 81 | # Finally, ensure no segment is still too long 82 | final_segments = [] 83 | for seg in segments: 84 | if len(seg) > self.config.max_sentence_len: 85 | final_segments.extend(self._brute_force_split(seg)) 86 | else: 87 | final_segments.append(seg) 88 | 89 | return final_segments 90 | 91 | def _brute_force_split(self, text: str) -> list[str]: 92 | # Split by words if the text is still too large, ignoring separators 93 | words = text.split() 94 | chunks = [] 95 | current_chunk = [] 96 | 97 | for w in words: 98 | # +1 for space 99 | if ( 100 | sum(len(x) + 1 for x in current_chunk) + len(w) 101 | > self.config.max_sentence_len 102 | ): 103 | if current_chunk: 104 | chunks.append(" ".join(current_chunk)) 105 | current_chunk = [w] 106 | else: 107 | current_chunk.append(w) 108 | 109 | if current_chunk: 110 | chunks.append(" ".join(current_chunk)) 111 | 112 | return [c.strip() for c in chunks if c.strip()] 113 | 114 | def _filter_by_unique_chars(self, sentences: list[str]) -> list[str]: 115 | if self.config.min_unique_chars <= 0: 116 | return sentences 117 | 118 | def unique_chars_count(s: str) -> int: 119 | return len(set(s)) 120 | 121 | return [ 122 | s for s in sentences if unique_chars_count(s) > self.config.min_unique_chars 123 | ] 124 | 125 | @torch.inference_mode() 126 | def __call__(self, texts: list[str]) -> list[list[str]]: 127 | # If single string, convert to list 128 | if isinstance(texts, str): 129 | texts = [texts] 130 | 131 | # Split texts using the model 132 | # Filter out too-short texts directly 133 | long_texts = [ 134 | (i, t) for i, t in enumerate(texts) if len(t) > self.config.min_text_length 135 | ] 136 | short_texts = [ 137 | (i, t) for i, t in enumerate(texts) if len(t) <= self.config.min_text_length 138 | ] 139 | 140 | # Extract the actual text for model inference 141 | long_text_strings = [t for _, t in long_texts] 142 | 143 | # Run the model 144 | outputs = self.model.split( 145 | long_text_strings, 146 | threshold=self.config.sentence_threshold, 147 | stride=self.config.stride, 148 | block_size=self.config.block_size, 149 | batch_size=self.config.batch_size, 150 | pad_last_batch=self.config.pad_last_batch, 151 | remove_whitespace_before_inference=self.config.remove_whitespace_before_inference, 152 | outer_batch_size=self.config.outer_batch_size, 153 | verbose=self.config.verbose, 154 | ) 155 | 156 | # Now we have a list of sentence lists for each long text 157 | # Post-process each list: 158 | final_results = [None] * len(texts) 159 | 160 | # Insert short texts (they don't need splitting) 161 | for i, t in short_texts: 162 | final_results[i] = [t.strip()] if t.strip() else [] 163 | 164 | # Process the long texts 165 | for (i, _), sentence_list in zip(long_texts, outputs, strict=False): 166 | # Strip sentences 167 | sentence_list = [s.strip() for s in sentence_list if s.strip()] 168 | 169 | # Resplit any long sentences 170 | resplit_sentences = [] 171 | for sent in sentence_list: 172 | if len(sent) > self.config.max_sentence_len: 173 | resplit_sentences.extend(self._resplit_long_sentence(sent)) 174 | else: 175 | resplit_sentences.append(sent) 176 | 177 | # Filter by unique chars if needed 178 | resplit_sentences = self._filter_by_unique_chars(resplit_sentences) 179 | 180 | final_results[i] = resplit_sentences 181 | 182 | return final_results 183 | 184 | 185 | if __name__ == "__main__": 186 | import time 187 | 188 | config = SentenceSplitterConfig() 189 | splitter = SentenceSplitter(config) 190 | sample_texts = [ 191 | "This is a test. It's a simple test, isn't it? Yes, it is!", 192 | "Short", 193 | "A very long sentence that definitely exceeds the maximum length of the sentence and should be split into multiple chunks by the splitter because it is too long to remain one single sentence yes very long isn't it wowowww this is super long ahuahiuhahaiuahu.", 194 | ] 195 | times = [] 196 | for _ in range(10): 197 | start = time.time() 198 | result = splitter(sample_texts) 199 | debug(result) 200 | times.append(time.time() - start) 201 | print(f"Average time: {sum(times) / len(times)}, {times}") 202 | -------------------------------------------------------------------------------- /src/encoder/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from pathlib import Path 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from dotenv import load_dotenv 8 | from torch.utils.data import DataLoader 9 | 10 | import wandb 11 | from src.common.config import LANGJEPAConfig 12 | from src.common.datasets.fineweb_edu import TextDataset, worker_init_fn 13 | from src.common.logging import AverageMeter, CSVLogger 14 | from src.encoder.collator import Batch, Collator 15 | from src.encoder.models import TextPredictor, TextTransformer 16 | from src.encoder.utils.helper import init_optimizer, load_checkpoint, save_checkpoint 17 | from src.encoder.utils.monitor import TrainingMonitor 18 | 19 | 20 | def train(config: LANGJEPAConfig) -> None: 21 | """Main training function for LANG-JEPA next-sentence prediction.""" 22 | 23 | # Initialize wandb 24 | load_dotenv() 25 | wandb.login(key=os.environ["WANDB_API_KEY"]) 26 | wandb.init( 27 | project="lang-jepa", 28 | config=config.model_dump(), 29 | name=f"run_{time.strftime('%Y%m%d_%H%M%S')}", 30 | ) 31 | 32 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 33 | 34 | # Setup logging 35 | os.makedirs(config.logging.log_dir, exist_ok=True) 36 | log_file = os.path.join(config.logging.log_dir, "training.csv") 37 | csv_logger = CSVLogger( 38 | log_file, 39 | ("%d", "epoch"), 40 | ("%d", "itr"), 41 | ("%.5f", "loss"), 42 | ("%.6f", "lr"), 43 | ("%.2f", "time(ms)"), 44 | ) 45 | 46 | # Initialize dataset and dataloader 47 | dataset = TextDataset( 48 | train_file=config.data.train_file, 49 | limit=config.data.limit, 50 | min_length=config.data.min_length, 51 | min_sentences=config.data.min_sentences, 52 | ) 53 | 54 | collator = Collator( 55 | tokenizer=config.data.tokenizer, max_length=config.model.max_length 56 | ) 57 | 58 | dataloader = DataLoader( 59 | dataset=dataset, 60 | batch_size=config.data.batch_size, 61 | shuffle=True, 62 | num_workers=config.data.num_workers, 63 | pin_memory=True, 64 | collate_fn=collator, 65 | worker_init_fn=worker_init_fn, 66 | ) 67 | 68 | # Initialize models 69 | encoder = TextTransformer(config=config).to(device) 70 | predictor = TextPredictor( 71 | input_dim=config.model.embed_dim, 72 | pred_dim=config.model.pred_dim, 73 | ).to(device) 74 | 75 | # Initialize optimizer and schedulers 76 | optimizer, scaler, scheduler, wd_scheduler = init_optimizer( 77 | encoder=encoder, 78 | predictor=predictor, 79 | lr=config.optimization.lr, 80 | weight_decay=config.optimization.weight_decay, 81 | warmup=config.optimization.warmup, 82 | total_epochs=config.optimization.epochs, 83 | steps_per_epoch=len(dataloader), 84 | use_bfloat16=config.meta.use_bfloat16, 85 | ) 86 | 87 | # Initialize training monitor 88 | monitor = TrainingMonitor( 89 | tokenizer=config.data.tokenizer, 90 | log_dir=Path(config.logging.log_dir), 91 | log_to_wandb=True, 92 | ) 93 | 94 | # Load checkpoint if specified 95 | start_epoch = 0 96 | if config.meta.load_checkpoint: 97 | start_epoch = load_checkpoint( 98 | checkpoint_path=config.meta.checkpoint_path, 99 | encoder=encoder, 100 | predictor=predictor, 101 | optimizer=optimizer, 102 | scaler=scaler, 103 | device=device, 104 | ) 105 | 106 | # Training loop 107 | loss_meter = AverageMeter() 108 | encoder.train() 109 | predictor.train() 110 | 111 | for epoch in range(start_epoch, config.optimization.epochs): 112 | epoch_start = time.time() 113 | loss_meter.reset() 114 | 115 | for itr, batch in enumerate(dataloader): 116 | batch: Batch 117 | # Move batch to device 118 | context_ids = batch.context_ids.to(device) 119 | context_mask = batch.padding_masks.to(device) 120 | 121 | # Process target sentences 122 | target_tokens = config.data.tokenizer( 123 | batch.target_texts, 124 | padding=True, 125 | truncation=True, 126 | max_length=config.model.max_length, 127 | return_tensors="pt", 128 | ).to(device) 129 | 130 | # Get embeddings for both context and target 131 | with torch.cuda.amp.autocast(enabled=config.meta.use_bfloat16): 132 | # Get target embeddings 133 | with torch.no_grad(): 134 | target_features = encoder( 135 | target_tokens.input_ids, 136 | target_tokens.attention_mask, 137 | ) 138 | # Average token embeddings for sentence representation 139 | target_features = target_features.mean(dim=1) 140 | target_features = predictor.project_targets(target_features) 141 | target_features = F.normalize(target_features, p=2, dim=-1) 142 | 143 | # Get context embeddings and predict 144 | context_features = encoder(context_ids, context_mask) 145 | predicted_features = predictor(context_features, context_mask) 146 | predicted_features = F.normalize(predicted_features, p=2, dim=-1) 147 | 148 | # Compute loss using cosine similarity 149 | loss = ( 150 | 1 - F.cosine_similarity(predicted_features, target_features).mean() 151 | ) 152 | 153 | # Optimize 154 | optimizer.zero_grad() 155 | if scaler is not None: 156 | scaler.scale(loss).backward() 157 | scaler.step(optimizer) 158 | scaler.update() 159 | else: 160 | loss.backward() 161 | optimizer.step() 162 | 163 | # Update schedulers 164 | lr = scheduler.step() 165 | wd_scheduler.step() 166 | 167 | # Logging 168 | loss_val = loss.item() 169 | loss_meter.update(loss_val) 170 | 171 | if itr % config.logging.log_freq == 0: 172 | elapsed = (time.time() - epoch_start) * 1000.0 173 | csv_logger.log(epoch + 1, itr, loss_val, lr, elapsed) 174 | print( 175 | f"[Epoch {epoch+1}/{config.optimization.epochs}, Itr {itr}] " 176 | f"loss: {loss_meter.avg:.4f}, lr: {lr:.2e}" 177 | ) 178 | 179 | # Log metrics 180 | wandb.log( 181 | { 182 | "train/loss": loss_val, 183 | "train/learning_rate": lr, 184 | "train/iteration": itr + epoch * len(dataloader), 185 | "stats/target_features_norm": target_features.norm(dim=1) 186 | .mean() 187 | .item(), 188 | "stats/predicted_features_norm": predicted_features.norm(dim=1) 189 | .mean() 190 | .item(), 191 | "stats/cosine_similarity": F.cosine_similarity( 192 | predicted_features, target_features 193 | ) 194 | .mean() 195 | .item(), 196 | } 197 | ) 198 | 199 | # Monitor training examples 200 | monitor.log_training_examples( 201 | epoch=epoch, 202 | batch_texts=batch.context_texts, 203 | target_texts=batch.target_texts, 204 | predicted_features=predicted_features.detach(), 205 | target_features=target_features.detach(), 206 | encoder=encoder, 207 | predictor=predictor, 208 | ) 209 | monitor.log_validation_metrics( 210 | epoch=epoch, 211 | pred_embeddings=predicted_features, 212 | target_embeddings=target_features, 213 | ) 214 | 215 | # End of epoch 216 | if (epoch + 1) % config.logging.checkpoint_freq == 0: 217 | ckpt_path = os.path.join( 218 | config.logging.log_dir, f"checkpoint-epoch{epoch+1}.pth" 219 | ) 220 | save_checkpoint( 221 | ckpt_path, 222 | encoder, 223 | predictor, 224 | optimizer, 225 | scaler, 226 | epoch + 1, 227 | loss_meter.avg, 228 | ) 229 | 230 | wandb.log( 231 | { 232 | "epoch/loss": loss_meter.avg, 233 | "epoch/time": time.time() - epoch_start, 234 | "epoch/number": epoch + 1, 235 | } 236 | ) 237 | 238 | print("Training completed successfully.") 239 | wandb.finish() 240 | -------------------------------------------------------------------------------- /src/decoder/train.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.optim import AdamW 7 | from torch.utils.data import DataLoader 8 | from tqdm import tqdm 9 | 10 | from src.common.logging import AverageMeter 11 | from src.decoder.decoder_dataset import DecoderBatch 12 | from src.decoder.models import ConceptDecoder 13 | from src.decoder.utils.evaluation import ConceptMetrics, SampleGenerator 14 | from src.encoder.models import TextTransformer 15 | 16 | 17 | @dataclass 18 | class DecoderTrainingConfig: 19 | """Configuration for decoder training.""" 20 | 21 | batch_size: int 22 | learning_rate: float 23 | num_epochs: int 24 | warmup_steps: int 25 | grad_clip: float 26 | weight_decay: float 27 | eval_steps: int 28 | save_steps: int 29 | output_dir: str 30 | 31 | 32 | class DecoderTrainer: 33 | def __init__( 34 | self, 35 | config: DecoderTrainingConfig, 36 | encoder: TextTransformer, 37 | decoder: ConceptDecoder, 38 | train_loader: DataLoader, 39 | eval_loader: DataLoader | None = None, 40 | device: torch.device | None = None, 41 | ): 42 | self.config = config 43 | self.encoder = encoder 44 | self.decoder = decoder 45 | self.train_loader = train_loader 46 | self.eval_loader = eval_loader 47 | self.device = device or torch.device( 48 | "cuda" if torch.cuda.is_available() else "cpu" 49 | ) 50 | 51 | # Move models to device 52 | self.encoder = self.encoder.to(self.device) 53 | self.decoder = self.decoder.to(self.device) 54 | 55 | # Freeze encoder 56 | for param in self.encoder.parameters(): 57 | param.requires_grad = False 58 | self.encoder.eval() 59 | 60 | # Setup optimizer 61 | self.optimizer = AdamW( 62 | self.decoder.parameters(), 63 | lr=config.learning_rate, 64 | weight_decay=config.weight_decay, 65 | ) 66 | 67 | # Setup metrics 68 | self.metrics = ConceptMetrics(self.decoder.tokenizer, self.device) 69 | self.sample_generator = SampleGenerator( 70 | self.encoder, self.decoder, self.decoder.tokenizer, self.device 71 | ) 72 | 73 | # Create output directory 74 | self.output_dir = Path(config.output_dir) 75 | self.output_dir.mkdir(parents=True, exist_ok=True) 76 | 77 | def _process_batch(self, batch: DecoderBatch) -> tuple[torch.Tensor, torch.Tensor]: 78 | """Process a batch of data.""" 79 | input_ids = batch.input_ids.to(self.device) 80 | attention_mask = batch.attention_mask.to(self.device) 81 | return input_ids, attention_mask 82 | 83 | def train(self) -> None: 84 | """Main training loop.""" 85 | best_loss = float("inf") 86 | global_step = 0 87 | loss_meter = AverageMeter() 88 | 89 | for epoch in range(self.config.num_epochs): 90 | print(f"\nEpoch {epoch + 1}/{self.config.num_epochs}") 91 | self.decoder.train() 92 | 93 | for batch in tqdm(self.train_loader, desc="Training"): 94 | # Process batch 95 | input_ids, attention_mask = self._process_batch(batch) 96 | 97 | # Get concept embeddings from encoder 98 | with torch.no_grad(): 99 | concepts = self.encoder(input_ids, attention_mask) 100 | if len(concepts.shape) > 2: 101 | concepts = concepts.mean(dim=1) # Average sequence dimension 102 | 103 | # Forward pass through decoder 104 | # input_ids[:, :-1] as input, input_ids[:, 1:] as targets 105 | logits = self.decoder(concepts, target_ids=input_ids) # [B, L-1, V] 106 | 107 | # Prepare targets (excluding first token) 108 | targets = input_ids[:, 1:].reshape(-1) # [B*(L-1)] 109 | 110 | # Reshape logits to match target shape 111 | logits = logits.reshape(-1, logits.size(-1)) # [B*(L-1), V] 112 | 113 | # Verify shapes match 114 | assert logits.size(0) == targets.size(0), ( 115 | f"Shape mismatch: logits {logits.shape}, targets {targets.shape}. " 116 | f"Batch size: {input_ids.size(0)}, Sequence length: {input_ids.size(1)}" 117 | ) 118 | 119 | # Compute loss 120 | loss = F.cross_entropy( 121 | logits, 122 | targets, 123 | ignore_index=self.decoder.config.pad_token_id, 124 | ) 125 | 126 | # Update meter 127 | loss_meter.update(loss.item()) 128 | 129 | # Backward pass 130 | self.optimizer.zero_grad() 131 | loss.backward() 132 | if self.config.grad_clip > 0: 133 | torch.nn.utils.clip_grad_norm_( 134 | self.decoder.parameters(), self.config.grad_clip 135 | ) 136 | self.optimizer.step() 137 | 138 | # Increment step 139 | global_step += 1 140 | 141 | # Evaluate if needed 142 | if global_step % self.config.eval_steps == 0: 143 | eval_loss = self.evaluate() 144 | print(f"\nStep {global_step} - Eval loss: {eval_loss:.4f}") 145 | 146 | # Save if best 147 | if eval_loss < best_loss: 148 | best_loss = eval_loss 149 | self.save_checkpoint( 150 | self.output_dir / "best_decoder.pt", 151 | global_step, 152 | best_loss, 153 | ) 154 | 155 | # Generate samples 156 | if self.eval_loader is not None: 157 | eval_batch = next(iter(self.eval_loader)) 158 | samples = self.sample_generator.generate_samples( 159 | eval_batch.input_texts, num_samples=2 160 | ) 161 | self._print_samples(samples) 162 | 163 | # Save checkpoint if needed 164 | if global_step % self.config.save_steps == 0: 165 | self.save_checkpoint( 166 | self.output_dir / f"decoder_step_{global_step}.pt", 167 | global_step, 168 | loss_meter.avg, 169 | ) 170 | 171 | # Log progress 172 | if global_step % 10 == 0: 173 | print(f"\nStep {global_step} - Loss: {loss_meter.avg:.4f}") 174 | 175 | # End of epoch 176 | print(f"Epoch {epoch + 1} finished. Avg loss: {loss_meter.avg:.4f}") 177 | loss_meter.reset() 178 | 179 | @torch.no_grad() 180 | def evaluate(self) -> float: 181 | """Evaluate the decoder.""" 182 | if self.eval_loader is None: 183 | return float("inf") 184 | 185 | self.decoder.eval() 186 | total_loss = 0 187 | num_batches = 0 188 | 189 | for batch in self.eval_loader: 190 | # Process batch 191 | input_ids, attention_mask = self._process_batch(batch) 192 | 193 | # Get concepts 194 | concepts = self.encoder(input_ids, attention_mask) 195 | 196 | # Generate 197 | logits = self.decoder(concepts, target_ids=input_ids) 198 | 199 | # Compute loss 200 | loss = F.cross_entropy( 201 | logits.view(-1, logits.size(-1)), 202 | input_ids[:, 1:].reshape(-1), 203 | ignore_index=self.decoder.config.pad_token_id, 204 | ) 205 | 206 | total_loss += loss.item() 207 | num_batches += 1 208 | 209 | avg_loss = total_loss / num_batches 210 | self.decoder.train() 211 | return avg_loss 212 | 213 | def save_checkpoint(self, path: Path, global_step: int, loss: float) -> None: 214 | """Save a checkpoint.""" 215 | path.parent.mkdir(parents=True, exist_ok=True) 216 | torch.save( 217 | { 218 | "step": global_step, 219 | "model_state_dict": self.decoder.state_dict(), 220 | "optimizer_state_dict": self.optimizer.state_dict(), 221 | "loss": loss, 222 | "config": self.decoder.config, 223 | }, 224 | path, 225 | ) 226 | print(f"Saved checkpoint to {path}") 227 | 228 | def load_checkpoint(self, path: Path) -> None: 229 | """Load a checkpoint.""" 230 | if not path.exists(): 231 | raise FileNotFoundError(f"Checkpoint not found at {path}") 232 | 233 | checkpoint = torch.load(path, map_location=self.device) 234 | self.decoder.load_state_dict(checkpoint["model_state_dict"]) 235 | self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 236 | print(f"Loaded checkpoint from {path}") 237 | 238 | def _print_samples(self, samples: list) -> None: 239 | """Print generated samples.""" 240 | print("\nGenerated Samples:") 241 | print("-" * 50) 242 | for i, sample in enumerate(samples, 1): 243 | print(f"Sample {i}:") 244 | print(f"Original : {sample['original']}") 245 | print(f"Generated: {sample['generated']}") 246 | print(f"BLEU : {sample['bleu']:.4f}") 247 | print() 248 | -------------------------------------------------------------------------------- /src/encoder/utils/monitor.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | from dataclasses import dataclass 4 | from pathlib import Path 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from rich.console import Console 9 | from rich.markup import escape 10 | from rich.panel import Panel 11 | from rich.table import Table 12 | from torch import Tensor 13 | from transformers import PreTrainedTokenizer 14 | 15 | import wandb 16 | 17 | """ 18 | These metrics will help you understand: 19 | 20 | Semantic Similarity (avg_similarity): How close your predictions are to the actual next sentence embeddings 21 | Hit Rate: Whether your model can distinguish the true next sentence from other sentences in the batch 22 | Embedding Norms: If your embeddings are maintaining reasonable magnitudes 23 | Embedding Diversity: If your embeddings are maintaining good separation or collapsing 24 | 25 | Low hit rate but high similarity might indicate your model is making "safe" but overly general predictions. Low diversity might indicate embedding collapse. You can use these insights to tune your architecture or training process. 26 | """ 27 | 28 | 29 | @dataclass 30 | class MonitoringExample: 31 | """Holds information for a single monitoring example.""" 32 | 33 | context_text: str 34 | target_text: str 35 | predicted_embedding: Tensor 36 | target_embedding: Tensor 37 | similarity_score: float 38 | 39 | 40 | class ValidationMetrics: 41 | """Tracks validation metrics for JEPA model evaluation.""" 42 | 43 | def __init__(self): 44 | self.reset() 45 | 46 | def reset(self): 47 | """Reset all metric counters.""" 48 | self.total_samples = 0 49 | self.metrics = { 50 | "semantic/avg_similarity": 0.0, 51 | "semantic/hit_rate": 0.0, 52 | "embeddings/norm": 0.0, 53 | "embeddings/diversity": 0.0, 54 | } 55 | 56 | def update(self, pred_embeddings: Tensor, target_embeddings: Tensor): 57 | """Update metrics with new batch of embeddings.""" 58 | batch_size = pred_embeddings.shape[0] 59 | 60 | # Compute cosine similarity 61 | sim_scores = F.cosine_similarity(pred_embeddings, target_embeddings) 62 | avg_sim = sim_scores.mean().item() 63 | 64 | # Compute contrastive accuracy (hit rate) 65 | sim_matrix = torch.matmul(pred_embeddings, target_embeddings.T) 66 | correct_matches = sim_matrix.argmax(dim=-1) == torch.arange( 67 | len(pred_embeddings), device=pred_embeddings.device 68 | ) 69 | hit_rate = correct_matches.float().mean().item() 70 | 71 | # Compute embedding norms 72 | norms = torch.norm(pred_embeddings, dim=-1).mean().item() 73 | 74 | # Compute embedding diversity 75 | cosine_sim_matrix = torch.matmul(pred_embeddings, pred_embeddings.T) 76 | mask = ~torch.eye(batch_size, dtype=torch.bool, device=pred_embeddings.device) 77 | diversity = 1 - cosine_sim_matrix[mask].mean().item() 78 | 79 | # Update running averages 80 | weight = batch_size / (self.total_samples + batch_size) 81 | old_weight = 1 - weight 82 | 83 | self.metrics["semantic/avg_similarity"] = ( 84 | old_weight * self.metrics["semantic/avg_similarity"] + weight * avg_sim 85 | ) 86 | self.metrics["semantic/hit_rate"] = ( 87 | old_weight * self.metrics["semantic/hit_rate"] + weight * hit_rate 88 | ) 89 | self.metrics["embeddings/norm"] = ( 90 | old_weight * self.metrics["embeddings/norm"] + weight * norms 91 | ) 92 | self.metrics["embeddings/diversity"] = ( 93 | old_weight * self.metrics["embeddings/diversity"] + weight * diversity 94 | ) 95 | 96 | self.total_samples += batch_size 97 | 98 | def get_metrics(self) -> dict[str, float]: 99 | """Return current metrics.""" 100 | return self.metrics.copy() 101 | 102 | 103 | class TrainingMonitor: 104 | """Monitors and logs training progress for JEPA model.""" 105 | 106 | def __init__( 107 | self, 108 | tokenizer: PreTrainedTokenizer, 109 | log_dir: Path = Path("logs/monitor_logs"), 110 | num_examples: int = 3, 111 | log_every_n_epochs: int = 1, 112 | log_to_wandb: bool = True, 113 | ): 114 | self.tokenizer = tokenizer 115 | self.log_dir = Path(log_dir) 116 | self.num_examples = num_examples 117 | self.log_every_n_epochs = log_every_n_epochs 118 | self.log_to_wandb = log_to_wandb 119 | self.console = Console() 120 | self.validation_metrics = ValidationMetrics() 121 | 122 | # Create log directory 123 | self.log_dir.mkdir(parents=True, exist_ok=True) 124 | 125 | # Set up loggers 126 | self._setup_loggers() 127 | 128 | def _setup_loggers(self): 129 | """Initialize different loggers for console and file output.""" 130 | # Console logger 131 | self.console_logger = logging.getLogger("console") 132 | console_handler = logging.StreamHandler(sys.stdout) 133 | console_formatter = logging.Formatter("%(message)s") 134 | console_handler.setFormatter(console_formatter) 135 | self.console_logger.addHandler(console_handler) 136 | self.console_logger.setLevel(logging.INFO) 137 | 138 | # File logger for training examples 139 | self.file_logger = logging.getLogger("training_examples") 140 | file_handler = logging.FileHandler(self.log_dir / "training_examples.log") 141 | file_formatter = logging.Formatter("%(asctime)s - %(message)s") 142 | file_handler.setFormatter(file_formatter) 143 | self.file_logger.addHandler(file_handler) 144 | self.file_logger.setLevel(logging.INFO) 145 | 146 | def log_training_examples( 147 | self, 148 | epoch: int, 149 | batch_texts: list[str], 150 | target_texts: list[str], 151 | predicted_features: Tensor, 152 | target_features: Tensor, 153 | encoder: torch.nn.Module, 154 | predictor: torch.nn.Module, 155 | ) -> None: 156 | """Log training examples showing next-sentence prediction performance.""" 157 | if epoch % self.log_every_n_epochs != 0: 158 | return 159 | 160 | examples = [] 161 | for idx in range(min(self.num_examples, len(batch_texts))): 162 | # Calculate cosine similarity for this example 163 | similarity = F.cosine_similarity( 164 | predicted_features[idx : idx + 1], target_features[idx : idx + 1], dim=1 165 | ).item() 166 | 167 | example = MonitoringExample( 168 | context_text=batch_texts[idx], 169 | target_text=target_texts[idx], 170 | predicted_embedding=predicted_features[idx].cpu(), 171 | target_embedding=target_features[idx].cpu(), 172 | similarity_score=similarity, 173 | ) 174 | examples.append(example) 175 | 176 | self._display_examples(epoch, examples) 177 | if self.log_to_wandb: 178 | self._log_to_wandb(epoch, examples) 179 | 180 | def _display_examples(self, epoch: int, examples: list[MonitoringExample]) -> None: 181 | """Display training examples in a formatted table.""" 182 | self.file_logger.info(f"\n=== Training Examples (Epoch {epoch}) ===") 183 | 184 | for i, example in enumerate(examples, 1): 185 | table = Table( 186 | show_header=True, header_style="bold magenta", show_lines=True 187 | ) 188 | table.add_column("Type", style="cyan", width=20) 189 | table.add_column("Content", style="green") 190 | 191 | # Format context text for display 192 | escaped_context = escape(example.context_text) 193 | context_chunks = [ 194 | escaped_context[i : i + 100] 195 | for i in range(0, len(escaped_context), 100) 196 | ] 197 | table.add_row("Context", "\n".join(context_chunks)) 198 | 199 | # Format target text 200 | escaped_target = escape(example.target_text) 201 | target_chunks = [ 202 | escaped_target[i : i + 100] for i in range(0, len(escaped_target), 100) 203 | ] 204 | table.add_row("Target (Next Sentence)", "\n".join(target_chunks)) 205 | 206 | # Add similarity score 207 | table.add_row("Cosine Similarity", f"{example.similarity_score:.4f}") 208 | 209 | # Add embedding statistics 210 | pred_norm = example.predicted_embedding.norm().item() 211 | target_norm = example.target_embedding.norm().item() 212 | table.add_row( 213 | "Embedding Stats", 214 | f"Predicted norm: {pred_norm:.4f}\nTarget norm: {target_norm:.4f}", 215 | ) 216 | 217 | # Print the table to console and log to file 218 | panel = Panel(table, title=f"Example {i}", border_style="blue") 219 | self.file_logger.info(panel) 220 | self.file_logger.info(f"\nExample {i}:") 221 | self.file_logger.info(f"Context: {example.context_text}") 222 | self.file_logger.info(f"Target: {example.target_text}") 223 | self.file_logger.info(f"Similarity: {example.similarity_score:.4f}") 224 | self.file_logger.info("-" * 80) 225 | 226 | def _log_to_wandb(self, epoch: int, examples: list[MonitoringExample]) -> None: 227 | """Log examples and metrics to Weights & Biases.""" 228 | # Log examples 229 | for i, example in enumerate(examples): 230 | wandb.log( 231 | { 232 | f"examples/context_{i}": example.context_text, 233 | f"examples/target_{i}": example.target_text, 234 | f"examples/similarity_{i}": example.similarity_score, 235 | f"examples/pred_norm_{i}": example.predicted_embedding.norm().item(), 236 | f"examples/target_norm_{i}": example.target_embedding.norm().item(), 237 | "epoch": epoch, 238 | } 239 | ) 240 | 241 | # Create similarity histogram for this batch 242 | if i == 0: # Only do this once per batch 243 | wandb.log( 244 | { 245 | "similarity_distribution": wandb.Histogram( 246 | [e.similarity_score for e in examples] 247 | ), 248 | "epoch": epoch, 249 | } 250 | ) 251 | 252 | def log_validation_metrics( 253 | self, 254 | epoch: int, 255 | pred_embeddings: Tensor, 256 | target_embeddings: Tensor, 257 | ) -> None: 258 | """Log validation metrics for the current batch.""" 259 | # Update metrics 260 | self.validation_metrics.update(pred_embeddings, target_embeddings) 261 | 262 | # Get current metrics 263 | metrics = self.validation_metrics.get_metrics() 264 | 265 | # Log to console with rich table 266 | self.console.print("\nValidation Metrics:") 267 | table = Table(show_header=True, header_style="bold magenta") 268 | table.add_column("Metric", style="cyan") 269 | table.add_column("Value", style="green") 270 | 271 | for name, value in metrics.items(): 272 | table.add_row(name, f"{value:.4f}") 273 | 274 | self.console.print(table) 275 | 276 | # Log to wandb if enabled 277 | if self.log_to_wandb: 278 | wandb.log({f"val/{k}": v for k, v in metrics.items()}) 279 | wandb.log({"epoch": epoch}) 280 | 281 | # Log to file 282 | self.file_logger.info(f"\nValidation Metrics (Epoch {epoch}):") 283 | for name, value in metrics.items(): 284 | self.file_logger.info(f"{name}: {value:.4f}") 285 | 286 | def get_current_metrics(self) -> dict[str, float]: 287 | """Return current validation metrics.""" 288 | return self.validation_metrics.get_metrics() 289 | 290 | def reset_validation_metrics(self) -> None: 291 | """Reset validation metrics for new epoch.""" 292 | self.validation_metrics.reset() 293 | --------------------------------------------------------------------------------