├── clt ├── __init__.py ├── utils │ ├── io.py │ └── __init__.py ├── nnsight │ └── __init__.py ├── activations │ ├── __init__.py │ └── registry.py ├── training │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── base_store.py │ │ ├── activation_store_factory.py │ │ └── remote_activation_store.py │ ├── utils.py │ ├── distributed_utils.py │ ├── metric_utils.py │ └── diagnostics.py ├── activation_generation │ └── __init__.py ├── config │ ├── __init__.py │ └── data_config.py ├── models │ ├── __init__.py │ ├── base.py │ ├── activations_optimized.py │ └── encoder.py └── parallel │ └── ops.py ├── tests ├── __init__.py ├── unit │ ├── __init__.py │ ├── models │ │ ├── test_decoder_norms.py │ │ ├── test_clt_encode_decode.py │ │ ├── test_decoder.py │ │ ├── test_parallel_layers.py │ │ ├── test_encoder.py │ │ ├── test_theta.py │ │ └── test_parallel_ops.py │ ├── data │ │ └── README.md │ └── training │ │ ├── test_checkpointing.py │ │ ├── data │ │ ├── test_chunk_row_sampler.py │ │ └── test_local_activation_store.py │ │ └── test_loss_manager.py ├── helpers │ ├── __init__.py │ ├── fake_requests.py │ ├── fake_hdf5.py │ └── tiny_configs.py ├── integration │ ├── __init__.py │ ├── README.md │ ├── test_single_gpu_training_step.py │ ├── distributed_training_worker.py │ ├── test_clt_end_to_end.py │ ├── test_distributed_training.py │ ├── test_clt_distributed.py │ └── test_checkpoint_resumption.py ├── __main__.py ├── README.md ├── conftest.py └── models │ └── test_clt_distributed_forward.py ├── clt_server ├── api │ ├── __init__.py │ └── health.py ├── core │ ├── __init__.py │ └── config.py ├── tests │ ├── __init__.py │ ├── core │ │ ├── __init__.py │ │ └── test_config.py │ └── api │ │ └── test_health.py ├── requirements.txt ├── README.md └── main.py ├── pytest.ini ├── .vscode └── settings.json ├── example_clt_config_gpt2_local.json ├── scripts ├── analysis │ └── check_model.py ├── smoke_train_dist.sh ├── compare_norm_stats.py └── experiments │ ├── run_pythia_batchtopk_training.py │ └── run_pythia_batchtopk_training_fp16.py ├── LICENSE ├── .github └── workflows │ └── python-tests.yml ├── pyproject.toml └── .gitignore /clt/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /clt/utils/io.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /clt/nnsight/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /clt/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /clt/activations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /clt/training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /clt_server/api/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /clt_server/core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/helpers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /clt_server/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/helpers/fake_requests.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/integration/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /clt/activation_generation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /clt_server/tests/core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__main__.py: -------------------------------------------------------------------------------- 1 | # This file allows the tests directory to be run as a package. 2 | -------------------------------------------------------------------------------- /clt/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .clt_config import CLTConfig, TrainingConfig 2 | from .data_config import ActivationConfig 3 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers = 3 | integration: marks tests as integration tests that verify multiple components working together \ 4 | require_gpu: marks tests that require a GPU (CUDA or MPS) to run \ -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "editor.rulers": [120], 3 | "python.formatting.provider": "black", 4 | "black-formatter.args": ["--line-length", "120"], 5 | "ruff.args": ["--line-length=120", "--ignore=E501"] 6 | } -------------------------------------------------------------------------------- /clt_server/api/health.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter 2 | 3 | router = APIRouter() 4 | 5 | 6 | @router.get("/health") 7 | async def health_check(): 8 | """Provides a simple health check endpoint.""" 9 | return {"status": "ok"} 10 | -------------------------------------------------------------------------------- /clt_server/requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi 2 | uvicorn[standard] 3 | python-dotenv # For loading config from .env file 4 | pydantic-settings # Needed for core.config 5 | pytest-asyncio # For running async tests 6 | # torch - Install separately based on system/CUDA 7 | requests # Needed if server makes outbound calls (unlikely for now) 8 | aiofiles # For async file operations -------------------------------------------------------------------------------- /example_clt_config_gpt2_local.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_features": 24576, 3 | "num_layers": 12, 4 | "d_model": 768, 5 | "model_name": "gpt2", 6 | "normalization_method": "auto", 7 | "activation_fn": "relu", 8 | "jumprelu_threshold": 0.03, 9 | "clt_dtype": null, 10 | "expected_input_dtype": "float32", 11 | "mlp_input_template": "transformer.h.{}.ln_2.input", 12 | "mlp_output_template": "transformer.h.{}.mlp.output" 13 | } -------------------------------------------------------------------------------- /clt/training/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_store import BaseActivationStore 2 | from .local_activation_store import LocalActivationStore 3 | from .remote_activation_store import RemoteActivationStore 4 | from .manifest_activation_store import ManifestActivationStore, ChunkRowSampler, _open_h5, ActivationBatch 5 | 6 | __all__ = [ 7 | "BaseActivationStore", 8 | "LocalActivationStore", 9 | "RemoteActivationStore", 10 | "ManifestActivationStore", 11 | "ChunkRowSampler", 12 | "_open_h5", # If intended to be part of public API from this level 13 | "ActivationBatch", # If intended to be part of public API from this level 14 | ] 15 | -------------------------------------------------------------------------------- /scripts/analysis/check_model.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import torch 3 | import os 4 | import json 5 | from safetensors.torch import load_file 6 | 7 | # Load model from safetensors file 8 | model_path = "/Users/curttigges/Projects/crosslayer-coding/conversion_test/gpt2_32k/full_model.safetensors" 9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | 11 | if os.path.exists(model_path): 12 | state_dict = load_file(model_path, device=device.type) 13 | print(f"Loaded model from {model_path}") 14 | else: 15 | print(f"Model file not found at {model_path}") 16 | 17 | # %% 18 | state_dict.keys() 19 | # %% 20 | state_dict["decoder_module.decoders.0->1.weight"].shape 21 | # %% 22 | -------------------------------------------------------------------------------- /clt/models/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def mark_replicated(param: torch.nn.Parameter): 5 | """Tag a Parameter that is fully replicated across TP ranks. 6 | 7 | Args: 8 | param: The torch.nn.Parameter to mark. 9 | """ 10 | # setattr(param, "_is_replicated", True) # Alternative 11 | param.__dict__["_is_replicated"] = True 12 | 13 | 14 | def is_replicated(param: torch.nn.Parameter) -> bool: 15 | """Check if a Parameter is marked as replicated. 16 | 17 | Args: 18 | param: The torch.nn.Parameter to check. 19 | 20 | Returns: 21 | True if the parameter is marked as replicated, False otherwise. 22 | """ 23 | return getattr(param, "_is_replicated", False) 24 | -------------------------------------------------------------------------------- /clt/training/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | # Helper function to format elapsed time 8 | def _format_elapsed_time(seconds: float) -> str: 9 | """Formats elapsed seconds into HH:MM:SS or MM:SS.""" 10 | td = datetime.timedelta(seconds=int(seconds)) 11 | hours, remainder = divmod(td.seconds, 3600) 12 | minutes, seconds = divmod(remainder, 60) 13 | if td.days > 0 or hours > 0: 14 | return f"{td.days * 24 + hours:02d}:{minutes:02d}:{seconds:02d}" 15 | else: 16 | return f"{minutes:02d}:{seconds:02d}" 17 | 18 | 19 | def torch_bfloat16_to_numpy_uint16(x: torch.Tensor) -> np.ndarray: 20 | return np.frombuffer(x.float().numpy().tobytes(), dtype=np.uint16)[1::2].reshape(x.shape) 21 | -------------------------------------------------------------------------------- /clt_server/core/config.py: -------------------------------------------------------------------------------- 1 | from pydantic_settings import BaseSettings 2 | from pathlib import Path 3 | 4 | 5 | class Settings(BaseSettings): 6 | # Base directory for storing activation datasets and chunks 7 | STORAGE_BASE_DIR: Path = Path("./server_data") 8 | 9 | # Number of extra random chunks to try loading if the first fails 10 | CHUNK_RETRY_ATTEMPTS: int = 3 11 | 12 | # Other potential settings (e.g., logging level, allowed origins for CORS) 13 | LOG_LEVEL: str = "info" 14 | 15 | class Config: 16 | # Optional: Load from a .env file if present 17 | env_file = ".env" 18 | env_file_encoding = "utf-8" 19 | 20 | 21 | # Create a single instance of the settings to be imported elsewhere 22 | settings = Settings() 23 | 24 | # Ensure the storage directory exists on startup (or handle creation elsewhere) 25 | if not settings.STORAGE_BASE_DIR.exists(): 26 | print(f"Creating storage directory: {settings.STORAGE_BASE_DIR}") 27 | settings.STORAGE_BASE_DIR.mkdir(parents=True, exist_ok=True) 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Curt Tigges 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/smoke_train_dist.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # smoke_train_dist.sh 3 | # Launches the smoke_train.py script using torchrun for distributed training. 4 | 5 | # Ensure the script is executable: chmod +x scripts/smoke_train_dist.sh 6 | 7 | # Number of GPUs to use (adjust if needed, but 2 is good for smoke testing) 8 | NPROC_PER_NODE=2 9 | 10 | # Path to the smoke_train.py script 11 | SMOKE_SCRIPT="scripts/smoke_train.py" 12 | 13 | echo "Launching distributed smoke test with $NPROC_PER_NODE processes..." 14 | 15 | # Check if smoke_train.py exists 16 | if [ ! -f "$SMOKE_SCRIPT" ]; then 17 | echo "Error: $SMOKE_SCRIPT not found!" 18 | exit 1 19 | fi 20 | 21 | # Clear previous smoke output if it exists to ensure a fresh run 22 | if [ -d "./clt_smoke_output" ]; then 23 | echo "Clearing previous smoke output directory ./clt_smoke_output" 24 | rm -rf "./clt_smoke_output" 25 | fi 26 | 27 | # Launch with torchrun 28 | # No need to pass --distributed explicitly to smoke_train.py, as it auto-detects from env vars set by torchrun. 29 | torchrun --nproc_per_node=$NPROC_PER_NODE $SMOKE_SCRIPT 30 | 31 | EXIT_CODE=$? 32 | 33 | if [ $EXIT_CODE -eq 0 ]; then 34 | echo "Distributed smoke test completed successfully." 35 | else 36 | echo "Distributed smoke test failed with exit code $EXIT_CODE." 37 | fi 38 | 39 | exit $EXIT_CODE -------------------------------------------------------------------------------- /tests/unit/models/test_decoder_norms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from clt.config import CLTConfig 4 | from clt.models.decoder import Decoder 5 | 6 | 7 | def _create_decoder(num_layers: int = 3, d_model: int = 8, num_features: int = 12): 8 | config = CLTConfig(num_layers=num_layers, d_model=d_model, num_features=num_features) 9 | return Decoder(config=config, process_group=None, device=torch.device("cpu"), dtype=torch.float32) 10 | 11 | 12 | def test_decoder_norms_shape_and_non_negative(): 13 | """Decoder.get_decoder_norms should return a tensor of shape [num_layers, num_features] with non-negative values.""" 14 | decoder = _create_decoder() 15 | norms = decoder.get_decoder_norms() 16 | 17 | assert norms.shape == (decoder.config.num_layers, decoder.config.num_features) 18 | # All norms should be >= 0 (L2 norms) 19 | assert torch.all(norms >= 0), "Decoder norms should be non-negative" 20 | 21 | 22 | def test_decoder_norms_cached(): 23 | """Subsequent calls to get_decoder_norms should return the cached tensor object (no recomputation).""" 24 | decoder = _create_decoder() 25 | norms_first = decoder.get_decoder_norms() 26 | norms_second = decoder.get_decoder_norms() 27 | 28 | # Should be the *same* tensor object (cached) 29 | assert norms_first is norms_second, "Decoder norms should be cached and identical object on repeated calls" 30 | -------------------------------------------------------------------------------- /.github/workflows/python-tests.yml: -------------------------------------------------------------------------------- 1 | name: Python Tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - develop # Or your primary development branch 8 | pull_request: 9 | branches: 10 | - main 11 | 12 | jobs: 13 | test: 14 | runs-on: ubuntu-latest 15 | strategy: 16 | matrix: 17 | python-version: ["3.9", "3.10", "3.11"] # Specify python versions 18 | 19 | steps: 20 | - uses: actions/checkout@v4 21 | 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v4 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | 27 | - name: Install Poetry 28 | run: | 29 | curl -sSL https://install.python-poetry.org | python3 - 30 | echo "$HOME/.local/bin" >> $GITHUB_PATH 31 | # Alternatively, if not using Poetry, or have a requirements.txt: 32 | # run: pip install -r requirements.txt 33 | 34 | - name: Install dependencies 35 | run: poetry install --no-interaction --no-root 36 | # If you have dev dependencies for pytest, e.g. in a [tool.poetry.group.dev.dependencies] 37 | # run: poetry install --no-interaction --no-root --with dev 38 | # Or if using pip with requirements.txt: 39 | # run: pip install -r requirements-dev.txt # (if you have a separate dev requirements) 40 | # run: pip install pytest # or ensure pytest is in your main requirements 41 | 42 | - name: Run tests with pytest 43 | run: poetry run pytest tests/ 44 | # Or if not using poetry: 45 | # run: pytest tests/ -------------------------------------------------------------------------------- /clt/training/data/base_store.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Dict, List, Tuple, Any 3 | import logging 4 | from abc import ABC, abstractmethod 5 | 6 | logging.basicConfig(level=logging.INFO) 7 | logger = logging.getLogger(__name__) 8 | 9 | ActivationBatchCLT = Tuple[Dict[int, torch.Tensor], Dict[int, torch.Tensor]] 10 | 11 | 12 | class BaseActivationStore(ABC): 13 | layer_indices: List[int] 14 | d_model: int 15 | dtype: torch.dtype 16 | device: torch.device 17 | train_batch_size_tokens: int 18 | total_tokens: int 19 | 20 | @abstractmethod 21 | def get_batch(self) -> ActivationBatchCLT: 22 | pass 23 | 24 | @abstractmethod 25 | def state_dict(self) -> Dict[str, Any]: 26 | pass 27 | 28 | @abstractmethod 29 | def load_state_dict(self, state_dict: Dict[str, Any]): 30 | pass 31 | 32 | @abstractmethod 33 | def close(self): 34 | pass 35 | 36 | def __iter__(self): 37 | return self 38 | 39 | def __next__(self): 40 | try: 41 | return self.get_batch() 42 | except StopIteration: 43 | raise 44 | except Exception as e: 45 | logger.error(f"Error during iteration: {e}", exc_info=True) 46 | raise 47 | 48 | def __len__(self): 49 | if ( 50 | not hasattr(self, "total_tokens") 51 | or self.total_tokens <= 0 52 | or not hasattr(self, "train_batch_size_tokens") 53 | or self.train_batch_size_tokens <= 0 54 | ): 55 | return 0 56 | return (self.total_tokens + self.train_batch_size_tokens - 1) // self.train_batch_size_tokens 57 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | target-version = ['py37'] 4 | include = '\.pyi?$' 5 | 6 | [tool.flake8] 7 | max-line-length = 120 8 | ignore = ["E501"] 9 | extend-ignore = "E203" # Ignore whitespace before ':' (conflicts with Black) 10 | per-file-ignores = [] 11 | exclude = [ 12 | ".git", 13 | "__pycache__", 14 | "build", 15 | "dist", 16 | ] 17 | 18 | [build-system] 19 | requires = ["setuptools>=61.0"] 20 | build-backend = "setuptools.build_meta" 21 | 22 | [project] 23 | name = "clt" 24 | version = "0.0.1" 25 | description = "Cross-Layer Transcoder Library" 26 | readme = "README.md" 27 | requires-python = ">=3.8" 28 | classifiers = [ 29 | "Programming Language :: Python :: 3", 30 | "License :: OSI Approved :: MIT License", # Assuming MIT, change if needed 31 | "Operating System :: OS Independent", 32 | ] 33 | dependencies = [ 34 | "torch>=2.0.0", 35 | "torchvision>=0.15.1", 36 | "torchaudio>=2.0.1", 37 | "nnsight>=0.1.0", 38 | "transformers>=4.30.0", 39 | "datasets>=2.0.0", 40 | "wandb>=0.15.0", 41 | "tqdm>=4.64.0", 42 | "numpy>=1.24.0", 43 | "zstandard", 44 | "h5py", 45 | "requests", 46 | # pytest is primarily a dev dependency 47 | # Add other core dependencies here as needed 48 | ] 49 | 50 | [project.optional-dependencies] 51 | dev = [ 52 | "pytest", 53 | "pytest-cov", 54 | "black>=23.0.0", 55 | "flake8>=6.0.0", 56 | "mypy>=1.0.0", 57 | # Add other development dependencies like linters, formatters etc. 58 | ] 59 | 60 | [tool.setuptools.packages.find] 61 | where = ["."] # Root directory contains the 'clt' package 62 | include = ["clt*"] 63 | exclude = ["tests*"] # Exclude tests from the package itself -------------------------------------------------------------------------------- /clt_server/README.md: -------------------------------------------------------------------------------- 1 | # Activation Storage Server 2 | 3 | This server provides a RESTful API for storing and retrieving pre-generated model activations, primarily for use with the Cross-Layer Transcoder (CLT) training process. 4 | 5 | ## Features 6 | 7 | - Stores activation chunks uploaded by `ActivationGenerator`. 8 | - Serves dataset metadata (`metadata.json`). 9 | - Serves normalization statistics (`norm_stats.json`). 10 | - Serves random batches of activations for training (`RemoteActivationStore`). 11 | 12 | ## Setup 13 | 14 | 1. **Install Dependencies:** 15 | ```bash 16 | pip install -r requirements.txt 17 | # Install PyTorch matching your system/CUDA version (see pytorch.org) 18 | # e.g., pip install torch torchvision torchaudio 19 | ``` 20 | 21 | 2. **Configure Storage:** 22 | - By default, data is stored in `./server_data`. 23 | - Set the `STORAGE_BASE_DIR` environment variable to change this location. 24 | 25 | 3. **Run the Server:** 26 | ```bash 27 | # For development/testing: 28 | python main.py 29 | 30 | # Or using uvicorn directly: 31 | uvicorn clt_server.main:app --reload --host 0.0.0.0 --port 8000 32 | ``` 33 | The `--reload` flag automatically restarts the server when code changes. 34 | 35 | ## API 36 | 37 | The API documentation is available via Swagger UI at `http://:8000/docs` when the server is running. 38 | 39 | Refer to `ref_docs/activation_server_api.md` in the main project for the detailed API specification. 40 | 41 | ## TODO 42 | 43 | - Add robust error handling and logging. 44 | - Consider authentication/authorization. 45 | - Add unit and integration tests. 46 | - Implement alternative storage backends (e.g., S3). 47 | - Optimize chunk storage/retrieval. -------------------------------------------------------------------------------- /tests/helpers/fake_hdf5.py: -------------------------------------------------------------------------------- 1 | """Helper to create synthetic HDF5 activation data for testing.""" 2 | 3 | import h5py 4 | import numpy as np 5 | from pathlib import Path 6 | from typing import Union, Any 7 | 8 | 9 | def make_tiny_chunk_files( 10 | path: Union[str, Path], 11 | num_chunks: int = 1, 12 | n_layers: int = 2, 13 | n_tokens: int = 32, 14 | d_model: int = 8, 15 | dtype: Any = np.float16, 16 | ): 17 | """ 18 | Creates one or more HDF5 chunk files with synthetic data. 19 | 20 | Args: 21 | path: Directory to save the chunk files in. 22 | num_chunks: Number of chunk files to create. 23 | n_layers: Number of layers. 24 | n_tokens: Number of tokens (rows) per chunk. 25 | d_model: Feature dimension. 26 | dtype: Numpy dtype for the data. 27 | """ 28 | path = Path(path) 29 | path.mkdir(parents=True, exist_ok=True) 30 | 31 | rng = np.random.default_rng(seed=42) 32 | 33 | for i in range(num_chunks): 34 | chunk_path = path / f"chunk_{i}.h5" 35 | with h5py.File(chunk_path, "w") as f: 36 | for layer_idx in range(n_layers): 37 | layer_group = f.create_group(f"layer_{layer_idx}") 38 | # "inputs" are from the source model's MLP outputs 39 | # "targets" are from the source model's MLP inputs 40 | # Both have the same shape 41 | shape = (n_tokens, d_model) 42 | inputs_data = (rng.random(size=shape) * 10).astype(dtype) 43 | targets_data = (rng.random(size=shape) * 5).astype(dtype) 44 | layer_group.create_dataset("inputs", data=inputs_data) 45 | layer_group.create_dataset("targets", data=targets_data) 46 | -------------------------------------------------------------------------------- /clt/training/distributed_utils.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | from clt.parallel import ops as dist_ops 3 | 4 | if TYPE_CHECKING: 5 | from clt.models.clt import CrossLayerTranscoder # For type hinting model parameters 6 | 7 | # from clt.models import is_replicated # No longer needed 8 | 9 | 10 | def average_shared_parameter_grads(model: "CrossLayerTranscoder", world_size: int): 11 | """Average gradients of parameters that are **replicated** across ranks. 12 | 13 | Tensor-parallel layers shard their weights so those gradients must **not** be 14 | synchronised. However parameters that are kept identical on every rank – 15 | e.g. the JumpReLU `log_threshold` vector (shape `[num_features]`) and any 16 | unsharded bias vectors – must have their gradients reduced or they will 17 | diverge between ranks. 18 | """ 19 | # This function is called when distributed and world_size > 1 20 | # The original method had a guard: `if not self.distributed or self.world_size == 1: return` 21 | # That check should be done by the caller now. 22 | for p in model.parameters(): 23 | if p.grad is None: 24 | continue 25 | 26 | # Check if parameter is marked as replicated OR if it's a 1D tensor (for backward compatibility) 27 | # The import for is_replicated will be guarded by TYPE_CHECKING, so use getattr for runtime. 28 | is_rep = getattr(p, "_is_replicated", False) 29 | 30 | # Only average if explicitly marked as replicated. 31 | # The p.dim() == 1 heuristic was too broad and could incorrectly average sharded 1D parameters (e.g., encoder biases). 32 | if is_rep: 33 | dist_ops.all_reduce(p.grad, op=dist_ops.SUM) 34 | p.grad /= world_size 35 | -------------------------------------------------------------------------------- /clt_server/main.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | import uvicorn 3 | import os # Import os for environment variables 4 | 5 | # Import routers (assuming they exist, we will create them) 6 | from .api import health 7 | from .core.config import settings # Import settings 8 | 9 | # Import the low‑level slice server app to expose /datasets endpoints at root 10 | from .core import storage as slice_server 11 | 12 | app = FastAPI( 13 | title="CLT Activation Storage Server", 14 | description="Stores and serves pre-generated model activations for CLT training.", 15 | version="0.1.0", 16 | ) 17 | 18 | # Include API routers 19 | app.include_router(health.router, prefix="/api/v1", tags=["Health"]) 20 | 21 | # Mount the slice server (raw HDF5 slice endpoints) at root so that 22 | # paths like /datasets/... are served alongside the higher‑level /api/v1 routes. 23 | app.mount("", slice_server.app) 24 | 25 | 26 | @app.get("/", tags=["Root"]) 27 | async def read_root(): 28 | return { 29 | "message": "Welcome to the CLT Activation Storage Server. See /docs for API details." 30 | } 31 | 32 | 33 | # Optional: Add startup/shutdown events later for resource management 34 | # @app.on_event("startup") 35 | # async def startup_event(): 36 | # print("Server starting up...") 37 | 38 | # @app.on_event("shutdown") 39 | # async def shutdown_event(): 40 | # print("Server shutting down...") 41 | 42 | # Allow running directly for simple testing/development 43 | if __name__ == "__main__": 44 | # Use environment variables for host/port or defaults 45 | host = os.getenv("HOST", "127.0.0.1") 46 | port = int(os.getenv("PORT", "8000")) 47 | log_level = os.getenv("LOG_LEVEL", "info") 48 | 49 | print(f"Starting server on {host}:{port}...") 50 | print(f"Storage base directory: {settings.STORAGE_BASE_DIR}") # Log storage dir 51 | uvicorn.run(app, host=host, port=port, log_level=log_level) 52 | -------------------------------------------------------------------------------- /clt_server/tests/api/test_health.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pytest_asyncio 3 | from httpx import AsyncClient 4 | import sys 5 | import os 6 | from typing import AsyncGenerator 7 | 8 | # Ensure the project root is in the Python path for imports 9 | project_root = os.path.abspath( 10 | os.path.join(os.path.dirname(__file__), "..", "..", "..") 11 | ) 12 | if project_root not in sys.path: 13 | sys.path.insert(0, project_root) 14 | 15 | # Import the FastAPI app instance 16 | # Adjust the import path if your structure differs 17 | try: 18 | from clt_server.main import app 19 | except ImportError as e: 20 | print(f"Error importing FastAPI app: {e}") 21 | print( 22 | "Ensure the test is run from the project root or PYTHONPATH is set correctly." 23 | ) 24 | # Optionally re-raise or exit if the app cannot be imported 25 | raise 26 | 27 | # --- Fixtures --- 28 | 29 | 30 | @pytest_asyncio.fixture(scope="function") # Use function scope for client isolation 31 | async def async_client() -> AsyncGenerator[AsyncClient, None]: 32 | """Provides an httpx AsyncClient configured for the test app.""" 33 | # Use the context manager for proper startup/shutdown event handling 34 | async with AsyncClient(app=app, base_url="http://test") as client: 35 | yield client 36 | 37 | 38 | # --- Test Cases --- 39 | 40 | 41 | @pytest.mark.asyncio 42 | async def test_health_check_status_code(async_client: AsyncClient): 43 | """Tests if the health check endpoint returns a 200 OK status.""" 44 | response = await async_client.get("/api/v1/health") 45 | assert response.status_code == 200 46 | 47 | 48 | @pytest.mark.asyncio 49 | async def test_health_check_response_body(async_client: AsyncClient): 50 | """Tests if the health check endpoint returns the correct JSON body.""" 51 | response = await async_client.get("/api/v1/health") 52 | assert response.status_code == 200 53 | assert response.json() == {"status": "ok"} 54 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | # CLT Testing Strategy 2 | 3 | This directory contains tests for the Cross-Layer Transcoder (CLT) library. The testing strategy is divided into two main categories: 4 | 5 | ## Unit Tests 6 | 7 | Located in `tests/unit/`, these tests verify that individual components of the library work correctly in isolation. Unit tests are designed to: 8 | 9 | - Test specific functionality of a single class or function 10 | - Mock dependencies to isolate the component being tested 11 | - Run quickly and provide good coverage 12 | - Help identify where issues originate 13 | 14 | ## Integration Tests 15 | 16 | Located in `tests/integration/`, these tests verify that multiple components work together correctly. Integration tests are designed to: 17 | 18 | - Test realistic user workflows 19 | - Verify that components interact properly 20 | - Use real (but small-scale) components instead of mocks 21 | - Test data flow between components 22 | 23 | ## Running Tests 24 | 25 | To run all tests: 26 | 27 | ```bash 28 | pytest 29 | ``` 30 | 31 | To run unit tests only: 32 | 33 | ```bash 34 | pytest tests/unit 35 | ``` 36 | 37 | To run integration tests only: 38 | 39 | ```bash 40 | pytest tests/integration 41 | ``` 42 | 43 | ## Test Fixtures 44 | 45 | The `tests/integration/data/` directory contains fixtures for integration tests, including: 46 | 47 | - Sample activation data 48 | - Pre-trained model files 49 | - Helper scripts to generate test data 50 | 51 | These fixtures enable deterministic testing of model loading, inference, and training without requiring external data. 52 | 53 | ## Writing New Tests 54 | 55 | When adding new functionality to the library: 56 | 57 | 1. Add unit tests for the new component in `tests/unit/` 58 | 2. Add integration tests in `tests/integration/` for any new workflows 59 | 60 | Use the `@pytest.mark.integration` marker for integration tests: 61 | 62 | ```python 63 | @pytest.mark.integration 64 | def test_my_integration_test(): 65 | # Test code here 66 | pass 67 | ``` -------------------------------------------------------------------------------- /tests/integration/README.md: -------------------------------------------------------------------------------- 1 | # CLT Integration Tests 2 | 3 | This directory contains integration tests for the Cross-Layer Transcoder (CLT) library. Unlike unit tests that verify individual components in isolation, these integration tests verify that multiple components work together correctly. 4 | 5 | ## Testing Strategy 6 | 7 | Our integration tests focus on these key integration points: 8 | 9 | 1. **Configuration Loading & Usage**: Testing that `CLTConfig` and `TrainingConfig` are correctly used by the trainer, models, and loss manager components. 10 | 11 | 2. **Activation Data Pipeline**: Verifying the flow from activation sources → `ActivationStore` → training components. 12 | 13 | 3. **End-to-End Training**: Testing that the full training pipeline works with real components (small-scale). 14 | 15 | 4. **Model Persistence**: Testing save/load functionality for trained models. 16 | 17 | 5. **Config Variants**: Testing different configuration options and ensuring they properly integrate. 18 | 19 | ## Test Fixtures 20 | 21 | The tests use small-scale fixtures to enable quick testing: 22 | 23 | - Small model configurations (few layers, few features) 24 | - Minimal activation datasets 25 | - Short training runs (few steps) 26 | 27 | Some tests utilize pre-generated files in the `data/` directory. 28 | 29 | ## Running the Tests 30 | 31 | To run the integration tests only: 32 | 33 | ```bash 34 | pytest tests/integration -v 35 | ``` 36 | 37 | To run a specific integration test file: 38 | 39 | ```bash 40 | pytest tests/integration/test_training_pipeline.py -v 41 | ``` 42 | 43 | To run a specific test: 44 | 45 | ```bash 46 | pytest tests/integration/test_activation_store.py::test_activation_store_from_nnsight -v 47 | ``` 48 | 49 | ## Writing New Integration Tests 50 | 51 | When adding new integration tests: 52 | 53 | 1. Use the `@pytest.mark.integration` decorator to mark integration tests 54 | 2. Use temporary directories for any file operations 55 | 3. Keep model sizes and training steps small to ensure tests run quickly 56 | 4. When testing with real components, focus on verifying that they connect correctly, not on the quality of results -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import json 4 | import numpy as np 5 | from pathlib import Path 6 | 7 | from tests.helpers.tiny_configs import create_tiny_clt_config 8 | from tests.helpers.fake_hdf5 import make_tiny_chunk_files 9 | 10 | 11 | def get_available_devices(): 12 | """Returns available devices, including cpu, mps, and cuda if available.""" 13 | devices = ["cpu"] 14 | if torch.cuda.is_available(): 15 | devices.append("cuda") 16 | if torch.backends.mps.is_available(): 17 | devices.append("mps") 18 | return devices 19 | 20 | 21 | DEVICES = get_available_devices() 22 | 23 | 24 | @pytest.fixture(params=DEVICES) 25 | def device(request): 26 | """Fixture to iterate over all available devices.""" 27 | return torch.device(request.param) 28 | 29 | 30 | @pytest.fixture 31 | def tmp_local_dataset(tmp_path: Path) -> Path: 32 | """ 33 | Creates a temporary local dataset directory with metadata, a manifest, 34 | and a dummy HDF5 chunk file for testing activation stores. 35 | """ 36 | dataset_path = tmp_path / "tiny_dataset" 37 | dataset_path.mkdir() 38 | 39 | # --- Configs --- 40 | clt_config = create_tiny_clt_config(num_layers=2, d_model=8) 41 | # The tokens here must match the chunk file generation 42 | n_tokens_per_chunk = 32 43 | num_chunks = 2 44 | 45 | # --- Create Fake Data --- 46 | make_tiny_chunk_files( 47 | dataset_path, 48 | num_chunks=num_chunks, 49 | n_layers=clt_config.num_layers, 50 | n_tokens=n_tokens_per_chunk, 51 | d_model=clt_config.d_model, 52 | dtype=np.float16, 53 | ) 54 | 55 | # --- Create Metadata --- 56 | metadata = { 57 | "num_layers": clt_config.num_layers, 58 | "d_model": clt_config.d_model, 59 | "total_tokens": n_tokens_per_chunk * num_chunks, 60 | "chunk_tokens": n_tokens_per_chunk, 61 | "dtype": "float16", 62 | } 63 | with open(dataset_path / "metadata.json", "w") as f: 64 | json.dump(metadata, f) 65 | 66 | # --- Create Manifest (legacy 2-field format) --- 67 | manifest_rows = [] 68 | for chunk_id in range(num_chunks): 69 | for row_id in range(n_tokens_per_chunk): 70 | manifest_rows.append([chunk_id, row_id]) 71 | 72 | manifest_arr = np.array(manifest_rows, dtype=np.uint32) 73 | manifest_arr.tofile(dataset_path / "index.bin") 74 | 75 | return dataset_path 76 | -------------------------------------------------------------------------------- /tests/helpers/tiny_configs.py: -------------------------------------------------------------------------------- 1 | """Minimal, fast-running configs for testing.""" 2 | 3 | from typing import Literal, Optional 4 | from clt.config import CLTConfig, TrainingConfig 5 | 6 | ActivationFn = Literal["jumprelu", "relu", "batchtopk", "topk"] 7 | SparsitySchedule = Literal["linear", "delayed_linear"] 8 | ActivationSource = Literal["local_manifest", "remote"] 9 | Precision = Literal["fp32", "fp16", "bf16"] 10 | ActivationDtype = Literal["bfloat16", "float16", "float32"] 11 | 12 | 13 | def create_tiny_clt_config( 14 | num_layers: int = 2, 15 | num_features: int = 8, 16 | d_model: int = 4, 17 | activation_fn: ActivationFn = "relu", 18 | ) -> CLTConfig: 19 | """Creates a minimal CLTConfig for fast tests.""" 20 | return CLTConfig( 21 | num_layers=num_layers, 22 | num_features=num_features, 23 | d_model=d_model, 24 | activation_fn=activation_fn, 25 | # Keep other params at default for simplicity unless needed 26 | ) 27 | 28 | 29 | def create_tiny_training_config( 30 | training_steps: int = 10, 31 | train_batch_size_tokens: int = 16, 32 | learning_rate: float = 1e-4, 33 | sparsity_lambda: float = 0.01, 34 | sparsity_lambda_schedule: SparsitySchedule = "linear", 35 | sparsity_lambda_delay_frac: float = 0.0, 36 | preactivation_coef: float = 0.0, 37 | eval_interval: int = 1000, 38 | checkpoint_interval: int = 1000, 39 | activation_source: ActivationSource = "local_manifest", 40 | activation_path: Optional[str] = None, 41 | activation_dtype: ActivationDtype = "bfloat16", 42 | precision: Precision = "fp32", 43 | dead_feature_window: int = 1000000, # Set very high to disable dead neuron tracking 44 | ) -> TrainingConfig: 45 | """Creates a minimal TrainingConfig for fast tests.""" 46 | return TrainingConfig( 47 | training_steps=training_steps, 48 | train_batch_size_tokens=train_batch_size_tokens, 49 | learning_rate=learning_rate, 50 | sparsity_lambda=sparsity_lambda, 51 | sparsity_lambda_schedule=sparsity_lambda_schedule, 52 | sparsity_lambda_delay_frac=sparsity_lambda_delay_frac, 53 | preactivation_coef=preactivation_coef, 54 | eval_interval=eval_interval, 55 | checkpoint_interval=checkpoint_interval, 56 | activation_source=activation_source, 57 | activation_path=activation_path, 58 | activation_dtype=activation_dtype, 59 | precision=precision, 60 | dead_feature_window=dead_feature_window, 61 | # Keep other params at default 62 | ) 63 | -------------------------------------------------------------------------------- /tests/unit/data/README.md: -------------------------------------------------------------------------------- 1 | # Data Integrity Tests 2 | 3 | This directory contains comprehensive tests to ensure data integrity across the activation generation and retrieval pipeline. 4 | 5 | ## Background 6 | 7 | We've experienced data mixup issues in the past: 8 | 1. **Lexicographic vs Numerical Ordering**: Layers were sorted as strings (layer_10, layer_2, layer_20) instead of numerically 9 | 2. **Layer Data Corruption**: Similar issues with layer ordering causing data from one layer to be associated with another 10 | 11 | ## Test Coverage 12 | 13 | ### `test_data_integrity.py` 14 | 15 | Comprehensive test suite covering: 16 | 17 | 1. **Layer Ordering** (`test_layer_ordering_numerical_not_lexicographic`) 18 | - Verifies layers are ordered numerically (1, 2, 10, 20, 100) not lexicographically 19 | - Creates layers that would be misordered if sorted as strings 20 | - Validates both HDF5 structure and actual data values 21 | 22 | 2. **Normalization Application** (`test_normalization_application_correctness`) 23 | - Tests that normalization statistics are correctly applied during retrieval 24 | - Creates data with known mean/std, then verifies normalized output 25 | - Ensures each layer's statistics are applied to the correct layer 26 | 27 | 3. **Cross-Chunk Token Ordering** (`test_cross_chunk_token_ordering`) 28 | - Verifies token ordering is preserved across chunk boundaries 29 | - Uses deterministic patterns to track tokens across multiple chunks 30 | - Ensures no tokens are duplicated or lost 31 | 32 | 4. **Manifest Format Compatibility** (`test_manifest_format_compatibility`) 33 | - Tests both legacy 2-field and new 3-field manifest formats 34 | - Ensures backward compatibility with existing datasets 35 | 36 | ### `test_local_activation_store.py` 37 | 38 | Includes additional test: 39 | - **Layer Data Integrity** (`test_layer_data_integrity`) 40 | - Verifies each layer contains distinct, non-mixed data 41 | - Checks value ranges are layer-specific 42 | - Ensures targets = inputs + 1 relationship is preserved 43 | 44 | ## Running the Tests 45 | 46 | Run all data integrity tests: 47 | ```bash 48 | pytest tests/unit/data/test_data_integrity.py -v 49 | ``` 50 | 51 | Run specific test: 52 | ```bash 53 | pytest tests/unit/data/test_data_integrity.py::TestDataIntegrity::test_layer_ordering_numerical_not_lexicographic -v 54 | ``` 55 | 56 | Run with coverage: 57 | ```bash 58 | pytest tests/unit/data/test_data_integrity.py --cov=clt.activation_generation --cov=clt.training.data -v 59 | ``` 60 | 61 | ## What These Tests Prevent 62 | 63 | 1. **Silent Data Corruption**: Detects if layers get mixed up during generation or retrieval 64 | 2. **Normalization Errors**: Ensures statistics from one layer aren't applied to another 65 | 3. **Token Loss**: Verifies all tokens are accessible and in correct order 66 | 4. **Format Regressions**: Maintains compatibility with existing activation datasets 67 | 68 | ## Adding New Tests 69 | 70 | When adding features that touch activation generation or retrieval: 71 | 1. Add tests that use deterministic, verifiable patterns 72 | 2. Test edge cases (empty chunks, single token, many layers) 73 | 3. Verify both structure (metadata, manifests) and actual data values 74 | 4. Consider cross-component interactions -------------------------------------------------------------------------------- /clt_server/tests/core/test_config.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from unittest.mock import MagicMock # Removed patch import 3 | 4 | 5 | # --- Helper to reload settings --- # 6 | def reload_settings(): 7 | """Forces reload of the settings module to pick up env var changes.""" 8 | import importlib 9 | from clt_server.core import config 10 | 11 | importlib.reload(config) 12 | return config.settings 13 | 14 | 15 | # --- Test Cases --- # 16 | 17 | 18 | def test_default_settings(monkeypatch): 19 | """Test that default settings are loaded correctly.""" 20 | # Ensure no env vars are set that might override defaults 21 | monkeypatch.delenv("STORAGE_BASE_DIR", raising=False) 22 | monkeypatch.delenv("CHUNK_RETRY_ATTEMPTS", raising=False) 23 | monkeypatch.delenv("LOG_LEVEL", raising=False) 24 | 25 | # Reload settings to ensure defaults are picked up 26 | settings = reload_settings() 27 | 28 | assert settings.STORAGE_BASE_DIR == Path("./server_data") 29 | assert settings.CHUNK_RETRY_ATTEMPTS == 3 30 | assert settings.LOG_LEVEL == "info" 31 | 32 | 33 | def test_override_settings_from_env(monkeypatch, tmp_path): 34 | """Test that settings can be overridden using environment variables.""" 35 | test_dir = tmp_path / "env_test_data" 36 | test_retries = 5 37 | test_log_level = "debug" 38 | 39 | monkeypatch.setenv("STORAGE_BASE_DIR", str(test_dir)) 40 | monkeypatch.setenv("CHUNK_RETRY_ATTEMPTS", str(test_retries)) 41 | monkeypatch.setenv("LOG_LEVEL", test_log_level) 42 | 43 | settings = reload_settings() 44 | 45 | assert settings.STORAGE_BASE_DIR == test_dir 46 | assert settings.CHUNK_RETRY_ATTEMPTS == test_retries 47 | assert settings.LOG_LEVEL == test_log_level 48 | 49 | # The directory should be created if overridden and doesn't exist 50 | assert test_dir.exists() 51 | assert test_dir.is_dir() 52 | 53 | 54 | def test_storage_directory_creation(monkeypatch, tmp_path): 55 | """Test that the storage directory is created if it doesn't exist.""" 56 | test_dir = tmp_path / "creation_test" 57 | assert not test_dir.exists() 58 | 59 | monkeypatch.setenv("STORAGE_BASE_DIR", str(test_dir)) 60 | reload_settings() # This triggers the Settings instantiation and directory check 61 | 62 | assert test_dir.exists() 63 | assert test_dir.is_dir() 64 | 65 | 66 | def test_storage_directory_exists(monkeypatch, tmp_path): 67 | """Test that mkdir is not called if the directory already exists.""" 68 | test_dir = tmp_path / "existing_dir" 69 | # Create the directory BEFORE the settings are potentially reloaded 70 | test_dir.mkdir(parents=True, exist_ok=True) 71 | assert test_dir.exists() # Verify directory creation 72 | 73 | # Set the env var 74 | monkeypatch.setenv("STORAGE_BASE_DIR", str(test_dir)) 75 | 76 | # Explicitly patch Path.mkdir only during the reload_settings call 77 | mock_mkdir = MagicMock() 78 | monkeypatch.setattr(Path, "mkdir", mock_mkdir) 79 | 80 | # Reload settings - this should NOT call the mocked mkdir 81 | try: 82 | reload_settings() 83 | finally: 84 | # Important: undo the patch even if reload_settings fails 85 | monkeypatch.undo() 86 | 87 | # Assert that the mocked mkdir was NOT called 88 | mock_mkdir.assert_not_called() 89 | -------------------------------------------------------------------------------- /tests/integration/test_single_gpu_training_step.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from pathlib import Path 4 | 5 | from clt.config import CLTConfig, TrainingConfig 6 | from clt.training.trainer import CLTTrainer 7 | from tests.helpers.tiny_configs import create_tiny_clt_config, create_tiny_training_config 8 | 9 | 10 | @pytest.fixture 11 | def tiny_clt_config() -> CLTConfig: 12 | return create_tiny_clt_config(num_layers=2, d_model=8, num_features=16) 13 | 14 | 15 | @pytest.fixture 16 | def tiny_training_config(tmp_local_dataset: Path) -> TrainingConfig: 17 | # Set training steps to a small number for a quick test 18 | return create_tiny_training_config( 19 | training_steps=5, 20 | train_batch_size_tokens=16, 21 | eval_interval=10_000, # Disable eval for this test 22 | checkpoint_interval=10_000, # Disable checkpointing for this test 23 | activation_source="local_manifest", # Specify the source 24 | activation_path=str(tmp_local_dataset), # Use the fixture 25 | activation_dtype="float32", # Use float32 for CPU test consistency 26 | precision="fp32", # Use fp32 for CPU test consistency 27 | ) 28 | 29 | 30 | class TestSingleDeviceTraining: 31 | def test_training_loop_runs(self, tiny_clt_config: CLTConfig, tiny_training_config: TrainingConfig, tmp_path: Path): 32 | """ 33 | Test that the basic training loop can run for a few steps on a single device 34 | without crashing. This is a basic integration test. 35 | """ 36 | log_dir = tmp_path / "test_logs" 37 | 38 | trainer = CLTTrainer( 39 | clt_config=tiny_clt_config, 40 | training_config=tiny_training_config, 41 | log_dir=str(log_dir), 42 | device="cpu", 43 | distributed=False, 44 | ) 45 | 46 | # Check that model parameters have no gradients initially 47 | for p in trainer.model.parameters(): 48 | if p.requires_grad: 49 | assert p.grad is None 50 | 51 | # Run the training for the configured number of steps (5) 52 | trained_model = trainer.train() 53 | 54 | # --- Assertions --- 55 | # 1. The model returned should be the trainer's model 56 | assert trained_model is trainer.model 57 | 58 | # 2. Model parameters should have gradients after training 59 | grads_found = False 60 | for p in trainer.model.parameters(): 61 | if p.requires_grad and p.grad is not None: 62 | # Check that gradients are not all zero 63 | assert torch.any(p.grad != 0) 64 | grads_found = True 65 | assert grads_found, "No gradients were found on model parameters after training." 66 | 67 | # 3. Check that log files were created 68 | # The trainer logs metrics and checkpoints 69 | assert log_dir.exists() 70 | # Check for final checkpoint directory created by the trainer 71 | final_checkpoint_dir = log_dir / "final" 72 | assert final_checkpoint_dir.exists() 73 | # Check that the log directory was created, metrics file is optional 74 | assert log_dir.exists() 75 | 76 | # 4. Checkpoint saving was disabled, so the step_4 dir should NOT exist 77 | step_4_checkpoint_dir = log_dir / "step_4" 78 | assert not step_4_checkpoint_dir.exists(), "Checkpoint at step 4 was created but should have been disabled." 79 | -------------------------------------------------------------------------------- /clt/models/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import torch 3 | import torch.nn as nn 4 | from typing import Any, Dict, Optional, Tuple, TypeVar, Type 5 | 6 | from clt.config import CLTConfig 7 | 8 | 9 | # Generic type for Transcoder subclasses 10 | T = TypeVar("T", bound="BaseTranscoder") 11 | 12 | 13 | class BaseTranscoder(nn.Module, ABC): 14 | """Abstract base class for all transcoders.""" 15 | 16 | config: CLTConfig 17 | 18 | def __init__(self, config: CLTConfig): 19 | """Initialize the transcoder with the given configuration. 20 | 21 | Args: 22 | config: Configuration for the transcoder. 23 | """ 24 | super().__init__() 25 | self.config = config 26 | 27 | @abstractmethod 28 | def encode(self, x: torch.Tensor, layer_idx: int) -> torch.Tensor: 29 | """Encode the input activations at the specified layer. 30 | 31 | Args: 32 | x: Input activations [batch_size, seq_len, d_model] 33 | layer_idx: Index of the layer 34 | 35 | Returns: 36 | Encoded activations 37 | """ 38 | pass 39 | 40 | @abstractmethod 41 | def decode(self, a: Dict[int, torch.Tensor], layer_idx: int) -> torch.Tensor: 42 | """Decode the feature activations to reconstruct outputs at the specified layer. 43 | 44 | Args: 45 | a: Dictionary mapping layer indices to feature activations 46 | layer_idx: Index of the layer to reconstruct outputs for 47 | 48 | Returns: 49 | Reconstructed outputs 50 | """ 51 | pass 52 | 53 | @abstractmethod 54 | def forward(self, inputs: Dict[int, torch.Tensor]) -> Dict[int, torch.Tensor]: 55 | """Process inputs through the transcoder model. 56 | 57 | Args: 58 | inputs: Dictionary mapping layer indices to input activations 59 | 60 | Returns: 61 | Dictionary mapping layer indices to reconstructed outputs 62 | """ 63 | pass 64 | 65 | def save(self, path: str) -> None: 66 | """Save the transcoder model to the specified path. 67 | 68 | Args: 69 | path: Path to save the model 70 | """ 71 | # Ensure config is serializable (e.g., if it's a dataclass) 72 | # If config is a dataclass, convert to dict 73 | config_dict = ( 74 | self.config.__dict__ if hasattr(self.config, "__dict__") else self.config 75 | ) 76 | 77 | checkpoint = {"config": config_dict, "state_dict": self.state_dict()} 78 | torch.save(checkpoint, path) 79 | 80 | @classmethod 81 | def load(cls: Type[T], path: str, device: Optional[torch.device] = None) -> T: 82 | """Load a transcoder model from the specified path. 83 | 84 | Args: 85 | path: Path to load the model from 86 | device: Device to load the model to 87 | 88 | Returns: 89 | Loaded transcoder model 90 | """ 91 | checkpoint = torch.load(path, map_location=device) 92 | 93 | # Instantiate the config object from the dictionary 94 | config_dict = checkpoint["config"] 95 | # Assuming the config class is CLTConfig for now 96 | # A more robust solution might store the config class name 97 | config = CLTConfig(**config_dict) 98 | 99 | model = cls(config) 100 | model.load_state_dict(checkpoint["state_dict"]) 101 | return model 102 | -------------------------------------------------------------------------------- /scripts/compare_norm_stats.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Compare two norm_stats.json files layer-by-layer. 3 | 4 | Usage: 5 | python scripts/compare_norm_stats.py path/to/a/norm_stats.json path/to/b/norm_stats.json 6 | """ 7 | 8 | from __future__ import annotations 9 | 10 | import argparse 11 | import json 12 | from pathlib import Path 13 | from typing import Dict, Any, Tuple 14 | import numpy as np 15 | 16 | 17 | def _load_norm(path: Path) -> Dict[int, Dict[str, Any]]: 18 | with open(path) as f: 19 | raw = json.load(f) 20 | # cast layer keys to int for convenient lookup 21 | norm: Dict[int, Dict[str, Any]] = {int(k): v for k, v in raw.items()} 22 | return norm 23 | 24 | 25 | def _diff_stats( 26 | a: Dict[int, Dict[str, Any]], b: Dict[int, Dict[str, Any]] 27 | ) -> Dict[int, Dict[str, Tuple[float, float]]]: 28 | """Return {layer: {"inputs_mean": (abs_diff, rel_diff%), ...}}""" 29 | out: Dict[int, Dict[str, Tuple[float, float]]] = {} 30 | for layer in sorted(set(a) | set(b)): 31 | layer_res: Dict[str, Tuple[float, float]] = {} 32 | for section in ("inputs", "targets"): 33 | for field in ("mean", "std"): 34 | key = f"{section}_{field}" 35 | if layer in a and layer in b and section in a[layer] and section in b[layer]: 36 | vec_a = np.asarray(a[layer][section][field], dtype=np.float64) 37 | vec_b = np.asarray(b[layer][section][field], dtype=np.float64) 38 | if vec_a.shape != vec_b.shape: 39 | layer_res[key] = (float("nan"), float("nan")) 40 | continue 41 | abs_diff = float(np.mean(np.abs(vec_a - vec_b))) 42 | denom = np.mean(np.abs(vec_a)) + 1e-12 43 | rel_diff = float((abs_diff / denom) * 100.0) 44 | layer_res[key] = (abs_diff, rel_diff) 45 | else: 46 | layer_res[key] = (float("nan"), float("nan")) 47 | out[layer] = layer_res 48 | return out 49 | 50 | 51 | def main(): 52 | parser = argparse.ArgumentParser(description="Compare two norm_stats.json files") 53 | parser.add_argument("file_a", type=Path) 54 | parser.add_argument("file_b", type=Path) 55 | parser.add_argument( 56 | "--top-n", type=int, default=5, help="Show detailed stats for top-N layers with biggest mean differences" 57 | ) 58 | args = parser.parse_args() 59 | 60 | norm_a = _load_norm(args.file_a) 61 | norm_b = _load_norm(args.file_b) 62 | 63 | diffs = _diff_stats(norm_a, norm_b) 64 | 65 | print(f"Compared {len(diffs)} layers\n") 66 | worst_layers = sorted(diffs.items(), key=lambda kv: np.nan_to_num(kv[1]["inputs_mean"][0], nan=0.0), reverse=True) 67 | 68 | print("Layer | inputs_mean | targets_mean | inputs_std | targets_std (abs diff / % rel diff)") 69 | print("------- | ------------- | -------------- | ------------ | ------------") 70 | for layer, stats in worst_layers[: args.top_n]: 71 | im = stats["inputs_mean"] 72 | tm = stats["targets_mean"] 73 | isd = stats["inputs_std"] 74 | tsd = stats["targets_std"] 75 | print( 76 | f"{layer:5d} | {im[0]:10.4g} / {im[1]:6.2f}% | {tm[0]:10.4g} / {tm[1]:6.2f}% | " 77 | f"{isd[0]:10.4g} / {isd[1]:6.2f}% | {tsd[0]:10.4g} / {tsd[1]:6.2f}%" 78 | ) 79 | 80 | print( 81 | "\nTip: large relative differences (>5-10 %) mean you should use the training norm_stats.json during evaluation." 82 | ) 83 | 84 | 85 | if __name__ == "__main__": 86 | main() 87 | -------------------------------------------------------------------------------- /clt/training/metric_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import logging 4 | from typing import Dict, Optional, Any, Union # Added Union 5 | 6 | # Forward declarations for type hinting 7 | if False: # TYPE_CHECKING 8 | from clt.training.wandb_logger import WandBLogger, DummyWandBLogger 9 | from clt.config import TrainingConfig 10 | 11 | # from clt.training.losses import LossManager # Not directly needed if lambda is passed 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class MetricLogger: 17 | def __init__( 18 | self, 19 | distributed: bool, 20 | rank: int, 21 | log_dir: str, 22 | wandb_logger: Union["WandBLogger", "DummyWandBLogger"], 23 | training_config: "TrainingConfig", 24 | world_size: int, 25 | ): 26 | self.distributed = distributed 27 | self.rank = rank 28 | self.log_dir = log_dir 29 | self.wandb_logger = wandb_logger 30 | self.training_config = training_config 31 | self.world_size = world_size 32 | # self.loss_manager = loss_manager # Not storing loss_manager 33 | 34 | self.metrics: Dict[str, list] = { 35 | "train_losses": [], 36 | "eval_metrics": [], # This will be populated by the trainer calling a method here 37 | } 38 | 39 | def log_training_step( 40 | self, 41 | step: int, 42 | loss_dict: Dict[str, float], 43 | current_lr: Optional[float], 44 | current_sparsity_lambda: Optional[float], # Pass lambda directly 45 | ): 46 | """Log training metrics for a step, including LR and lambda.""" 47 | # All ranks might update their local copy of train_losses for potential future needs, 48 | # but only rank 0 saves/logs to WandB. 49 | self.metrics["train_losses"].append({"step": step, **loss_dict}) 50 | 51 | if not self.distributed or self.rank == 0: 52 | total_tokens_processed = self.training_config.train_batch_size_tokens * (step + 1) 53 | 54 | self.wandb_logger.log_step( 55 | step, 56 | loss_dict, 57 | lr=current_lr, 58 | sparsity_lambda=current_sparsity_lambda, # Use passed lambda 59 | total_tokens_processed=total_tokens_processed, 60 | ) 61 | 62 | log_interval = self.training_config.log_interval 63 | if step % log_interval == 0: 64 | self._save_metrics_to_disk() # Renamed for clarity 65 | 66 | def log_evaluation_metrics(self, step: int, eval_metrics_dict: Dict[str, Any]): 67 | """Logs evaluation metrics. Assumes only called on rank 0 if distributed.""" 68 | if not self.distributed or self.rank == 0: 69 | self.metrics["eval_metrics"].append({"step": step, **eval_metrics_dict}) 70 | self.wandb_logger.log_evaluation(step, eval_metrics_dict) 71 | self._save_metrics_to_disk() # Save after eval as well 72 | 73 | def _save_metrics_to_disk(self): 74 | """Save all tracked metrics to disk. Assumes only called on rank 0 if distributed.""" 75 | if self.distributed and self.rank != 0: 76 | # This check is a safeguard, but typically this method should only be called by rank 0 logic. 77 | return 78 | 79 | metrics_path = os.path.join(self.log_dir, "metrics.json") 80 | try: 81 | with open(metrics_path, "w") as f: 82 | json.dump(self.metrics, f, indent=2, default=str) 83 | except Exception as e: 84 | logger.warning(f"Rank {self.rank}: Failed to save metrics to {metrics_path}: {e}") # Use logger 85 | 86 | def get_metrics_history(self) -> Dict[str, list]: 87 | """Returns the history of all metrics.""" 88 | return self.metrics 89 | -------------------------------------------------------------------------------- /tests/integration/distributed_training_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script is designed to be launched by torchrun for distributed training tests. 3 | It initializes a trainer, runs a few steps, and saves its final model state. 4 | The calling test can then inspect the saved states for consistency. 5 | """ 6 | 7 | import sys 8 | from pathlib import Path 9 | 10 | # Ensure the project root is in the python path 11 | project_root = Path(__file__).resolve().parents[2] 12 | if str(project_root) not in sys.path: 13 | sys.path.insert(0, str(project_root)) 14 | 15 | import torch 16 | import os 17 | import argparse 18 | import json 19 | import numpy as np 20 | 21 | from clt.training.trainer import CLTTrainer 22 | from tests.helpers.tiny_configs import create_tiny_clt_config, create_tiny_training_config 23 | from tests.helpers.fake_hdf5 import make_tiny_chunk_files 24 | 25 | 26 | def main(): 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument("--output-dir", type=str, required=True, help="Directory to save final model states.") 29 | args = parser.parse_args() 30 | 31 | # --- Get Distributed Info --- 32 | rank = int(os.environ.get("RANK", "0")) 33 | 34 | # --- Create a temporary dataset for this run (shared across ranks) --- 35 | dataset_path = Path(args.output_dir) / "test_dataset" 36 | 37 | # Parameters for tiny dataset 38 | num_chunks = 2 39 | num_layers = 2 40 | d_model = 8 41 | n_tokens_per_chunk = 64 42 | 43 | # Let every rank attempt to create the dataset; implementation is idempotent 44 | make_tiny_chunk_files( 45 | path=dataset_path, 46 | num_chunks=num_chunks, 47 | n_layers=num_layers, 48 | n_tokens=n_tokens_per_chunk, 49 | d_model=d_model, 50 | ) 51 | 52 | # Rank 0 writes metadata and manifest (lightweight JSON/bin files) 53 | if rank == 0: 54 | metadata = { 55 | "num_layers": num_layers, 56 | "d_model": d_model, 57 | "total_tokens": n_tokens_per_chunk * num_chunks, 58 | "chunk_tokens": n_tokens_per_chunk, 59 | "dtype": "float16", 60 | } 61 | (dataset_path / "metadata.json").write_text(json.dumps(metadata)) 62 | 63 | # Create simple legacy 2-field manifest 64 | manifest_rows = [] 65 | for cid in range(num_chunks): 66 | for rid in range(n_tokens_per_chunk): 67 | manifest_rows.append([cid, rid]) 68 | manifest_arr = np.asarray(manifest_rows, dtype=np.uint32) 69 | manifest_arr.tofile(dataset_path / "index.bin") 70 | 71 | # --- Configuration --- 72 | clt_config = create_tiny_clt_config(num_layers=2, d_model=8, num_features=16) 73 | 74 | # When distributed=True, the trainer correctly sets shard_data=False for the activation store, 75 | # which is required for tensor parallelism. 76 | training_config = create_tiny_training_config( 77 | training_steps=5, 78 | train_batch_size_tokens=16, 79 | activation_source="local_manifest", 80 | activation_path=str(dataset_path), 81 | activation_dtype="float32", 82 | precision="fp32", 83 | ) 84 | 85 | # --- Initialize and run trainer --- 86 | # The trainer will automatically initialize the process group based on env variables. 87 | trainer = CLTTrainer( 88 | clt_config=clt_config, 89 | training_config=training_config, 90 | log_dir=str(Path(args.output_dir) / f"rank_{rank}_logs"), 91 | distributed=True, 92 | ) 93 | 94 | trainer.train() 95 | 96 | # --- Save final model state for verification --- 97 | output_path = Path(args.output_dir) / f"rank_{rank}_final_model.pt" 98 | torch.save(trainer.model.state_dict(), output_path) 99 | 100 | print(f"Rank {rank} finished and saved model to {output_path}") 101 | 102 | 103 | if __name__ == "__main__": 104 | main() 105 | -------------------------------------------------------------------------------- /tests/integration/test_clt_end_to_end.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from clt.config import CLTConfig 5 | from clt.models.clt import CrossLayerTranscoder 6 | 7 | 8 | def get_available_devices(): 9 | """Returns available devices, including cpu, mps, and cuda if available.""" 10 | devices = ["cpu"] 11 | if torch.cuda.is_available(): 12 | devices.append("cuda") 13 | if torch.backends.mps.is_available(): 14 | devices.append("mps") 15 | return devices 16 | 17 | 18 | DEVICES = get_available_devices() 19 | 20 | 21 | @pytest.fixture(params=DEVICES) 22 | def device(request): 23 | """Fixture to iterate over all available devices.""" 24 | return torch.device(request.param) 25 | 26 | 27 | @pytest.fixture 28 | def clt_config(): 29 | """Provides a basic CLTConfig for end-to-end testing.""" 30 | return CLTConfig( 31 | num_layers=2, 32 | d_model=8, 33 | num_features=16, 34 | activation_fn="relu", # Use simple ReLU for gradient checking 35 | ) 36 | 37 | 38 | @pytest.fixture 39 | def clt_model(clt_config, device): 40 | """Provides a CrossLayerTranscoder instance for integration tests.""" 41 | model = CrossLayerTranscoder( 42 | config=clt_config, 43 | process_group=None, 44 | device=device, 45 | ) 46 | # Ensure all parameters have requires_grad=True for the backward pass test 47 | for param in model.parameters(): 48 | param.requires_grad = True 49 | return model.to(device) 50 | 51 | 52 | @pytest.fixture 53 | def sample_inputs(clt_config, device): 54 | """Provides a sample input dictionary with consistent token counts.""" 55 | total_tokens = 20 56 | return { 57 | 0: torch.randn(total_tokens, clt_config.d_model, device=device), 58 | 1: torch.randn(total_tokens, clt_config.d_model, device=device), 59 | } 60 | 61 | 62 | class TestCLTEndToEnd: 63 | def test_forward_backward_pass(self, clt_model, sample_inputs): 64 | """ 65 | Tests a full forward and backward pass to ensure gradients are computed. 66 | """ 67 | # --- Forward Pass --- 68 | reconstructions = clt_model.forward(sample_inputs) 69 | 70 | # --- Loss Calculation --- 71 | # A simple MSE loss between the reconstructions and the original inputs 72 | loss = torch.tensor(0.0, device=clt_model.device, dtype=torch.float32) 73 | for layer_idx, recon_tensor in reconstructions.items(): 74 | original_tensor = sample_inputs[layer_idx] 75 | loss += torch.mean((recon_tensor - original_tensor) ** 2) 76 | 77 | # --- Backward Pass --- 78 | try: 79 | loss.backward() 80 | except Exception as e: 81 | pytest.fail(f"Backward pass failed with exception: {e}") 82 | 83 | # --- Gradient Check --- 84 | # Check that some gradients have been computed. We check a few key parameters. 85 | # Encoder weights for layer 0 86 | assert clt_model.encoder_module.encoders[0].weight.grad is not None 87 | assert torch.all(torch.isfinite(clt_model.encoder_module.encoders[0].weight.grad)) 88 | assert not torch.all(clt_model.encoder_module.encoders[0].weight.grad == 0) 89 | 90 | # Decoder weights for 0->1 91 | decoder_key = "0->1" 92 | assert clt_model.decoder_module.decoders[decoder_key].weight.grad is not None 93 | assert torch.all(torch.isfinite(clt_model.decoder_module.decoders[decoder_key].weight.grad)) 94 | assert not torch.all(clt_model.decoder_module.decoders[decoder_key].weight.grad == 0) 95 | 96 | # Decoder bias for 1->1 97 | decoder_key = "1->1" 98 | if clt_model.decoder_module.decoders[decoder_key].bias_param is not None: 99 | assert clt_model.decoder_module.decoders[decoder_key].bias_param.grad is not None 100 | assert torch.all(torch.isfinite(clt_model.decoder_module.decoders[decoder_key].bias_param.grad)) 101 | # Note: Bias gradients can sometimes be zero in simple cases, so we don't assert non-zero 102 | -------------------------------------------------------------------------------- /tests/integration/test_distributed_training.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import subprocess 4 | from pathlib import Path 5 | 6 | 7 | class TestDistributedTraining: 8 | def test_tensor_parallel_training_runs(self, tmp_path: Path): 9 | """ 10 | Tests that training with Tensor Parallelism (distributed) runs successfully 11 | and produces model checkpoints from each rank. 12 | 13 | Note: With tensor parallelism, different ranks hold different shards of 14 | the model weights, so we don't expect identical parameters. 15 | """ 16 | world_size = 2 17 | 18 | # --- Run the distributed worker script using torchrun --- 19 | cmd = [ 20 | "torchrun", 21 | "--nproc_per_node", 22 | str(world_size), 23 | "-m", 24 | "tests.integration.distributed_training_worker", 25 | "--output-dir", 26 | str(tmp_path), 27 | ] 28 | 29 | try: 30 | process = subprocess.run(cmd, check=True, capture_output=True, text=True, timeout=120) 31 | print("STDOUT:", process.stdout) 32 | print("STDERR:", process.stderr) 33 | except subprocess.CalledProcessError as e: 34 | print("STDOUT:", e.stdout) 35 | print("STDERR:", e.stderr) 36 | pytest.fail(f"Distributed worker script failed with exit code {e.returncode}.") 37 | except subprocess.TimeoutExpired as e: 38 | print("STDOUT:", e.stdout) 39 | print("STDERR:", e.stderr) 40 | pytest.fail("Distributed worker script timed out.") 41 | 42 | # --- Verification --- 43 | # Check that each rank saved its model shard 44 | rank0_state_path = tmp_path / "rank_0_final_model.pt" 45 | rank1_state_path = tmp_path / "rank_1_final_model.pt" 46 | 47 | assert rank0_state_path.exists(), "Rank 0 did not save a model file." 48 | assert rank1_state_path.exists(), "Rank 1 did not save a model file." 49 | 50 | # Load the states to verify they're valid PyTorch checkpoints 51 | rank0_state = torch.load(rank0_state_path, map_location="cpu") 52 | rank1_state = torch.load(rank1_state_path, map_location="cpu") 53 | 54 | # Basic sanity checks 55 | assert len(rank0_state) > 0, "Rank 0 state dict is empty" 56 | assert len(rank1_state) > 0, "Rank 1 state dict is empty" 57 | 58 | # Check that both ranks have the same keys (structure) 59 | assert set(rank0_state.keys()) == set( 60 | rank1_state.keys() 61 | ), "Model structure differs between ranks (different keys in state dict)" 62 | 63 | # Verify that key tensor parallel layers have expected shapes 64 | # With world_size=2, each rank should have half the features 65 | for key in rank0_state: 66 | if "encoder" in key and "weight" in key: 67 | # ColumnParallelLinear: out_features is sharded 68 | rank0_shape = rank0_state[key].shape 69 | rank1_shape = rank1_state[key].shape 70 | assert rank0_shape == rank1_shape, f"Shape mismatch for {key}: rank0={rank0_shape}, rank1={rank1_shape}" 71 | print(f"✓ {key}: shape {rank0_shape} (sharded across {world_size} ranks)") 72 | elif "decoder" in key and "weight" in key: 73 | # RowParallelLinear: in_features is sharded 74 | rank0_shape = rank0_state[key].shape 75 | rank1_shape = rank1_state[key].shape 76 | assert rank0_shape == rank1_shape, f"Shape mismatch for {key}: rank0={rank0_shape}, rank1={rank1_shape}" 77 | print(f"✓ {key}: shape {rank0_shape} (sharded across {world_size} ranks)") 78 | 79 | # Check that training logs were created 80 | rank0_log_dir = tmp_path / "rank_0_logs" 81 | rank1_log_dir = tmp_path / "rank_1_logs" 82 | assert rank0_log_dir.exists(), "Rank 0 log directory not created" 83 | assert rank1_log_dir.exists(), "Rank 1 log directory not created" 84 | 85 | print("\nDistributed training test passed!") 86 | print("- Both ranks completed training successfully") 87 | print(f"- Model checkpoints saved with {len(rank0_state)} parameters each") 88 | print(f"- Log directories created at {rank0_log_dir} and {rank1_log_dir}") 89 | -------------------------------------------------------------------------------- /tests/unit/models/test_clt_encode_decode.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from clt.config import CLTConfig 5 | from clt.models.clt import CrossLayerTranscoder 6 | 7 | 8 | def get_available_devices(): 9 | """Returns available devices, including cpu, mps, and cuda if available.""" 10 | devices = ["cpu"] 11 | if torch.cuda.is_available(): 12 | devices.append("cuda") 13 | if torch.backends.mps.is_available(): 14 | devices.append("mps") 15 | return devices 16 | 17 | 18 | DEVICES = get_available_devices() 19 | 20 | 21 | @pytest.fixture(params=DEVICES) 22 | def device(request): 23 | """Fixture to iterate over all available devices.""" 24 | return torch.device(request.param) 25 | 26 | 27 | @pytest.fixture(params=["relu", "jumprelu", "batchtopk"]) 28 | def activation_fn(request): 29 | return request.param 30 | 31 | 32 | @pytest.fixture 33 | def clt_config(activation_fn): 34 | """Provides a basic CLTConfig for testing, parameterized by activation function.""" 35 | return CLTConfig( 36 | num_layers=2, 37 | d_model=8, 38 | num_features=16, 39 | activation_fn=activation_fn, 40 | jumprelu_threshold=0.5, 41 | batchtopk_k=4, 42 | ) 43 | 44 | 45 | @pytest.fixture 46 | def clt_model(clt_config, device): 47 | """Provides a CrossLayerTranscoder instance.""" 48 | return CrossLayerTranscoder( 49 | config=clt_config, 50 | process_group=None, 51 | device=device, 52 | ).to(device) 53 | 54 | 55 | @pytest.fixture 56 | def sample_inputs(clt_config, device): 57 | """ 58 | Provides a sample input dictionary for the CLT. 59 | Ensures that the total number of tokens is the same for all layers, 60 | which is a key assumption for batchtopk. 61 | """ 62 | total_tokens = 30 63 | return { 64 | 0: torch.randn(total_tokens, clt_config.d_model, device=device), 65 | 1: torch.randn(total_tokens, clt_config.d_model, device=device), 66 | } 67 | 68 | 69 | class TestCLTEncodeDecode: 70 | def test_get_feature_activations_shapes(self, clt_model, sample_inputs, clt_config): 71 | """Test that get_feature_activations returns activations of the correct shape.""" 72 | activations = clt_model.get_feature_activations(sample_inputs) 73 | 74 | assert isinstance(activations, dict) 75 | assert sorted(activations.keys()) == sorted(sample_inputs.keys()) 76 | 77 | # Check shapes (note that 3D input is flattened in the fixture now) 78 | assert activations[0].shape == (30, clt_config.num_features) 79 | assert activations[1].shape == (30, clt_config.num_features) 80 | 81 | def test_relu_activations_are_non_negative(self, clt_model, sample_inputs): 82 | """Test that ReLU activations are always >= 0.""" 83 | if clt_model.config.activation_fn != "relu": 84 | pytest.skip("Test only for ReLU activation") 85 | 86 | activations = clt_model.get_feature_activations(sample_inputs) 87 | for layer_idx in activations: 88 | assert torch.all(activations[layer_idx] >= 0) 89 | 90 | def test_decode_shapes(self, clt_model, sample_inputs, clt_config): 91 | """Test that decoding from feature activations produces the correct output shape.""" 92 | activations = clt_model.get_feature_activations(sample_inputs) 93 | 94 | # Decode for layer 1, which can see activations from layers 0 and 1 95 | reconstruction = clt_model.decode(activations, layer_idx=1) 96 | 97 | # The output batch dimension should match the input batch dimension for that layer 98 | expected_batch_dim = sample_inputs[1].shape[0] 99 | assert reconstruction.shape == (expected_batch_dim, clt_config.d_model) 100 | 101 | def test_forward_pass_shapes(self, clt_model, sample_inputs, clt_config): 102 | """Test the full forward() method returns a dictionary of reconstructions with correct shapes.""" 103 | reconstructions = clt_model.forward(sample_inputs) 104 | 105 | assert isinstance(reconstructions, dict) 106 | assert sorted(reconstructions.keys()) == sorted(sample_inputs.keys()) 107 | 108 | # Check shapes 109 | assert reconstructions[0].shape == ( 110 | sample_inputs[0].shape[0], 111 | clt_config.d_model, 112 | ) 113 | assert reconstructions[1].shape == ( 114 | sample_inputs[1].shape[0], 115 | clt_config.d_model, 116 | ) 117 | -------------------------------------------------------------------------------- /clt/parallel/ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from torch.distributed import ProcessGroup, ReduceOp, Work 4 | from typing import Optional, List 5 | 6 | # Re-export common ReduceOp for convenience. 7 | # Users can import these from this module (e.g., from clt.parallel.ops import SUM) 8 | SUM = ReduceOp.SUM 9 | AVG = ReduceOp.AVG 10 | PRODUCT = ReduceOp.PRODUCT 11 | MIN = ReduceOp.MIN 12 | MAX = ReduceOp.MAX 13 | BAND = ReduceOp.BAND 14 | BOR = ReduceOp.BOR 15 | BXOR = ReduceOp.BXOR 16 | 17 | 18 | def is_dist_initialized_and_available() -> bool: 19 | """Checks if torch.distributed is available and initialized.""" 20 | return dist.is_available() and dist.is_initialized() 21 | 22 | 23 | def get_rank(group: Optional[ProcessGroup] = None) -> int: 24 | """Returns the rank of the current process in the group. 25 | Returns 0 if distributed is not initialized or not available. 26 | """ 27 | if not is_dist_initialized_and_available(): 28 | return 0 29 | return dist.get_rank(group=group) 30 | 31 | 32 | def get_world_size(group: Optional[ProcessGroup] = None) -> int: 33 | """Returns the world size of the given process group. 34 | Returns 1 if distributed is not initialized or not available. 35 | """ 36 | if not is_dist_initialized_and_available(): 37 | return 1 38 | return dist.get_world_size(group=group) 39 | 40 | 41 | def is_main_process(group: Optional[ProcessGroup] = None) -> bool: 42 | """Checks if the current process is the main process (rank 0).""" 43 | return get_rank(group=group) == 0 44 | 45 | 46 | def all_reduce( 47 | tensor: torch.Tensor, 48 | op: ReduceOp = SUM, # Default to SUM 49 | group: Optional[ProcessGroup] = None, 50 | async_op: bool = False, 51 | ) -> Optional[Work]: 52 | """Reduces the tensor data across all machines. 53 | 54 | Args: 55 | tensor: Input and output of the collective. The function operates in-place. 56 | op: The reduction operation (e.g., ReduceOp.SUM, ReduceOp.PRODUCT). 57 | group: The process group to work on. If None, the default process group will be used. 58 | async_op: Whether this op should be an async op. 59 | 60 | Returns: 61 | A Work object if async_op is True, otherwise None. 62 | Returns None if distributed is not initialized or world_size is 1, as no actual communication occurs. 63 | """ 64 | if not is_dist_initialized_and_available() or get_world_size(group=group) == 1: 65 | return None 66 | return dist.all_reduce(tensor, op=op, group=group, async_op=async_op) 67 | 68 | 69 | def broadcast( 70 | tensor: torch.Tensor, 71 | src: int, 72 | group: Optional[ProcessGroup] = None, 73 | async_op: bool = False, 74 | ) -> Optional[Work]: 75 | """Broadcasts the tensor to the whole group. 76 | 77 | Args: 78 | tensor: Data to be sent if src is the rank of current process, 79 | or tensor to be used to save received data otherwise. 80 | src: Source rank. 81 | group: The process group to work on. If None, the default process group will be used. 82 | async_op: Whether this op should be an async op. 83 | 84 | Returns: 85 | A Work object if async_op is True, otherwise None. 86 | Returns None if distributed is not initialized or world_size is 1, as no actual communication occurs. 87 | """ 88 | if not is_dist_initialized_and_available() or get_world_size(group=group) == 1: 89 | return None 90 | return dist.broadcast(tensor, src=src, group=group, async_op=async_op) 91 | 92 | 93 | def all_gather( 94 | tensor_list: List[torch.Tensor], 95 | tensor: torch.Tensor, 96 | group: Optional[ProcessGroup] = None, 97 | async_op: bool = False, 98 | ) -> Optional[Work]: 99 | """Gathers tensors from the whole group in a list. 100 | 101 | Args: 102 | tensor_list: Output list. It should contain correctly-sized tensors to be used for output of the collective. 103 | tensor: Tensor to be broadcast from current process. 104 | group: The process group to work on. If None, the default process group will be used. 105 | async_op: Whether this op should be an async op. 106 | 107 | Returns: 108 | A Work object if async_op is True, otherwise None. 109 | If distributed is not initialized, it places the input tensor into tensor_list[0] (assuming single process context). 110 | """ 111 | if not is_dist_initialized_and_available(): 112 | rank = get_rank(group) 113 | if rank < len(tensor_list): 114 | tensor_list[rank] = tensor # pyright: ignore[reportGeneralTypeIssues] 115 | return None 116 | 117 | return dist.all_gather(tensor_list, tensor, group=group, async_op=async_op) 118 | -------------------------------------------------------------------------------- /tests/unit/training/test_checkpointing.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from unittest.mock import MagicMock 4 | from unittest.mock import patch 5 | 6 | from clt.training.checkpointing import CheckpointManager 7 | from clt.models.clt import CrossLayerTranscoder 8 | from clt.config import CLTConfig 9 | 10 | # --- Fixtures --- 11 | 12 | 13 | @pytest.fixture 14 | def tiny_model(device) -> CrossLayerTranscoder: 15 | """Provides a tiny, non-distributed model for testing.""" 16 | config = CLTConfig(num_layers=1, d_model=4, num_features=8) 17 | return CrossLayerTranscoder(config, process_group=None, device=device) 18 | 19 | 20 | @pytest.fixture 21 | def mock_activation_store(): 22 | """Mocks the BaseActivationStore.""" 23 | store = MagicMock() 24 | store.state_dict.return_value = {"sampler_state": "dummy"} 25 | return store 26 | 27 | 28 | @pytest.fixture 29 | def mock_wandb_logger(): 30 | """Mocks the WandB logger.""" 31 | logger = MagicMock() 32 | logger.get_current_wandb_run_id.return_value = "test_run_id_123" 33 | logger.log_artifact = MagicMock() 34 | return logger 35 | 36 | 37 | @pytest.fixture 38 | def checkpoint_manager_components(tmp_path, tiny_model, mock_activation_store, mock_wandb_logger, device): 39 | """A dictionary of components needed to instantiate CheckpointManager.""" 40 | return { 41 | "model": tiny_model, 42 | "activation_store": mock_activation_store, 43 | "wandb_logger": mock_wandb_logger, 44 | "log_dir": str(tmp_path), 45 | "distributed": False, 46 | "rank": 0, 47 | "device": device, 48 | "world_size": 1, 49 | } 50 | 51 | 52 | class TestCheckpointManagerNonDistributed: 53 | def test_save_checkpoint_non_distributed(self, checkpoint_manager_components, tmp_path): 54 | """ 55 | Verifies that _save_checkpoint (non-distributed) creates the correct files. 56 | """ 57 | manager = CheckpointManager(**checkpoint_manager_components) 58 | 59 | step = 100 60 | trainer_state = { 61 | "step": step, 62 | "optimizer_state_dict": {"param_groups": []}, 63 | "scheduler_state_dict": None, 64 | "scaler_state_dict": None, 65 | "n_forward_passes_since_fired": torch.zeros(1, 8), 66 | "wandb_run_id": "test_run_id_123", 67 | "torch_rng_state": torch.get_rng_state(), 68 | "numpy_rng_state": None, 69 | "python_rng_state": None, 70 | } 71 | 72 | # --- Call the save method --- 73 | manager._save_checkpoint(step, trainer_state) 74 | 75 | # --- Assertions --- 76 | # 1. Check for the step-specific files 77 | model_path = tmp_path / f"clt_checkpoint_{step}.safetensors" 78 | trainer_state_path = tmp_path / f"trainer_state_{step}.pt" 79 | store_path = tmp_path / f"activation_store_{step}.pt" 80 | 81 | assert model_path.exists(), "Model checkpoint file was not created." 82 | assert trainer_state_path.exists(), "Trainer state checkpoint file was not created." 83 | assert store_path.exists(), "Activation store checkpoint file was not created." 84 | 85 | # 2. Check for the 'latest' symlinks or copies 86 | latest_model_path = tmp_path / "clt_checkpoint_latest.safetensors" 87 | latest_trainer_state_path = tmp_path / "trainer_state_latest.pt" 88 | latest_store_path = tmp_path / "activation_store_latest.pt" 89 | 90 | assert latest_model_path.exists(), "Latest model checkpoint was not created." 91 | assert latest_trainer_state_path.exists(), "Latest trainer state was not created." 92 | assert latest_store_path.exists(), "Latest activation store was not created." 93 | 94 | # 3. Verify content of trainer state can be loaded 95 | loaded_trainer_state = torch.load(trainer_state_path) 96 | assert loaded_trainer_state["step"] == step 97 | assert loaded_trainer_state["wandb_run_id"] == "test_run_id_123" 98 | 99 | def test_save_checkpoint_handles_io_error(self, checkpoint_manager_components, caplog): 100 | """ 101 | Verifies that _save_checkpoint logs a warning and continues if an IO error occurs. 102 | """ 103 | manager = CheckpointManager(**checkpoint_manager_components) 104 | 105 | # Patch torch.save to raise an IOError 106 | with patch("torch.save", side_effect=IOError("Disk full")): 107 | with caplog.at_level("WARNING"): 108 | manager._save_checkpoint(step=100, trainer_state_to_save={}) 109 | 110 | # Check that a warning was logged 111 | assert "Warning: Failed to save non-distributed checkpoint" in caplog.text 112 | assert "Disk full" in caplog.text 113 | -------------------------------------------------------------------------------- /tests/unit/training/data/test_chunk_row_sampler.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | 4 | from clt.training.data.manifest_activation_store import ChunkRowSampler 5 | 6 | 7 | @pytest.fixture 8 | def sampler_params(): 9 | """Provides default parameters for creating a ChunkRowSampler.""" 10 | return { 11 | "chunk_sizes": {0: 30, 1: 25, 2: 40}, 12 | "num_chunks": 3, 13 | "batch": 10, 14 | "seed": 42, 15 | "epoch": 0, 16 | } 17 | 18 | 19 | class TestChunkRowSampler: 20 | def test_initialization(self, sampler_params): 21 | """Test that the sampler initializes correctly.""" 22 | sampler = ChunkRowSampler(rank=0, world=1, **sampler_params) 23 | assert sampler.batch == 10 24 | assert sampler.rank == 0 25 | assert sampler.world == 1 26 | assert sampler.epoch == 0 27 | assert sampler.seed == 42 28 | assert len(sampler) > 0 29 | 30 | @pytest.mark.parametrize("strategy", ["sequential", "random_chunk"]) 31 | def test_iteration_completeness(self, sampler_params, strategy): 32 | """Test that the sampler yields all possible rows exactly once per epoch without sharding.""" 33 | sampler = ChunkRowSampler(rank=0, world=1, sampling_strategy=strategy, **sampler_params) 34 | 35 | total_rows = sum(sampler_params["chunk_sizes"].values()) 36 | total_batches = total_rows // sampler_params["batch"] 37 | assert len(sampler) == total_batches 38 | 39 | yielded_pairs = set() 40 | for batch_indices in sampler: 41 | assert batch_indices.shape == (sampler_params["batch"], 2) 42 | for chunk_id, row_id in batch_indices: 43 | pair = (chunk_id, row_id) 44 | assert pair not in yielded_pairs, "Sampler yielded a duplicate (chunk, row) pair" 45 | yielded_pairs.add(pair) 46 | 47 | assert len(yielded_pairs) == total_batches * sampler_params["batch"] 48 | 49 | def test_sharding_correctness(self, sampler_params): 50 | """Test that data is correctly sharded across ranks with no overlap.""" 51 | world_size = 4 52 | all_yielded_pairs = [] 53 | for rank in range(world_size): 54 | sampler = ChunkRowSampler(rank=rank, world=world_size, **sampler_params) 55 | for batch_indices in sampler: 56 | for chunk_id, row_id in batch_indices: 57 | # Check that the yielded row_id is valid for this rank 58 | assert row_id % world_size == rank 59 | all_yielded_pairs.append((chunk_id, row_id)) 60 | 61 | # Check for duplicates across all yielded pairs from all ranks 62 | assert len(all_yielded_pairs) == len(set(all_yielded_pairs)), "Duplicate pairs were yielded across ranks." 63 | 64 | def test_state_dict_roundtrip(self, sampler_params): 65 | """Test that saving and loading state allows for exact resumption.""" 66 | # 1. Create a sampler and iterate halfway through 67 | sampler1 = ChunkRowSampler(rank=0, world=1, **sampler_params) 68 | mid_point = len(sampler1) // 2 69 | 70 | first_half_batches = [] 71 | for i, batch in enumerate(sampler1): 72 | if i >= mid_point: 73 | break 74 | first_half_batches.append(batch) 75 | 76 | # 2. Save its state 77 | state = sampler1.state_dict() 78 | 79 | # 3. Create a new sampler and load the state 80 | sampler2 = ChunkRowSampler(rank=0, world=1, **sampler_params) 81 | sampler2.load_state_dict(state) 82 | 83 | # 4. Ensure the next batch from the new sampler is the same as the one after the midpoint 84 | next_batch_from_1 = next(sampler1) 85 | next_batch_from_2 = next(sampler2) 86 | 87 | np.testing.assert_array_equal(next_batch_from_1, next_batch_from_2) 88 | 89 | def test_epoch_determinism(self, sampler_params): 90 | """Test that different epochs produce different samples, and same epochs produce same samples.""" 91 | # Sampler for epoch 0 92 | sampler_e0_a = ChunkRowSampler(rank=0, world=1, **sampler_params) 93 | batches_e0_a = list(sampler_e0_a) 94 | 95 | # A second sampler for epoch 0 should be identical 96 | sampler_e0_b = ChunkRowSampler(rank=0, world=1, **sampler_params) 97 | batches_e0_b = list(sampler_e0_b) 98 | np.testing.assert_array_equal(np.vstack(batches_e0_a), np.vstack(batches_e0_b)) 99 | 100 | # A sampler for epoch 1 should be different 101 | params_e1 = sampler_params.copy() 102 | params_e1["epoch"] = 1 103 | sampler_e1 = ChunkRowSampler(rank=0, world=1, **params_e1) 104 | batches_e1 = list(sampler_e1) 105 | 106 | assert len(batches_e0_a) == len(batches_e1) 107 | assert not np.array_equal(np.vstack(batches_e0_a), np.vstack(batches_e1)) 108 | -------------------------------------------------------------------------------- /clt/training/data/activation_store_factory.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | 4 | from clt.config import TrainingConfig, CLTConfig 5 | from clt.training.data.base_store import BaseActivationStore 6 | from clt.training.data.local_activation_store import LocalActivationStore 7 | from clt.training.data.remote_activation_store import RemoteActivationStore 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def create_activation_store( 13 | training_config: TrainingConfig, 14 | clt_config: CLTConfig, 15 | device: torch.device, 16 | rank: int, 17 | world_size: int, 18 | start_time: float, 19 | shard_data: bool = True, 20 | ) -> BaseActivationStore: 21 | """Create the appropriate activation store based on training config. 22 | 23 | Valid activation_source values: 24 | - "local_manifest": Use LocalActivationStore with local manifest/chunks. 25 | - "remote": Use RemoteActivationStore with remote server. 26 | 27 | Args: 28 | training_config: The training configuration object. 29 | clt_config: The CLT model configuration (currently unused here after removing generate). 30 | device: The torch device to use. 31 | rank: The distributed rank. 32 | world_size: The distributed world size. 33 | start_time: The training start time for elapsed time logging (unused if generate is gone). 34 | shard_data: Whether to include shard data in the store. 35 | 36 | Returns: 37 | Configured instance of a BaseActivationStore subclass. 38 | """ 39 | activation_source = training_config.activation_source 40 | sampling_strategy = training_config.sampling_strategy 41 | 42 | store: BaseActivationStore 43 | 44 | if activation_source == "local_manifest": 45 | logger.info(f"Rank {rank}: Using LocalActivationStore (reading local manifest/chunks).") 46 | if not training_config.activation_path: 47 | raise ValueError( 48 | "activation_path must be set in TrainingConfig when activation_source is 'local_manifest'." 49 | ) 50 | store = LocalActivationStore( 51 | dataset_path=training_config.activation_path, 52 | train_batch_size_tokens=training_config.train_batch_size_tokens, 53 | device=device, 54 | dtype=training_config.activation_dtype, 55 | rank=rank, 56 | world=world_size, 57 | seed=training_config.seed, 58 | sampling_strategy=sampling_strategy, 59 | normalization_method=training_config.normalization_method, 60 | shard_data=shard_data, 61 | ) 62 | if isinstance(store, LocalActivationStore): 63 | logger.info(f"Rank {rank}: Initialized LocalActivationStore from path: {store.dataset_path}") 64 | if store.apply_normalization: 65 | logger.info(f"Rank {rank}: Normalization ENABLED using loaded norm_stats.json.") 66 | else: 67 | logger.warning(f"Rank {rank}: Normalization DISABLED (processing failed or file incomplete/invalid).") 68 | elif activation_source == "remote": 69 | logger.info(f"Rank {rank}: Using RemoteActivationStore (remote slice server).") 70 | remote_cfg = training_config.remote_config 71 | if remote_cfg is None: 72 | raise ValueError("remote_config dict must be set in TrainingConfig when activation_source is 'remote'.") 73 | server_url = remote_cfg.get("server_url") 74 | dataset_id = remote_cfg.get("dataset_id") 75 | if not server_url or not dataset_id: 76 | raise ValueError("remote_config must contain 'server_url' and 'dataset_id'.") 77 | 78 | store = RemoteActivationStore( 79 | server_url=server_url, 80 | dataset_id=dataset_id, 81 | train_batch_size_tokens=training_config.train_batch_size_tokens, 82 | device=device, 83 | dtype=training_config.activation_dtype, 84 | rank=rank, 85 | world=world_size, 86 | seed=training_config.seed, 87 | timeout=remote_cfg.get("timeout", 60), 88 | sampling_strategy=sampling_strategy, 89 | normalization_method=training_config.normalization_method, 90 | shard_data=shard_data, 91 | ) 92 | if isinstance(store, RemoteActivationStore): 93 | logger.info(f"Rank {rank}: Initialized RemoteActivationStore for dataset: {store.did_raw}") 94 | if store.apply_normalization: 95 | logger.info(f"Rank {rank}: Normalization ENABLED using fetched norm_stats.json.") 96 | else: 97 | logger.warning( 98 | f"Rank {rank}: Normalization DISABLED (norm_stats.json not found on server or failed to load)." 99 | ) 100 | else: 101 | raise ValueError(f"Unknown activation_source: {activation_source}. Valid options: 'local_manifest', 'remote'.") 102 | return store 103 | -------------------------------------------------------------------------------- /tests/integration/test_clt_distributed.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.distributed as dist 4 | import torch.multiprocessing as mp 5 | import os 6 | from typing import cast 7 | 8 | from clt.config import CLTConfig 9 | from clt.models.clt import CrossLayerTranscoder 10 | 11 | 12 | def setup_distributed_environment(rank, world_size, port="12356"): 13 | """Initializes the distributed process group.""" 14 | os.environ["MASTER_ADDR"] = "localhost" 15 | os.environ["MASTER_PORT"] = port 16 | dist.init_process_group("gloo", rank=rank, world_size=world_size) 17 | 18 | 19 | def cleanup_distributed_environment(): 20 | """Cleans up the distributed process group.""" 21 | dist.destroy_process_group() 22 | 23 | 24 | def distributed_test_runner(rank, world_size, test_fn, *args): 25 | """A wrapper to run a distributed test function.""" 26 | setup_distributed_environment(rank, world_size) 27 | try: 28 | test_fn(rank, world_size, *args) 29 | finally: 30 | cleanup_distributed_environment() 31 | 32 | 33 | # --- Test Functions (to be run in separate processes) --- 34 | 35 | 36 | def _test_forward_pass_distributed(rank, world_size): 37 | """ 38 | Tests that the forward pass produces identical results on all ranks. 39 | """ 40 | device = torch.device("cpu") 41 | torch.manual_seed(42) # Ensure same model initialization 42 | 43 | clt_config = CLTConfig(num_layers=2, d_model=8, num_features=16, activation_fn="relu") 44 | model = CrossLayerTranscoder(config=clt_config, process_group=dist.group.WORLD, device=device) 45 | 46 | # All ranks get the same input 47 | torch.manual_seed(123) 48 | sample_inputs = { 49 | 0: torch.randn(20, clt_config.d_model, device=device), 50 | 1: torch.randn(20, clt_config.d_model, device=device), 51 | } 52 | 53 | reconstructions = model.forward(sample_inputs) 54 | loss = torch.mean(reconstructions[0]) # A simple, deterministic loss 55 | 56 | # Gather the loss from all ranks 57 | loss_list = [torch.zeros_like(loss) for _ in range(world_size)] 58 | dist.all_gather(loss_list, loss) 59 | 60 | # The loss, and therefore the forward pass result, should be identical on all ranks 61 | for other_loss in loss_list: 62 | assert torch.allclose(loss, other_loss), "Forward pass results (losses) differ across ranks" 63 | 64 | 65 | def _test_sharded_gradient(rank, world_size): 66 | """ 67 | Tests that sharded parameters receive different gradients on each rank. 68 | """ 69 | device = torch.device("cpu") 70 | # Use rank-specific seed for weight initialization to ensure different weights 71 | torch.manual_seed(42 + rank) 72 | 73 | clt_config = CLTConfig(num_layers=2, d_model=8, num_features=16, activation_fn="relu") 74 | model = CrossLayerTranscoder(config=clt_config, process_group=dist.group.WORLD, device=device) 75 | 76 | # All ranks get the same input 77 | torch.manual_seed(123) 78 | sample_inputs = {0: torch.randn(5, clt_config.d_model, device=device)} 79 | 80 | # Forward pass 81 | reconstructions = model.forward(sample_inputs) 82 | 83 | # Create a loss that depends on the actual output values 84 | # This will produce different gradients for different weight values 85 | target = torch.randn_like(reconstructions[0]) 86 | loss = torch.nn.functional.mse_loss(reconstructions[0], target) 87 | 88 | # Backward pass 89 | loss.backward() 90 | 91 | # Test gradients of a SHARDED parameter (e.g., Encoder weights) 92 | sharded_grad_optional = model.encoder_module.encoders[0].weight.grad 93 | assert sharded_grad_optional is not None, "Gradient for sharded parameter should exist" 94 | sharded_grad = cast(torch.Tensor, sharded_grad_optional) 95 | 96 | # Gather all gradients to compare 97 | grad_list = [torch.zeros_like(sharded_grad) for _ in range(world_size)] 98 | dist.all_gather(grad_list, sharded_grad) 99 | 100 | # The gradients for a sharded parameter should be DIFFERENT on each rank 101 | # because each rank has different weights and computes different outputs 102 | assert not torch.allclose( 103 | grad_list[0], grad_list[1], rtol=1e-5, atol=1e-8 104 | ), "Gradients for sharded parameters should be different across ranks" 105 | 106 | 107 | # --- Pytest Test Class --- 108 | 109 | 110 | @pytest.mark.integration 111 | @pytest.mark.distributed 112 | @pytest.mark.skipif(not dist.is_available(), reason="torch.distributed not available") 113 | class TestCLTDistributed: 114 | def test_forward_pass(self): 115 | world_size = 2 116 | mp.spawn( # type: ignore[attr-defined] 117 | distributed_test_runner, 118 | args=(world_size, _test_forward_pass_distributed), 119 | nprocs=world_size, 120 | join=True, 121 | ) 122 | 123 | def test_gradient_sharding(self): 124 | world_size = 2 125 | mp.spawn( # type: ignore[attr-defined] 126 | distributed_test_runner, 127 | args=(world_size, _test_sharded_gradient), 128 | nprocs=world_size, 129 | join=True, 130 | ) 131 | -------------------------------------------------------------------------------- /clt/activations/registry.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import logging 4 | from typing import Callable, Dict, TYPE_CHECKING, cast 5 | 6 | if TYPE_CHECKING: 7 | from clt.models.clt import CrossLayerTranscoder # To avoid circular import 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | # Type alias for activation functions used in the registry 12 | # They take the CrossLayerTranscoder instance, pre-activations, and layer_idx 13 | ActivationCallable = Callable[["CrossLayerTranscoder", torch.Tensor, int], torch.Tensor] 14 | 15 | ACTIVATION_REGISTRY: Dict[str, ActivationCallable] = {} 16 | 17 | 18 | def register_activation_fn(name: str) -> Callable[[ActivationCallable], ActivationCallable]: 19 | """Decorator to register a new activation function.""" 20 | 21 | def decorator(fn: ActivationCallable) -> ActivationCallable: 22 | if name in ACTIVATION_REGISTRY: 23 | logger.warning(f"Activation function '{name}' is already registered. Overwriting.") 24 | ACTIVATION_REGISTRY[name] = fn 25 | return fn 26 | 27 | return decorator 28 | 29 | 30 | @register_activation_fn("relu") 31 | def relu_activation(model: "CrossLayerTranscoder", preact: torch.Tensor, layer_idx: int) -> torch.Tensor: 32 | """Standard ReLU activation.""" 33 | return F.relu(preact) 34 | 35 | 36 | @register_activation_fn("jumprelu") 37 | def jumprelu_activation(model: "CrossLayerTranscoder", preact: torch.Tensor, layer_idx: int) -> torch.Tensor: 38 | """JumpReLU activation.""" 39 | # The model's jumprelu method handles device/dtype and threshold selection 40 | return model.jumprelu(preact, layer_idx) 41 | 42 | 43 | @register_activation_fn("batchtopk") 44 | def batchtopk_per_layer_activation(model: "CrossLayerTranscoder", preact: torch.Tensor, layer_idx: int) -> torch.Tensor: 45 | """BatchTopK activation applied per-layer (not global).""" 46 | from clt.models.activations import BatchTopK # Local import to avoid issues if activations.py imports this registry 47 | 48 | logger.warning( 49 | f"Rank {model.rank}: 'encode' called for BatchTopK on layer {layer_idx}. " 50 | f"This applies TopK per-layer, not globally. Use 'get_feature_activations' for global BatchTopK." 51 | ) 52 | k_val_local_int: int 53 | if model.config.batchtopk_k is not None: 54 | k_val_local_int = int(model.config.batchtopk_k) 55 | else: 56 | # If k is None, default to keeping all features for this layer. 57 | # This might happen if batchtopk_k is not set in the config, 58 | # though it typically should be for BatchTopK. 59 | k_val_local_int = preact.size(1) # Number of features in this layer's preactivation 60 | logger.warning( 61 | f"Rank {model.rank}: batchtopk_k not set in config for per-layer BatchTopK on layer {layer_idx}. " 62 | f"Defaulting to k={k_val_local_int} (all features for this layer)." 63 | ) 64 | 65 | # BatchTopK.apply takes the original preactivation, k, straight_through, and optional ranking tensor. 66 | # For per-layer application, we don't have a separate normalized ranking tensor readily available here from encode_all_layers, 67 | # so we pass preact itself for ranking if x_for_ranking is None. 68 | # Normalization for ranking, if desired for per-layer, would need to happen here or BatchTopK would need to handle it. 69 | # The global BatchTopK in _apply_batch_topk_helper does normalization. 70 | return cast( 71 | torch.Tensor, BatchTopK.apply(preact, float(k_val_local_int), model.config.batchtopk_straight_through, preact) 72 | ) 73 | 74 | 75 | @register_activation_fn("topk") 76 | def topk_per_layer_activation(model: "CrossLayerTranscoder", preact: torch.Tensor, layer_idx: int) -> torch.Tensor: 77 | """TokenTopK activation applied per-layer (not global).""" 78 | from clt.models.activations import TokenTopK # Local import 79 | 80 | logger.warning( 81 | f"Rank {model.rank}: 'encode' called for TopK (TokenTopK) on layer {layer_idx}. " 82 | f"This applies TopK per-layer, not globally. Use 'get_feature_activations' for global TopK." 83 | ) 84 | k_val_local_float: float 85 | if hasattr(model.config, "topk_k") and model.config.topk_k is not None: 86 | k_val_local_float = float(model.config.topk_k) 87 | else: 88 | # Default to keeping all features for this layer if topk_k not set 89 | k_val_local_float = float(preact.size(1)) 90 | logger.warning( 91 | f"Rank {model.rank}: topk_k not set in config for per-layer TopK on layer {layer_idx}. " 92 | f"Defaulting to k={k_val_local_float} (all features for this layer)." 93 | ) 94 | 95 | straight_through_local = getattr(model.config, "topk_straight_through", True) 96 | # TokenTopK.apply takes preact, k, straight_through, and x_for_ranking. 97 | # Similar to BatchTopK, for per-layer, we use preact for ranking if x_for_ranking is None. 98 | return cast(torch.Tensor, TokenTopK.apply(preact, k_val_local_float, straight_through_local, preact)) 99 | 100 | 101 | def get_activation_fn(name: str) -> ActivationCallable: 102 | """Retrieve an activation function from the registry.""" 103 | fn = ACTIVATION_REGISTRY.get(name) 104 | if fn is None: 105 | raise ValueError( 106 | f"Activation function '{name}' not found in registry. Available: {list(ACTIVATION_REGISTRY.keys())}" 107 | ) 108 | return fn 109 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | 176 | # Other 177 | .cursor/ 178 | tutorials/clt_training_logs/ 179 | ref_docs/ 180 | clt_training_logs/ 181 | tests/integration/data/ 182 | clt_output_tutorial/ 183 | .DS_Store 184 | mock_log_dir/ 185 | tutorials/tutorial_activations/ 186 | tutorials/tutorial_activations_local/ 187 | temp_tutorial_server_data/ 188 | tutorials/tutorial_activations_small/ 189 | tutorials/tutorial_activations_local_1M/ 190 | tutorials/tutorial_activations_local_1M_v2/ 191 | tutorials/tutorial_activations_local_100k/ 192 | tutorials/tutorial_activations_local_100k_v2/ 193 | tutorials/tutorial_activations_local_1M_pythia/ 194 | tutorials/tutorial_activations_local_100k_pythia/ 195 | tutorials/tutorial_activations_local_1M_float16/ 196 | tutorials/tutorial_activations_local_1M_pythia_ln/ 197 | 2412.06410v1.pdf 198 | tutorials/tutorial_activations_local_1M_pythia_160m/ 199 | tutorials/tutorial_activations_local_1M_mlp_only/ 200 | conversion_test/ 201 | server_data/ 202 | tutorials/old/ 203 | clt_remote_runs/ 204 | *tutorial_activations*/ 205 | vis/ 206 | clt_test_pythia_70m_jumprelu/ 207 | clt_smoke_output_local_wandb_batchtopk/ 208 | clt_smoke_output_remote_wandb/ 209 | wandb/ 210 | scripts/debug 211 | scripts/optimization 212 | sparsify/ 213 | clt-training/ 214 | 215 | # models 216 | *.pt 217 | *.pth 218 | *.pth.tar 219 | *.pth.tar.gz 220 | *.pth.tar.bz2 221 | *.pth.tar.xz 222 | *.pth.tar.lzma -------------------------------------------------------------------------------- /clt/config/data_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Literal, Optional, Dict, Any 3 | import logging 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | @dataclass 9 | class ActivationConfig: 10 | """Configuration for generating or locating activation datasets.""" 11 | 12 | # --- Model Source Identification --- 13 | model_name: str # Name or path of the Hugging Face transformer model 14 | mlp_input_module_path_template: str # NNsight path template for MLP inputs 15 | mlp_output_module_path_template: str # NNsight path template for MLP outputs 16 | # --- Dataset Source Identification --- 17 | dataset_path: str # Path or name of the Hugging Face dataset 18 | # --- Fields with Defaults --- # 19 | model_dtype: Optional[str] = None # Optional dtype for the model ('float16', 'bfloat16') 20 | activation_dtype: Literal["bfloat16", "float16", "float32"] = "bfloat16" # Precision for storing activations 21 | dataset_split: str = "train" # Dataset split to use 22 | dataset_text_column: str = "text" # Column containing text data 23 | 24 | # --- Generation Parameters --- 25 | context_size: int = 128 # Max sequence length for tokenization/inference 26 | inference_batch_size: int = 512 # Batch size for model inference during generation 27 | exclude_special_tokens: bool = True # Exclude special tokens during generation 28 | prepend_bos: bool = False # Prepend BOS token during generation 29 | 30 | # --- Dataset Handling Parameters (for generation) --- 31 | streaming: bool = True # Use HF dataset streaming during generation 32 | dataset_trust_remote_code: bool = False # Trust remote code for HF dataset 33 | cache_path: Optional[str] = None # Optional cache path for HF dataset (if not streaming) 34 | 35 | # --- Generation Output Control --- 36 | target_total_tokens: Optional[int] = None # Target num tokens to generate (approximate) 37 | 38 | # --- Storage Parameters (for generation output) --- 39 | activation_dir: str = "./activations" # Base directory to save activation datasets 40 | output_format: Literal["hdf5", "npz"] = "hdf5" # Format to save activations 41 | compression: Optional[str] = "gzip" # Compression for saved files ('lz4', 'gzip', None) 42 | chunk_token_threshold: int = 1_000_000 # Minimum tokens to accumulate before saving a chunk 43 | # Note: 'storage_type' (local/remote) is handled by the generator script/workflow, 44 | # not stored intrinsically here, as the generated data itself is local. 45 | 46 | # --- Normalization Computation (during generation) --- 47 | compute_norm_stats: bool = True # Compute mean/std during generation and save to norm_stats.json 48 | 49 | # --- Remote Storage Parameters --- 50 | remote_server_url: Optional[str] = None # Base URL of the remote activation server 51 | delete_after_upload: bool = False # Delete local chunk after successful upload 52 | upload_max_retries: int = 5 # Max number of upload retries per chunk 53 | upload_initial_backoff: float = 1.0 # Initial backoff delay in seconds for retries 54 | upload_max_backoff: float = 30.0 # Maximum backoff delay in seconds for retries 55 | 56 | # --- NNsight Parameters (Optional) --- 57 | # Use field to allow mutable default dict 58 | nnsight_tracer_kwargs: Dict[str, Any] = field(default_factory=dict) 59 | nnsight_invoker_args: Dict[str, Any] = field(default_factory=dict) 60 | 61 | # --- Profiling Control (during generation) --- 62 | enable_profiling: bool = False # Whether to enable detailed performance profiling during generation 63 | 64 | # --- Device Parameter --- 65 | # While generation happens on a device, this config is more about the data itself. 66 | # The device used for generation can be passed separately to the generator instance. 67 | # device: Optional[str] = None 68 | 69 | def __post_init__(self): 70 | """Validate configuration parameters.""" 71 | assert self.context_size > 0, "Context size must be positive" 72 | assert self.inference_batch_size > 0, "Inference batch size must be positive" 73 | assert self.chunk_token_threshold > 0, "Chunk token threshold must be positive" 74 | if self.output_format == "hdf5": 75 | try: 76 | import h5py # noqa: F401 - Check if h5py is available if format is hdf5 77 | except ImportError: 78 | raise ImportError("h5py is required for HDF5 output format. Install with: pip install h5py") 79 | if self.compression not in ["lz4", "gzip", None, False]: 80 | logger.warning( 81 | f"Warning: Unsupported compression '{self.compression}'. Will attempt without compression for {self.output_format}." 82 | ) 83 | # Allow generator to handle disabling if format doesn't support it. 84 | 85 | # Example: Print a summary or key values 86 | # This is more for user feedback than programmatic use. 87 | logger.info( 88 | "ActivationConfig Summary:\n" 89 | f" Model: {self.model_name}\n" 90 | f" Dataset: {self.dataset_path} (Split: {self.dataset_split})\n" 91 | f" Target Tokens: {self.target_total_tokens}\n" 92 | f" Chunk Threshold: {self.chunk_token_threshold}\n" 93 | f" Activation Dtype: {self.activation_dtype}\n" 94 | f" Output Dir: {self.activation_dir}" 95 | ) 96 | if self.remote_server_url: 97 | logger.info(f" Remote Server URL: {self.remote_server_url}") 98 | if self.delete_after_upload: 99 | logger.info(" Delete after upload: Enabled") 100 | -------------------------------------------------------------------------------- /tests/unit/models/test_decoder.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import logging 4 | 5 | from clt.config import CLTConfig 6 | from clt.models.decoder import Decoder 7 | 8 | 9 | def get_available_devices(): 10 | """Returns available devices, including cpu, mps, and cuda if available.""" 11 | devices = ["cpu"] 12 | if torch.cuda.is_available(): 13 | devices.append("cuda") 14 | if torch.backends.mps.is_available(): 15 | devices.append("mps") 16 | return devices 17 | 18 | 19 | DEVICES = get_available_devices() 20 | GPU_DEVICES = [d for d in DEVICES if d != "cpu"] 21 | 22 | 23 | def require_gpu(func): 24 | """Decorator to skip tests if no GPU is available.""" 25 | return pytest.mark.skipif(not GPU_DEVICES, reason="Test requires a GPU (CUDA or MPS)")(func) 26 | 27 | 28 | @pytest.fixture 29 | def clt_config(): 30 | """Provides a basic CLTConfig for testing.""" 31 | return CLTConfig( 32 | num_layers=2, 33 | d_model=8, 34 | num_features=16, 35 | ) 36 | 37 | 38 | @pytest.fixture 39 | def decoder(clt_config, device): 40 | """Provides a Decoder instance.""" 41 | return Decoder( 42 | config=clt_config, 43 | process_group=None, # Non-distributed for unit tests 44 | device=device, 45 | dtype=torch.float32, 46 | ).to(device) 47 | 48 | 49 | class TestDecoder: 50 | def test_decode_single_layer(self, decoder, clt_config, device): 51 | """Test decoding from a single source layer.""" 52 | batch_size = 4 53 | activations = {0: torch.randn(batch_size, clt_config.num_features, device=device)} 54 | target_layer = 0 55 | 56 | reconstruction = decoder.decode(activations, target_layer) 57 | 58 | assert reconstruction.shape == (batch_size, clt_config.d_model) 59 | assert reconstruction.device.type == device.type 60 | assert reconstruction.dtype == torch.float32 61 | 62 | def test_decode_multi_layer_sum(self, decoder, clt_config, device): 63 | """Test that reconstructions are summed across multiple source layers.""" 64 | batch_size = 3 65 | activations = { 66 | 0: torch.ones(batch_size, clt_config.num_features, device=device), 67 | 1: torch.ones(batch_size, clt_config.num_features, device=device) * 2, 68 | } 69 | target_layer = 1 70 | 71 | # Decode with both layers 72 | reconstruction_both = decoder.decode(activations, target_layer) 73 | 74 | # Decode with only the first layer 75 | reconstruction_first = decoder.decode({0: activations[0]}, target_layer) 76 | 77 | # Decode with only the second layer 78 | reconstruction_second = decoder.decode({1: activations[1]}, target_layer) 79 | 80 | # The sum should be close to the combined reconstruction 81 | torch.testing.assert_close(reconstruction_both, reconstruction_first + reconstruction_second) 82 | 83 | def test_decode_empty_activations_dict(self, decoder, clt_config, device): 84 | """Test decode with an empty activation dictionary.""" 85 | reconstruction = decoder.decode({}, layer_idx=0) 86 | assert reconstruction.shape == (0, clt_config.d_model) 87 | 88 | def test_decode_activations_with_zero_elements(self, decoder, clt_config, device): 89 | """Test decode with a tensor that has a zero dimension.""" 90 | activations = {0: torch.randn(0, clt_config.num_features, device=device)} 91 | reconstruction = decoder.decode(activations, layer_idx=0) 92 | assert reconstruction.shape == (0, clt_config.d_model) 93 | 94 | def test_decode_mismatched_feature_dim(self, decoder, clt_config, device, caplog): 95 | """Test that activations with wrong feature dimensions are skipped with a warning.""" 96 | batch_size = 4 97 | wrong_features = clt_config.num_features - 4 98 | activations = { 99 | 0: torch.randn(batch_size, wrong_features, device=device), 100 | 1: torch.randn(batch_size, clt_config.num_features, device=device), 101 | } 102 | target_layer = 1 103 | 104 | with caplog.at_level(logging.WARNING): 105 | reconstruction = decoder.decode(activations, target_layer) 106 | 107 | # The reconstruction should only be from layer 1 108 | reconstruction_only_l1 = decoder.decode({1: activations[1]}, target_layer) 109 | torch.testing.assert_close(reconstruction, reconstruction_only_l1) 110 | 111 | assert "incorrect feature dimension" in caplog.text 112 | assert str(wrong_features) in caplog.text 113 | 114 | def test_get_decoder_norms_shape_and_device(self, decoder, clt_config, device): 115 | """Test the shape and device of the decoder norms tensor.""" 116 | norms = decoder.get_decoder_norms() 117 | assert norms.shape == (clt_config.num_layers, clt_config.num_features) 118 | assert norms.device.type == device.type 119 | 120 | def test_get_decoder_norms_caching(self, decoder): 121 | """Test that get_decoder_norms caches the result.""" 122 | norms1 = decoder.get_decoder_norms() 123 | norms2 = decoder.get_decoder_norms() 124 | # The exact same object should be returned 125 | assert id(norms1) == id(norms2) 126 | # Invalidate cache and check again 127 | decoder._cached_decoder_norms = None 128 | norms3 = decoder.get_decoder_norms() 129 | assert id(norms1) != id(norms3) 130 | 131 | @require_gpu 132 | @pytest.mark.parametrize("device", GPU_DEVICES) 133 | def test_decode_on_gpu(self, clt_config, device): 134 | """Dedicated test to ensure decode runs on GPU.""" 135 | device_obj = torch.device(device) 136 | decoder = Decoder( 137 | config=clt_config, 138 | process_group=None, 139 | device=device_obj, 140 | dtype=torch.float32, 141 | ).to(device_obj) 142 | self.test_decode_single_layer(decoder, clt_config, device_obj) 143 | -------------------------------------------------------------------------------- /tests/models/test_clt_distributed_forward.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import os 4 | from typing import Dict, Optional, Literal 5 | 6 | from clt.config import CLTConfig 7 | from clt.models.clt import CrossLayerTranscoder 8 | 9 | 10 | # Helper to initialize distributed environment for the test 11 | def setup_distributed_test(rank, world_size, master_port="12355"): 12 | os.environ["MASTER_ADDR"] = "localhost" 13 | os.environ["MASTER_PORT"] = master_port 14 | dist.init_process_group("nccl" if torch.cuda.is_available() else "gloo", rank=rank, world_size=world_size) 15 | if torch.cuda.is_available(): 16 | torch.cuda.set_device(rank) # Set device for this process 17 | 18 | 19 | # Helper to cleanup distributed environment 20 | def cleanup_distributed_test(): 21 | dist.destroy_process_group() 22 | 23 | 24 | def run_forward_pass_test( 25 | rank, world_size, activation_fn: Literal["jumprelu", "relu", "batchtopk", "topk"], batchtopk_k: Optional[int] = None 26 | ): 27 | setup_distributed_test(rank, world_size) 28 | 29 | d_model = 64 # Small d_model for testing 30 | num_features_per_layer = d_model * 2 31 | num_layers = 2 # Small number of layers 32 | batch_size = 4 33 | seq_len = 8 34 | batch_tokens = batch_size * seq_len 35 | 36 | clt_config = CLTConfig( 37 | d_model=d_model, 38 | num_features=num_features_per_layer, 39 | num_layers=num_layers, 40 | activation_fn=activation_fn, 41 | batchtopk_k=batchtopk_k, 42 | # jumprelu_threshold is only relevant if activation_fn is jumprelu 43 | jumprelu_threshold=0.01 if activation_fn == "jumprelu" else 0.0, 44 | ) 45 | 46 | # Determine device for the model based on availability and rank 47 | current_device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu") 48 | 49 | # Instantiate model - process_group is automatically handled by CrossLayerTranscoder init if dist is initialized 50 | model = CrossLayerTranscoder(config=clt_config, process_group=None, device=current_device) # PG is WORLD implicitly 51 | model.to(current_device) 52 | model.eval() # Set to eval mode 53 | 54 | # Create identical dummy input data on all ranks 55 | # (batch_tokens, d_model) 56 | dummy_inputs: Dict[int, torch.Tensor] = {} 57 | for i in range(num_layers): 58 | # Ensure identical tensor across ranks using a fixed seed before creating tensor 59 | torch.manual_seed(42 + i) # Same seed for each layer across ranks 60 | dummy_inputs[i] = torch.randn(batch_tokens, d_model, device=current_device, dtype=model.dtype) 61 | 62 | # Perform forward pass 63 | reconstructions = model.forward(dummy_inputs) 64 | 65 | # Assertions 66 | assert isinstance(reconstructions, dict) 67 | assert len(reconstructions) == num_layers 68 | 69 | # Gather all reconstruction tensors to rank 0 for comparison (if more than 1 GPU) 70 | # Or, more simply, each rank asserts its output is identical to a tensor broadcast from rank 0 71 | for layer_idx in range(num_layers): 72 | output_tensor = reconstructions[layer_idx] 73 | assert output_tensor.shape == (batch_tokens, d_model) 74 | assert output_tensor.device == current_device 75 | assert output_tensor.dtype == model.dtype 76 | 77 | # All-reduce the sum of the tensor and sum of squares. If identical, these will be world_size * val. 78 | # This is a robust way to check for numerical identity across ranks. 79 | sum_val = output_tensor.sum() 80 | sum_sq_val = (output_tensor**2).sum() 81 | 82 | gathered_sum_list = [torch.zeros_like(sum_val) for _ in range(world_size)] 83 | gathered_sum_sq_list = [torch.zeros_like(sum_sq_val) for _ in range(world_size)] 84 | 85 | if world_size > 1: 86 | dist.all_gather(gathered_sum_list, sum_val) 87 | dist.all_gather(gathered_sum_sq_list, sum_sq_val) 88 | else: 89 | gathered_sum_list = [sum_val] 90 | gathered_sum_sq_list = [sum_sq_val] 91 | 92 | # Check if all gathered sums and sum_sq are close to each other 93 | for i in range(1, world_size): 94 | assert torch.allclose( 95 | gathered_sum_list[0], gathered_sum_list[i] 96 | ), f"Rank {rank} Layer {layer_idx} sum mismatch: {gathered_sum_list[0]} vs {gathered_sum_list[i]} (rank {i}) for act_fn {activation_fn}" 97 | assert torch.allclose( 98 | gathered_sum_sq_list[0], gathered_sum_sq_list[i] 99 | ), f"Rank {rank} Layer {layer_idx} sum_sq mismatch: {gathered_sum_sq_list[0]} vs {gathered_sum_sq_list[i]} (rank {i}) for act_fn {activation_fn}" 100 | 101 | if rank == 0: 102 | print(f"Distributed forward test PASSED for rank {rank}, activation_fn='{activation_fn}'") 103 | 104 | cleanup_distributed_test() 105 | 106 | 107 | # Main test execution controlled by torchrun 108 | if __name__ == "__main__": 109 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 110 | rank = int(os.environ.get("RANK", 0)) 111 | 112 | print(f"Starting distributed test on rank {rank} of {world_size}") 113 | 114 | # Test with ReLU 115 | print(f"Rank {rank}: Running test for ReLU") 116 | run_forward_pass_test(rank, world_size, activation_fn="relu") 117 | if world_size > 1: 118 | dist.barrier() # Ensure test finishes before next one 119 | 120 | # Test with BatchTopK 121 | print(f"Rank {rank}: Running test for BatchTopK") 122 | run_forward_pass_test(rank, world_size, activation_fn="batchtopk", batchtopk_k=10) 123 | if world_size > 1: 124 | dist.barrier() 125 | 126 | # Add more activation functions to test if needed, e.g., jumprelu 127 | # print(f"Rank {rank}: Running test for JumpReLU") 128 | # run_forward_pass_test(rank, world_size, activation_fn="jumprelu") 129 | # if world_size > 1: dist.barrier() 130 | 131 | if rank == 0: 132 | print("All distributed forward tests completed.") 133 | -------------------------------------------------------------------------------- /clt/models/activations_optimized.py: -------------------------------------------------------------------------------- 1 | """Optimized activation functions for better performance.""" 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from typing import Optional, Dict, Any 6 | import logging 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class OptimizedBatchTopK(torch.autograd.Function): 12 | """Optimized BatchTopK with fused operations and better memory usage.""" 13 | 14 | @staticmethod 15 | def _compute_mask_optimized( 16 | x: torch.Tensor, 17 | k_per_token: int, 18 | x_for_ranking: Optional[torch.Tensor] = None 19 | ) -> torch.Tensor: 20 | """Optimized mask computation with fewer allocations.""" 21 | B = x.size(0) 22 | if k_per_token <= 0: 23 | return torch.zeros_like(x, dtype=torch.bool) 24 | 25 | # Early exit for full selection 26 | F_total_batch = x.numel() 27 | if F_total_batch == 0: 28 | return torch.zeros_like(x, dtype=torch.bool) 29 | 30 | k_total_batch = min(k_per_token * B, F_total_batch) 31 | 32 | # Use the ranking tensor if provided, otherwise use x 33 | ranking_tensor = x_for_ranking if x_for_ranking is not None else x 34 | 35 | # Fused reshape and topk - avoid intermediate allocations 36 | if k_total_batch > 0: 37 | # Get top-k values and indices in one operation 38 | _, flat_indices = torch.topk( 39 | ranking_tensor.view(-1), 40 | k_total_batch, 41 | sorted=False, 42 | largest=True 43 | ) 44 | 45 | # Create mask directly without intermediate tensor 46 | mask = torch.zeros(F_total_batch, dtype=torch.bool, device=x.device) 47 | mask[flat_indices] = True 48 | return mask.view_as(x) 49 | else: 50 | return torch.zeros_like(x, dtype=torch.bool) 51 | 52 | @staticmethod 53 | @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) 54 | def forward( 55 | ctx, 56 | x: torch.Tensor, 57 | k: float, 58 | straight_through: bool, 59 | x_for_ranking: Optional[torch.Tensor] = None 60 | ) -> torch.Tensor: 61 | """Forward with mixed precision support.""" 62 | k_per_token = int(k) 63 | 64 | # Compute mask in FP32 for accuracy 65 | with torch.cuda.amp.autocast(enabled=False): 66 | mask = OptimizedBatchTopK._compute_mask_optimized( 67 | x.float(), k_per_token, 68 | x_for_ranking.float() if x_for_ranking is not None else None 69 | ) 70 | 71 | ctx.save_for_backward(mask) 72 | ctx.straight_through = straight_through 73 | 74 | # Apply mask in original dtype 75 | return x * mask.to(x.dtype) 76 | 77 | @staticmethod 78 | @torch.cuda.amp.custom_bwd 79 | def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor | None, None, None, None]: 80 | """Optimized backward pass.""" 81 | if ctx.straight_through: 82 | mask, = ctx.saved_tensors 83 | # Fused multiplication 84 | grad_input = grad_output * mask.to(grad_output.dtype) 85 | else: 86 | mask, = ctx.saved_tensors 87 | grad_input = grad_output * mask.to(grad_output.dtype) 88 | 89 | return grad_input, None, None, None 90 | 91 | 92 | def create_optimized_topk_mask_batched( 93 | concatenated_tensor: torch.Tensor, 94 | k_values: Dict[int, int], 95 | layer_sizes: list[tuple[int, int]] 96 | ) -> torch.Tensor: 97 | """Create masks for different layers in parallel when they have different k values.""" 98 | device = concatenated_tensor.device 99 | dtype = concatenated_tensor.dtype 100 | batch_size, total_features = concatenated_tensor.shape 101 | 102 | # Pre-allocate output mask 103 | mask = torch.zeros_like(concatenated_tensor, dtype=torch.bool) 104 | 105 | # Group layers by k value for batch processing 106 | k_groups = {} 107 | for layer_idx, (start_idx, num_features) in enumerate(layer_sizes): 108 | k_val = k_values.get(layer_idx, 0) 109 | if k_val not in k_groups: 110 | k_groups[k_val] = [] 111 | k_groups[k_val].append((layer_idx, start_idx, num_features)) 112 | 113 | # Process each k-value group 114 | for k_val, layer_infos in k_groups.items(): 115 | if k_val <= 0: 116 | continue 117 | 118 | # Gather all features for this k value 119 | indices = [] 120 | for _, start_idx, num_features in layer_infos: 121 | indices.extend(range(start_idx, start_idx + num_features)) 122 | 123 | if not indices: 124 | continue 125 | 126 | # Extract relevant features 127 | group_features = concatenated_tensor[:, indices] 128 | 129 | # Compute top-k for this group 130 | k_total = min(k_val * batch_size, group_features.numel()) 131 | if k_total > 0: 132 | _, top_indices = torch.topk( 133 | group_features.view(-1), 134 | k_total, 135 | sorted=False 136 | ) 137 | 138 | # Convert back to 2D indices 139 | row_indices = top_indices // len(indices) 140 | col_indices = top_indices % len(indices) 141 | 142 | # Map back to original positions 143 | for i, (row, col) in enumerate(zip(row_indices, col_indices)): 144 | original_col = indices[col] 145 | mask[row, original_col] = True 146 | 147 | return mask 148 | 149 | 150 | # Monkey patch for torch.compile compatibility 151 | def make_compile_compatible(): 152 | """Make activation functions compatible with torch.compile.""" 153 | try: 154 | # Check if torch.compile is available (PyTorch 2.0+) 155 | if hasattr(torch, 'compile'): 156 | # Register custom ops for better compilation 157 | torch.fx.wrap('OptimizedBatchTopK._compute_mask_optimized') 158 | except Exception as e: 159 | logger.debug(f"torch.compile compatibility setup skipped: {e}") 160 | 161 | 162 | # Initialize on module load 163 | make_compile_compatible() -------------------------------------------------------------------------------- /tests/unit/models/test_parallel_layers.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.distributed as dist 4 | import torch.multiprocessing as mp 5 | import os 6 | 7 | from clt.models.parallel import ColumnParallelLinear, RowParallelLinear 8 | 9 | 10 | def setup_distributed_environment(rank, world_size, port="12355"): 11 | """Initializes the distributed process group.""" 12 | os.environ["MASTER_ADDR"] = "localhost" 13 | os.environ["MASTER_PORT"] = port 14 | dist.init_process_group("gloo", rank=rank, world_size=world_size) 15 | 16 | 17 | def cleanup_distributed_environment(): 18 | """Cleans up the distributed process group.""" 19 | dist.destroy_process_group() 20 | 21 | 22 | def distributed_test_runner(rank, world_size, test_fn, *args): 23 | """A wrapper to run a distributed test function.""" 24 | setup_distributed_environment(rank, world_size) 25 | try: 26 | test_fn(rank, world_size, *args) 27 | finally: 28 | cleanup_distributed_environment() 29 | 30 | 31 | # --- Non-Distributed Tests (World Size = 1) --- 32 | 33 | 34 | class TestParallelLayersNonDistributed: 35 | @pytest.fixture 36 | def device(self): 37 | return torch.device("cpu") 38 | 39 | def test_column_parallel_linear_forward(self, device): 40 | """Test ColumnParallelLinear forward pass without distribution.""" 41 | layer = ColumnParallelLinear(in_features=10, out_features=20, process_group=None, device=device) 42 | input_tensor = torch.randn(5, 10, device=device) 43 | output = layer(input_tensor) 44 | assert output.shape == (5, 20) 45 | assert output.device.type == device.type 46 | 47 | def test_row_parallel_linear_forward(self, device): 48 | """Test RowParallelLinear forward pass without distribution.""" 49 | layer = RowParallelLinear( 50 | in_features=10, 51 | out_features=20, 52 | process_group=None, 53 | d_model_for_init=20, 54 | num_layers_for_init=1, 55 | device=device, 56 | ) 57 | input_tensor = torch.randn(5, 10, device=device) 58 | output = layer(input_tensor) 59 | assert output.shape == (5, 20) 60 | assert output.device.type == device.type 61 | 62 | 63 | # --- Distributed Tests (World Size = 2) --- 64 | 65 | 66 | # This function will be run in each process 67 | def _test_column_parallel_distributed_forward(rank, world_size): 68 | device = torch.device("cpu") 69 | in_features, out_features, batch_size = 10, 20, 5 70 | 71 | # Each rank has the same seed for weights 72 | torch.manual_seed(42) 73 | layer = ColumnParallelLinear(in_features, out_features, process_group=dist.group.WORLD, device=device) 74 | 75 | # Each rank gets the full input tensor 76 | torch.manual_seed(123) 77 | input_tensor = torch.randn(batch_size, in_features, device=device) 78 | 79 | output = layer(input_tensor) 80 | 81 | # Output should be gathered and identical on both ranks 82 | assert output.shape == (batch_size, out_features) 83 | 84 | # A more robust check would involve gathering the full weight matrix. 85 | # This requires gathering the full weight matrix manually. 86 | weight_slices = [torch.zeros_like(layer.weight) for _ in range(world_size)] 87 | dist.all_gather(weight_slices, layer.weight.data) 88 | full_weight = torch.cat(weight_slices, dim=0) 89 | 90 | full_bias = torch.zeros(out_features, device=device) 91 | if layer.bias_param is not None: 92 | bias_slices = [torch.zeros_like(layer.bias_param) for _ in range(world_size)] 93 | dist.all_gather(bias_slices, layer.bias_param.data) 94 | full_bias = torch.cat(bias_slices, dim=0) 95 | 96 | manual_output = torch.matmul(input_tensor, full_weight.t()) + full_bias 97 | torch.testing.assert_close(output, manual_output, rtol=1e-4, atol=1e-5) 98 | 99 | 100 | def _test_row_parallel_distributed_forward(rank, world_size): 101 | device = torch.device("cpu") 102 | in_features, out_features, batch_size = 10, 20, 5 103 | 104 | torch.manual_seed(42) 105 | layer = RowParallelLinear( 106 | in_features, 107 | out_features, 108 | process_group=dist.group.WORLD, 109 | d_model_for_init=out_features, 110 | num_layers_for_init=1, 111 | device=device, 112 | ) 113 | 114 | torch.manual_seed(123) 115 | full_input = torch.randn(batch_size, in_features, device=device) 116 | 117 | # Manually split input for each rank 118 | local_in_features = layer.local_in_features 119 | # We manually create the slice that this rank is supposed to receive. 120 | # The `input_is_parallel` flag on the layer is True by default. 121 | input_slice = full_input[:, rank * local_in_features : (rank + 1) * local_in_features] 122 | 123 | # To correctly test against a reference, we need to gather the sharded weights 124 | # into a full weight matrix on each rank. 125 | weight_slices = [torch.zeros_like(layer.weight) for _ in range(world_size)] 126 | dist.all_gather(weight_slices, layer.weight.data) 127 | full_weight_manual = torch.cat(weight_slices, dim=1) 128 | 129 | # The bias in RowParallelLinear is replicated, not sharded. 130 | # The reference calculation should use the layer's actual bias. 131 | manual_output = torch.matmul(full_input, full_weight_manual.t()) 132 | if layer.bias_param is not None: 133 | manual_output += layer.bias_param 134 | 135 | # Pass the pre-sharded input slice to the layer 136 | output = layer(input_slice) 137 | 138 | torch.testing.assert_close(output, manual_output, rtol=1e-4, atol=1e-5) 139 | 140 | 141 | @pytest.mark.skipif(not dist.is_available(), reason="torch.distributed not available") 142 | class TestParallelLayersDistributed: 143 | def test_column_parallel(self): 144 | world_size = 2 145 | mp.spawn( 146 | distributed_test_runner, 147 | args=(world_size, _test_column_parallel_distributed_forward), 148 | nprocs=world_size, 149 | join=True, 150 | ) 151 | 152 | def test_row_parallel(self): 153 | world_size = 2 154 | mp.spawn( 155 | distributed_test_runner, 156 | args=(world_size, _test_row_parallel_distributed_forward), 157 | nprocs=world_size, 158 | join=True, 159 | ) 160 | -------------------------------------------------------------------------------- /clt/training/diagnostics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Dict, Any, TYPE_CHECKING 3 | 4 | if TYPE_CHECKING: 5 | from clt.models.clt import CrossLayerTranscoder 6 | from clt.config import TrainingConfig 7 | 8 | 9 | @torch.no_grad() 10 | def compute_sparsity_diagnostics( 11 | model: "CrossLayerTranscoder", 12 | training_config: "TrainingConfig", 13 | # inputs: Dict[int, torch.Tensor], # Inputs are implicitly used by model.get_decoder_norms if not cached 14 | feature_activations: Dict[int, torch.Tensor], 15 | ) -> Dict[str, Any]: 16 | """Computes detailed sparsity diagnostics (z-scores, tanh saturation, etc.). 17 | 18 | Args: 19 | model: The CrossLayerTranscoder model instance. 20 | training_config: The training configuration for parameters like sparsity_c. 21 | feature_activations: Dictionary of feature activations (pre-computed). 22 | 23 | Returns: 24 | Dictionary containing sparsity diagnostic metrics. 25 | """ 26 | diag_metrics: Dict[str, Any] = {} 27 | layerwise_z_median: Dict[str, float] = {} 28 | layerwise_z_p90: Dict[str, float] = {} 29 | layerwise_mean_tanh: Dict[str, float] = {} 30 | layerwise_sat_frac: Dict[str, float] = {} 31 | layerwise_mean_abs_act: Dict[str, float] = {} 32 | layerwise_mean_dec_norm: Dict[str, float] = {} 33 | 34 | all_layer_medians = [] 35 | all_layer_p90s = [] 36 | all_layer_mean_tanhs = [] 37 | all_layer_sat_fracs = [] 38 | all_layer_abs_act = [] 39 | all_layer_dec_norm = [] 40 | 41 | sparsity_c = training_config.sparsity_c 42 | 43 | # Norms should be cached from the loss calculation earlier in the step, 44 | # or recomputed if necessary by get_decoder_norms() 45 | diag_dec_norms = model.get_decoder_norms() # [L, F] 46 | 47 | for l_idx, layer_acts in feature_activations.items(): 48 | if layer_acts.numel() == 0: 49 | layer_key = f"layer_{l_idx}" 50 | layerwise_z_median[layer_key] = float("nan") 51 | layerwise_z_p90[layer_key] = float("nan") 52 | layerwise_mean_tanh[layer_key] = float("nan") 53 | layerwise_sat_frac[layer_key] = float("nan") 54 | # Initialize other layerwise metrics as well for consistency if layer is skipped 55 | layerwise_mean_abs_act[layer_key] = float("nan") 56 | layerwise_mean_dec_norm[layer_key] = float("nan") 57 | continue 58 | 59 | # Ensure norms and activations are compatible and on the same device 60 | norms_l = diag_dec_norms[l_idx].to(layer_acts.device, layer_acts.dtype).unsqueeze(0) # [1, F] 61 | layer_acts = layer_acts.to(norms_l.device, norms_l.dtype) 62 | 63 | z = sparsity_c * norms_l * layer_acts # [tokens, F] 64 | on_mask = layer_acts > 1e-6 # Use a small threshold > 0 65 | z_on = z[on_mask] 66 | 67 | if z_on.numel() > 0: 68 | med = torch.median(z_on).item() 69 | p90 = torch.quantile(z_on.float(), 0.9).item() # Ensure float for quantile 70 | tanh_z_on = torch.tanh(z_on) 71 | mean_tanh = tanh_z_on.mean().item() 72 | sat_frac = (tanh_z_on.abs() > 0.99).float().mean().item() # Use abs for saturation 73 | else: 74 | med, p90, mean_tanh, sat_frac = float("nan"), float("nan"), float("nan"), float("nan") 75 | 76 | layer_key = f"layer_{l_idx}" 77 | layerwise_z_median[layer_key] = med 78 | layerwise_z_p90[layer_key] = p90 79 | layerwise_mean_tanh[layer_key] = mean_tanh 80 | layerwise_sat_frac[layer_key] = sat_frac 81 | 82 | mean_abs_act_val = layer_acts.abs().mean().item() if layer_acts.numel() > 0 else float("nan") 83 | # Ensure l_idx is valid for diag_dec_norms before accessing 84 | mean_dec_norm_val = diag_dec_norms[l_idx].mean().item() if l_idx < diag_dec_norms.shape[0] else float("nan") 85 | 86 | layerwise_mean_abs_act[layer_key] = mean_abs_act_val # Use layer_key for consistency 87 | layerwise_mean_dec_norm[layer_key] = mean_dec_norm_val # Use layer_key for consistency 88 | 89 | if not torch.isnan(torch.tensor(mean_abs_act_val)): 90 | all_layer_abs_act.append(mean_abs_act_val) 91 | if not torch.isnan(torch.tensor(mean_dec_norm_val)): 92 | all_layer_dec_norm.append(mean_dec_norm_val) 93 | 94 | if not torch.isnan(torch.tensor(med)): 95 | all_layer_medians.append(med) 96 | if not torch.isnan(torch.tensor(p90)): 97 | all_layer_p90s.append(p90) 98 | if not torch.isnan(torch.tensor(mean_tanh)): 99 | all_layer_mean_tanhs.append(mean_tanh) 100 | if not torch.isnan(torch.tensor(sat_frac)): 101 | all_layer_sat_fracs.append(sat_frac) 102 | 103 | agg_z_median = torch.tensor(all_layer_medians).mean().item() if all_layer_medians else float("nan") 104 | agg_z_p90 = torch.tensor(all_layer_p90s).mean().item() if all_layer_p90s else float("nan") 105 | agg_mean_tanh = torch.tensor(all_layer_mean_tanhs).mean().item() if all_layer_mean_tanhs else float("nan") 106 | agg_sat_frac = torch.tensor(all_layer_sat_fracs).mean().item() if all_layer_sat_fracs else float("nan") 107 | agg_mean_abs_act = torch.tensor(all_layer_abs_act).mean().item() if all_layer_abs_act else float("nan") 108 | agg_mean_dec_norm = torch.tensor(all_layer_dec_norm).mean().item() if all_layer_dec_norm else float("nan") 109 | 110 | diag_metrics["layerwise/sparsity_z_median"] = layerwise_z_median 111 | diag_metrics["layerwise/sparsity_z_p90"] = layerwise_z_p90 112 | diag_metrics["layerwise/sparsity_mean_tanh"] = layerwise_mean_tanh 113 | diag_metrics["layerwise/sparsity_sat_frac"] = layerwise_sat_frac 114 | diag_metrics["layerwise/mean_abs_activation"] = layerwise_mean_abs_act 115 | diag_metrics["layerwise/mean_decoder_norm"] = layerwise_mean_dec_norm 116 | diag_metrics["sparsity/z_median_agg"] = agg_z_median 117 | diag_metrics["sparsity/z_p90_agg"] = agg_z_p90 118 | diag_metrics["sparsity/mean_tanh_agg"] = agg_mean_tanh 119 | diag_metrics["sparsity/sat_frac_agg"] = agg_sat_frac 120 | diag_metrics["sparsity/mean_abs_activation_agg"] = agg_mean_abs_act 121 | diag_metrics["sparsity/mean_decoder_norm_agg"] = agg_mean_dec_norm 122 | 123 | # No need to explicitly delete local tensors like z, z_on etc. Python GC handles it. 124 | return diag_metrics 125 | -------------------------------------------------------------------------------- /tests/unit/training/test_loss_manager.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from clt.training.losses import LossManager 5 | from clt.models.clt import CrossLayerTranscoder 6 | from tests.helpers.tiny_configs import create_tiny_clt_config, create_tiny_training_config 7 | 8 | 9 | @pytest.fixture 10 | def tiny_model() -> CrossLayerTranscoder: 11 | """Provides a tiny CLT model for testing.""" 12 | clt_config = create_tiny_clt_config(num_features=8, d_model=4) 13 | return CrossLayerTranscoder(clt_config, process_group=None) 14 | 15 | 16 | @pytest.fixture 17 | def sample_data(): 18 | """Provides sample predictions and targets.""" 19 | device = torch.device("cpu") 20 | preds = {0: torch.randn(10, 4, device=device), 1: torch.randn(10, 4, device=device)} 21 | targets = {0: torch.randn(10, 4, device=device), 1: torch.randn(10, 4, device=device)} 22 | inputs = {0: torch.randn(10, 4, device=device), 1: torch.randn(10, 4, device=device)} 23 | activations = {0: torch.rand(10, 8, device=device), 1: torch.rand(10, 8, device=device)} 24 | return preds, targets, inputs, activations 25 | 26 | 27 | class TestLossManager: 28 | def test_reconstruction_loss_no_denorm(self, sample_data): 29 | """Test reconstruction loss without de-normalization.""" 30 | preds, targets, _, _ = sample_data 31 | config = create_tiny_training_config(activation_path="dummy") 32 | loss_manager = LossManager(config) 33 | 34 | loss = loss_manager.compute_reconstruction_loss(preds, targets) 35 | 36 | expected_loss = torch.nn.functional.mse_loss(preds[0], targets[0]) + torch.nn.functional.mse_loss( 37 | preds[1], targets[1] 38 | ) 39 | 40 | assert torch.isclose(loss, expected_loss) 41 | 42 | def test_reconstruction_loss_with_denorm(self, sample_data): 43 | """Test that de-normalization is applied correctly.""" 44 | preds, targets, _, _ = sample_data 45 | config = create_tiny_training_config(activation_path="dummy") 46 | 47 | mean_tg = {0: torch.tensor([[10.0]]), 1: torch.tensor([[-5.0]])} 48 | std_tg = {0: torch.tensor([[2.0]]), 1: torch.tensor([[0.5]])} 49 | 50 | loss_manager = LossManager(config, mean_tg=mean_tg, std_tg=std_tg) 51 | loss = loss_manager.compute_reconstruction_loss(preds, targets) 52 | 53 | pred0_denorm = preds[0] * std_tg[0] + mean_tg[0] 54 | target0_denorm = targets[0] * std_tg[0] + mean_tg[0] 55 | pred1_denorm = preds[1] * std_tg[1] + mean_tg[1] 56 | target1_denorm = targets[1] * std_tg[1] + mean_tg[1] 57 | 58 | expected_loss = torch.nn.functional.mse_loss(pred0_denorm, target0_denorm) + torch.nn.functional.mse_loss( 59 | pred1_denorm, target1_denorm 60 | ) 61 | 62 | assert torch.isclose(loss, expected_loss) 63 | 64 | @pytest.mark.parametrize("schedule", ["linear", "delayed_linear"]) 65 | def test_sparsity_penalty_schedule(self, tiny_model, sample_data, schedule): 66 | """Test sparsity penalty follows the lambda schedule.""" 67 | _, _, _, activations = sample_data 68 | config = create_tiny_training_config( 69 | sparsity_lambda=1.0, 70 | sparsity_lambda_schedule=schedule, 71 | sparsity_lambda_delay_frac=0.5, 72 | training_steps=100, 73 | activation_path="dummy", 74 | ) 75 | loss_manager = LossManager(config) 76 | 77 | # At the beginning (step 0), lambda should be 0 78 | loss_at_start, lambda_at_start = loss_manager.compute_sparsity_penalty(tiny_model, activations, 0, 100) 79 | assert torch.isclose(loss_at_start, torch.tensor(0.0)) 80 | assert lambda_at_start == 0.0 81 | 82 | # Halfway through 83 | loss_mid, lambda_mid = loss_manager.compute_sparsity_penalty(tiny_model, activations, 50, 100) 84 | if schedule == "linear": 85 | assert lambda_mid == pytest.approx(0.5) 86 | assert loss_mid > 0 87 | else: # delayed_linear with 0.5 delay 88 | assert lambda_mid == pytest.approx(0.0) 89 | assert torch.isclose(loss_mid, torch.tensor(0.0)) 90 | 91 | # At the end 92 | loss_end, lambda_end = loss_manager.compute_sparsity_penalty(tiny_model, activations, 100, 100) 93 | assert lambda_end == pytest.approx(1.0) 94 | assert loss_end > loss_mid 95 | 96 | def test_preactivation_loss(self, tiny_model, sample_data): 97 | """Test pre-activation loss penalizes negative pre-activations.""" 98 | _, _, inputs, _ = sample_data 99 | config = create_tiny_training_config(preactivation_coef=1.0, activation_path="dummy") 100 | loss_manager = LossManager(config) 101 | 102 | # Mock get_preactivations to return controlled values 103 | def mock_get_preactivations(x, layer_idx): 104 | if layer_idx == 0: 105 | return torch.tensor([[-0.5, 0.5], [-0.2, 0.8]]) # Has negative values 106 | return torch.tensor([[0.1, 0.2], [0.3, 0.4]]) # All positive 107 | 108 | tiny_model.get_preactivations = mock_get_preactivations 109 | 110 | loss = loss_manager.compute_preactivation_loss(tiny_model, inputs) 111 | 112 | # Expected penalty is from layer 0: (0.5 + 0.2) / 8 elements total 113 | expected_loss = (0.5 + 0.2) / 8 114 | assert torch.isclose(loss, torch.tensor(expected_loss)) 115 | 116 | def test_total_loss_computation(self, tiny_model, sample_data, mocker): 117 | """Test that the total loss is the sum of its components.""" 118 | preds, targets, inputs, activations = sample_data 119 | config = create_tiny_training_config(sparsity_lambda=0.1, preactivation_coef=0.1, activation_path="dummy") 120 | loss_manager = LossManager(config) 121 | 122 | mocker.patch.object(tiny_model, "get_feature_activations", return_value=activations) 123 | mocker.patch.object(tiny_model, "__call__", return_value=preds) 124 | 125 | total_loss, loss_dict = loss_manager.compute_total_loss( 126 | model=tiny_model, inputs=inputs, targets=targets, current_step=50, total_steps=100 127 | ) 128 | 129 | expected_total = ( 130 | loss_dict["reconstruction"] + loss_dict["sparsity"] + loss_dict["preactivation"] + loss_dict["auxiliary"] 131 | ) 132 | 133 | assert total_loss.item() == pytest.approx(expected_total) 134 | assert loss_dict["total"] == pytest.approx(expected_total) 135 | assert "reconstruction" in loss_dict 136 | assert "sparsity" in loss_dict 137 | assert "preactivation" in loss_dict 138 | assert "auxiliary" in loss_dict 139 | -------------------------------------------------------------------------------- /tests/unit/models/test_encoder.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import logging 4 | 5 | from clt.config import CLTConfig 6 | from clt.models.encoder import Encoder 7 | 8 | 9 | def get_available_devices(): 10 | """Returns available devices, including cpu, mps, and cuda if available.""" 11 | devices = ["cpu"] 12 | if torch.cuda.is_available(): 13 | devices.append("cuda") 14 | if torch.backends.mps.is_available(): 15 | devices.append("mps") 16 | return devices 17 | 18 | 19 | DEVICES = get_available_devices() 20 | GPU_DEVICES = [d for d in DEVICES if d != "cpu"] 21 | 22 | 23 | def require_gpu(func): 24 | """Decorator to skip tests if no GPU is available.""" 25 | return pytest.mark.skipif(not GPU_DEVICES, reason="Test requires a GPU (CUDA or MPS)")(func) 26 | 27 | 28 | @pytest.fixture 29 | def clt_config(): 30 | """Provides a basic CLTConfig for testing.""" 31 | return CLTConfig( 32 | num_layers=2, 33 | d_model=8, 34 | num_features=16, 35 | ) 36 | 37 | 38 | @pytest.fixture 39 | def encoder(clt_config, device): 40 | """Provides an Encoder instance.""" 41 | return Encoder( 42 | config=clt_config, 43 | process_group=None, # Non-distributed for unit tests 44 | device=device, 45 | dtype=torch.float32, 46 | ).to(device) 47 | 48 | 49 | class TestEncoder: 50 | def test_get_preactivations_2d_input(self, encoder, clt_config, device): 51 | """Test get_preactivations with a 2D tensor.""" 52 | batch_size = 4 53 | x = torch.randn(batch_size, clt_config.d_model, device=device) 54 | layer_idx = 0 55 | 56 | preacts = encoder.get_preactivations(x, layer_idx) 57 | 58 | assert preacts.shape == (batch_size, clt_config.num_features) 59 | assert preacts.device.type == device.type 60 | assert preacts.dtype == torch.float32 61 | 62 | def test_get_preactivations_3d_input(self, encoder, clt_config, device): 63 | """Test get_preactivations with a 3D tensor.""" 64 | batch_size = 2 65 | seq_len = 5 66 | x = torch.randn(batch_size, seq_len, clt_config.d_model, device=device) 67 | layer_idx = 1 68 | 69 | preacts = encoder.get_preactivations(x, layer_idx) 70 | 71 | assert preacts.shape == (batch_size * seq_len, clt_config.num_features) 72 | assert preacts.device.type == device.type 73 | assert preacts.dtype == torch.float32 74 | 75 | @require_gpu 76 | @pytest.mark.parametrize("device", GPU_DEVICES) 77 | def test_get_preactivations_3d_input_gpu(self, clt_config, device): 78 | """Dedicated GPU test for 3D input to ensure it runs on accelerator.""" 79 | device_obj = torch.device(device) 80 | # We need to recreate the encoder on the correct GPU device for this specific test 81 | gpu_encoder = Encoder( 82 | config=clt_config, 83 | process_group=None, 84 | device=device_obj, 85 | dtype=torch.float32, 86 | ).to(device_obj) 87 | self.test_get_preactivations_3d_input(gpu_encoder, clt_config, device_obj) 88 | 89 | def test_get_preactivations_mismatched_d_model(self, encoder, clt_config, device, caplog): 90 | """Test get_preactivations with mismatched d_model, expecting a warning and zero tensor.""" 91 | batch_size = 2 92 | seq_len = 5 93 | wrong_d_model = clt_config.d_model + 4 94 | x = torch.randn(batch_size, seq_len, wrong_d_model, device=device) 95 | layer_idx = 0 96 | 97 | with caplog.at_level(logging.WARNING): 98 | preacts = encoder.get_preactivations(x, layer_idx) 99 | 100 | assert preacts.shape == (batch_size * seq_len, clt_config.num_features) 101 | assert torch.all(preacts == 0) 102 | assert "Input d_model" in caplog.text 103 | assert str(wrong_d_model) in caplog.text 104 | 105 | def test_get_preactivations_invalid_layer_idx(self, encoder, clt_config, device, caplog): 106 | """Test get_preactivations with an out-of-bounds layer_idx.""" 107 | batch_size = 4 108 | x = torch.randn(batch_size, clt_config.d_model, device=device) 109 | invalid_layer_idx = clt_config.num_layers # num_layers is OOB 110 | 111 | with caplog.at_level(logging.ERROR): 112 | preacts = encoder.get_preactivations(x, invalid_layer_idx) 113 | 114 | assert preacts.shape == (batch_size, clt_config.num_features) 115 | assert torch.all(preacts == 0) 116 | assert "Invalid layer index" in caplog.text 117 | assert str(invalid_layer_idx) in caplog.text 118 | 119 | def test_encode_all_layers(self, encoder, clt_config, device): 120 | """Test the encode_all_layers method.""" 121 | inputs = { 122 | 0: torch.randn(4, clt_config.d_model, device=device), 123 | 1: torch.randn(2, 5, clt_config.d_model, device=device), 124 | } 125 | 126 | preacts_dict, shapes_info = encoder.encode_all_layers(inputs) 127 | 128 | # Check keys 129 | assert sorted(preacts_dict.keys()) == sorted(inputs.keys()) 130 | 131 | # Check shapes and values 132 | assert preacts_dict[0].shape == (4, clt_config.num_features) 133 | assert preacts_dict[1].shape == (10, clt_config.num_features) 134 | 135 | # Verify against direct calls 136 | torch.testing.assert_close(preacts_dict[0], encoder.get_preactivations(inputs[0], 0)) 137 | torch.testing.assert_close(preacts_dict[1], encoder.get_preactivations(inputs[1], 1)) 138 | 139 | # Check shapes_info 140 | assert len(shapes_info) == 2 141 | # Order should be sorted by layer_idx 142 | assert shapes_info[0] == (0, 4, 1) # (layer_idx, batch_size, seq_len) 143 | assert shapes_info[1] == (1, 2, 5) 144 | 145 | def test_get_preactivations_unusual_input_dims(self, encoder, clt_config, device, caplog): 146 | """Test get_preactivations with unexpected input dimensions (1D, 4D).""" 147 | # Test with 1D input 148 | x_1d = torch.randn(clt_config.d_model, device=device) 149 | with caplog.at_level(logging.WARNING): 150 | preacts_1d = encoder.get_preactivations(x_1d, 0) 151 | assert "Cannot handle input shape" in caplog.text 152 | # The fallback logic uses shape[0] as the batch dim. 153 | assert preacts_1d.shape == (x_1d.shape[0], clt_config.num_features) 154 | 155 | caplog.clear() 156 | 157 | # Test with 4D input 158 | x_4d = torch.randn(2, 3, 4, clt_config.d_model, device=device) 159 | with caplog.at_level(logging.WARNING): 160 | preacts_4d = encoder.get_preactivations(x_4d, 0) 161 | assert "Cannot handle input shape" in caplog.text 162 | assert preacts_4d.shape == (x_4d.shape[0], clt_config.num_features) 163 | -------------------------------------------------------------------------------- /tests/unit/training/data/test_local_activation_store.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pathlib import Path 3 | import numpy as np 4 | 5 | from clt.training.data.local_activation_store import LocalActivationStore 6 | 7 | 8 | class TestLocalActivationStore: 9 | def test_initialization(self, tmp_local_dataset: Path): 10 | """Test that the store initializes correctly with a local dataset.""" 11 | store = LocalActivationStore( 12 | dataset_path=str(tmp_local_dataset), 13 | train_batch_size_tokens=16, 14 | rank=0, 15 | world=1, 16 | ) 17 | assert store.num_layers == 2 18 | assert store.d_model == 8 19 | assert store.total_tokens == 64 20 | assert len(store.sampler) == 64 // 16 21 | 22 | def test_get_batch_unsharded(self, tmp_local_dataset: Path): 23 | """Test fetching a single batch from the store without sharding.""" 24 | store = LocalActivationStore( 25 | dataset_path=str(tmp_local_dataset), 26 | train_batch_size_tokens=16, 27 | dtype="float16", # Match the dtype of the data in the fixture 28 | rank=0, 29 | world=1, 30 | ) 31 | inputs, targets = store.get_batch() 32 | 33 | assert isinstance(inputs, dict) 34 | assert isinstance(targets, dict) 35 | assert len(inputs) == store.num_layers 36 | assert len(targets) == store.num_layers 37 | 38 | for i in range(store.num_layers): 39 | assert i in inputs 40 | assert i in targets 41 | assert inputs[i].shape == (16, store.d_model) 42 | assert targets[i].shape == (16, store.d_model) 43 | assert inputs[i].dtype == torch.float16 44 | assert targets[i].device.type == store.device.type 45 | 46 | def test_iteration_unsharded(self, tmp_local_dataset: Path): 47 | """Test iterating through the entire dataset without sharding.""" 48 | store = LocalActivationStore( 49 | dataset_path=str(tmp_local_dataset), 50 | train_batch_size_tokens=16, 51 | rank=0, 52 | world=1, 53 | dtype="float16", 54 | ) 55 | 56 | num_batches = 0 57 | total_tokens = 0 58 | for inputs, targets in store: 59 | num_batches += 1 60 | # Get token count from the first layer's input tensor 61 | batch_tokens = next(iter(inputs.values())).shape[0] 62 | total_tokens += batch_tokens 63 | 64 | assert num_batches == len(store.sampler) 65 | assert total_tokens == store.total_tokens 66 | 67 | def test_sharding_iteration(self, tmp_local_dataset: Path): 68 | """Test that sharding splits the data correctly across ranks.""" 69 | world_size = 2 70 | total_tokens_processed = 0 71 | 72 | for rank in range(world_size): 73 | store = LocalActivationStore( 74 | dataset_path=str(tmp_local_dataset), 75 | train_batch_size_tokens=16, 76 | rank=rank, 77 | world=world_size, 78 | shard_data=True, 79 | dtype="float16", 80 | ) 81 | 82 | rank_tokens = 0 83 | for inputs, _ in store: 84 | rank_tokens += next(iter(inputs.values())).shape[0] 85 | 86 | # Each rank should process roughly half the tokens 87 | assert rank_tokens == store.total_tokens // world_size 88 | total_tokens_processed += rank_tokens 89 | 90 | # The sum of tokens processed by all ranks should equal the total tokens 91 | assert total_tokens_processed == store.total_tokens 92 | 93 | def test_state_dict_roundtrip(self, tmp_local_dataset: Path): 94 | """Test that the store can be resumed from a saved state.""" 95 | store1 = LocalActivationStore(dataset_path=str(tmp_local_dataset), train_batch_size_tokens=16, dtype="float16") 96 | 97 | # Get first batch 98 | batch1_inputs, _ = store1.get_batch() 99 | 100 | # Save state 101 | state = store1.state_dict() 102 | 103 | # Create a new store and load state 104 | store2 = LocalActivationStore(dataset_path=str(tmp_local_dataset), train_batch_size_tokens=16, dtype="float16") 105 | store2.load_state_dict(state) 106 | 107 | # Get next batch from the new store 108 | batch2_inputs, _ = store2.get_batch() 109 | 110 | # Get the second batch from the original store for comparison 111 | expected_batch2_inputs, _ = store1.get_batch() 112 | 113 | # The batch from the resumed store should match the next batch from the original 114 | for i in store1.layer_indices: 115 | torch.testing.assert_close(batch2_inputs[i], expected_batch2_inputs[i]) 116 | 117 | def test_layer_data_integrity(self, tmp_local_dataset: Path): 118 | """Test that each layer's data is distinct and not mixed up.""" 119 | store = LocalActivationStore( 120 | dataset_path=str(tmp_local_dataset), 121 | train_batch_size_tokens=16, 122 | dtype="float16", 123 | ) 124 | 125 | inputs, targets = store.get_batch() 126 | 127 | # Verify each layer has distinct data 128 | # The fixture creates random data: inputs in [0, 10), targets in [0, 5) 129 | for layer_id in range(store.num_layers): 130 | layer_inputs = inputs[layer_id].cpu().numpy() 131 | layer_targets = targets[layer_id].cpu().numpy() 132 | 133 | # Basic sanity checks on the data 134 | assert layer_inputs.shape == (16, store.d_model), f"Layer {layer_id} inputs have wrong shape" 135 | assert layer_targets.shape == (16, store.d_model), f"Layer {layer_id} targets have wrong shape" 136 | 137 | # Check that values are in reasonable ranges (inputs: [0, 10), targets: [0, 5)) 138 | assert layer_inputs.min() >= 0, f"Layer {layer_id} inputs have negative values" 139 | assert layer_inputs.max() < 15, f"Layer {layer_id} inputs have unexpectedly large values" 140 | assert layer_targets.min() >= 0, f"Layer {layer_id} targets have negative values" 141 | assert layer_targets.max() < 10, f"Layer {layer_id} targets have unexpectedly large values" 142 | 143 | # Verify data is different between layers (statistically very unlikely to be identical) 144 | if layer_id > 0: 145 | prev_layer_inputs = inputs[layer_id - 1].cpu().numpy() 146 | # Check that at least some values differ between layers 147 | assert not np.array_equal( 148 | layer_inputs, prev_layer_inputs 149 | ), f"Layer {layer_id} data identical to layer {layer_id - 1}" 150 | 151 | # Additional check: mean values should differ (statistically) 152 | mean_diff = abs(layer_inputs.mean() - prev_layer_inputs.mean()) 153 | assert mean_diff > 0.01, f"Layer {layer_id} and {layer_id - 1} have suspiciously similar means" 154 | -------------------------------------------------------------------------------- /tests/unit/models/test_theta.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.nn as nn 4 | import logging 5 | from torch.utils.data import IterableDataset 6 | 7 | from clt.config import CLTConfig 8 | from clt.models.theta import ThetaManager 9 | 10 | 11 | def get_available_devices(): 12 | """Returns available devices, including cpu, mps, and cuda if available.""" 13 | devices = ["cpu"] 14 | if torch.cuda.is_available(): 15 | devices.append("cuda") 16 | if torch.backends.mps.is_available(): 17 | devices.append("mps") 18 | return devices 19 | 20 | 21 | DEVICES = get_available_devices() 22 | 23 | 24 | @pytest.fixture(params=DEVICES) 25 | def device(request): 26 | """Fixture to iterate over all available devices.""" 27 | return torch.device(request.param) 28 | 29 | 30 | @pytest.fixture 31 | def clt_config_jumprelu(): 32 | """Provides a CLTConfig for testing JumpReLU.""" 33 | return CLTConfig( 34 | num_layers=2, 35 | d_model=8, 36 | num_features=16, 37 | activation_fn="jumprelu", 38 | jumprelu_threshold=0.5, 39 | ) 40 | 41 | 42 | @pytest.fixture 43 | def clt_config_batchtopk(): 44 | """Provides a CLTConfig for testing BatchTopK conversion.""" 45 | return CLTConfig( 46 | num_layers=2, 47 | d_model=8, 48 | num_features=4, 49 | activation_fn="batchtopk", 50 | batchtopk_k=2, 51 | ) 52 | 53 | 54 | class MockIterableDataset(IterableDataset): 55 | def __init__(self, data): 56 | self.data = data 57 | 58 | def __iter__(self): 59 | return iter(self.data) 60 | 61 | 62 | class TestThetaManager: 63 | def test_initialization_jumprelu(self, clt_config_jumprelu, device): 64 | """Test ThetaManager initializes correctly for JumpReLU.""" 65 | tm = ThetaManager(clt_config_jumprelu, None, device, torch.float32) 66 | assert isinstance(tm.log_threshold, nn.Parameter) 67 | assert tm.log_threshold.shape == ( 68 | clt_config_jumprelu.num_layers, 69 | clt_config_jumprelu.num_features, 70 | ) 71 | expected_val = torch.log(torch.tensor(0.5, device=device)) 72 | torch.testing.assert_close(tm.log_threshold.mean(), expected_val) 73 | 74 | def test_initialization_other_activation(self, clt_config_batchtopk, device): 75 | """Test ThetaManager initializes correctly for other activations.""" 76 | tm = ThetaManager(clt_config_batchtopk, None, device, torch.float32) 77 | assert tm.log_threshold is None 78 | 79 | def test_jumprelu_activation(self, clt_config_jumprelu, device): 80 | """Test the JumpReLU activation logic.""" 81 | tm = ThetaManager(clt_config_jumprelu, None, device, torch.float32) 82 | # Input must match the number of features 83 | preacts = torch.linspace(-1.0, 1.0, clt_config_jumprelu.num_features, device=device).view(1, -1) 84 | # threshold is 0.5 85 | layer_idx = 0 86 | activated = tm.jumprelu(preacts, layer_idx) 87 | 88 | # Manually compute expected output 89 | threshold_val = clt_config_jumprelu.jumprelu_threshold 90 | expected = torch.where(preacts >= threshold_val, preacts, torch.zeros_like(preacts)) 91 | 92 | torch.testing.assert_close(activated, expected) 93 | 94 | def test_jumprelu_invalid_layer_idx(self, clt_config_jumprelu, device, caplog): 95 | """Test JumpReLU with an out-of-bounds layer index.""" 96 | tm = ThetaManager(clt_config_jumprelu, None, device, torch.float32) 97 | preacts = torch.randn(2, 16, device=device) 98 | with caplog.at_level(logging.ERROR): 99 | activated = tm.jumprelu(preacts, layer_idx=clt_config_jumprelu.num_layers) 100 | assert "Invalid layer_idx" in caplog.text 101 | # Should return the original tensor 102 | torch.testing.assert_close(activated, preacts) 103 | 104 | def test_convert_to_jumprelu_raises_error_if_stats_missing(self, clt_config_batchtopk, device): 105 | """Test that conversion fails if estimation has not been run.""" 106 | tm = ThetaManager(clt_config_batchtopk, None, device, torch.float32) 107 | with pytest.raises(RuntimeError, match="Required buffer .* not found"): 108 | tm.convert_to_jumprelu_inplace() 109 | 110 | def test_estimate_and_convert_posthoc(self, clt_config_batchtopk, device): 111 | """ 112 | Test the full estimate_theta_posthoc and convert_to_jumprelu_inplace flow. 113 | """ 114 | tm = ThetaManager(clt_config_batchtopk, None, device, torch.float32) 115 | num_features = clt_config_batchtopk.num_features 116 | num_layers = clt_config_batchtopk.num_layers 117 | 118 | def mock_encode_all_layers(inputs): 119 | # This mock should return preactivations that look like they came from a real model 120 | # i.e., they have some mean and std dev. The estimation process will normalize them. 121 | # Layer 0: High values for first k features 122 | preacts_l0 = torch.cat( 123 | [ 124 | torch.randn(4, 2, device=device) + 5, # High values for top-k 125 | torch.randn(4, 2, device=device), # Low values for others 126 | ], 127 | dim=1, 128 | ) 129 | # Layer 1: High values for last k features 130 | preacts_l1 = torch.cat( 131 | [torch.randn(4, 2, device=device), torch.randn(4, 2, device=device) + 5], # Low values # High values 132 | dim=1, 133 | ) 134 | return {0: preacts_l0, 1: preacts_l1}, [] 135 | 136 | # Mock data iterator to yield one batch 137 | mock_data = [({0: torch.randn(1, 8), 1: torch.randn(1, 8)}, None)] 138 | mock_data_iter = MockIterableDataset(mock_data) 139 | 140 | # --- Run estimation --- 141 | tm.estimate_theta_posthoc( 142 | encode_all_layers_fn=mock_encode_all_layers, 143 | data_iter=mock_data_iter, 144 | num_batches=1, 145 | ) 146 | 147 | # --- Check results of conversion --- 148 | assert tm.config.activation_fn == "jumprelu" 149 | assert tm.log_threshold is not None 150 | assert tm.log_threshold.shape == (num_layers, num_features) 151 | 152 | # The post-hoc estimation logic is complex. Instead of asserting exact values, 153 | # which are sensitive to the mock, we check for reasonable behavior: 154 | # 1. Thetas should be positive and finite. 155 | # 2. Thetas for the activated features should be higher than for non-activated ones. 156 | final_thetas = torch.exp(tm.log_threshold) 157 | assert torch.all(torch.isfinite(final_thetas)) 158 | assert torch.all(final_thetas > 0) 159 | 160 | # For layer 0, the first 2 features were consistently active and high. 161 | # Their thresholds should be higher than the last 2 features. 162 | assert torch.all(final_thetas[0, :2] > final_thetas[0, 2:]) 163 | 164 | # For layer 1, the last 2 features were active. 165 | assert torch.all(final_thetas[1, 2:] > final_thetas[1, :2]) 166 | -------------------------------------------------------------------------------- /tests/integration/test_checkpoint_resumption.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from pathlib import Path 4 | from safetensors.torch import load_file as load_safetensors_file 5 | 6 | from clt.config import CLTConfig, TrainingConfig 7 | from clt.training.trainer import CLTTrainer 8 | from tests.helpers.tiny_configs import create_tiny_clt_config, create_tiny_training_config 9 | 10 | 11 | @pytest.fixture 12 | def tiny_clt_config() -> CLTConfig: 13 | """Provides a basic CLTConfig for testing.""" 14 | return create_tiny_clt_config(num_layers=2, d_model=8, num_features=16) 15 | 16 | 17 | @pytest.fixture 18 | def tiny_training_config_factory(tmp_local_dataset: Path): 19 | """ 20 | Provides a factory function to create TrainingConfig instances, 21 | allowing customization of steps and intervals per test. 22 | """ 23 | 24 | def _factory(training_steps: int, checkpoint_interval: int, eval_interval: int = 10_000) -> TrainingConfig: 25 | return create_tiny_training_config( 26 | training_steps=training_steps, 27 | checkpoint_interval=checkpoint_interval, 28 | eval_interval=eval_interval, 29 | train_batch_size_tokens=16, 30 | activation_source="local_manifest", 31 | activation_path=str(tmp_local_dataset), 32 | activation_dtype="float32", 33 | precision="fp32", 34 | ) 35 | 36 | return _factory 37 | 38 | 39 | class TestCheckpointResumption: 40 | def test_resume_from_checkpoint_produces_identical_state( 41 | self, 42 | tiny_clt_config: CLTConfig, 43 | tiny_training_config_factory, 44 | tmp_path: Path, 45 | ): 46 | """ 47 | Verify that resuming training from a checkpoint produces the exact same 48 | model parameters, optimizer state, and loss as a continuous run. 49 | """ 50 | # --- Configuration --- 51 | total_steps = 10 52 | checkpoint_step = 9 # Checkpoint is saved at the final step 53 | log_dir_initial = tmp_path / "initial_run" 54 | log_dir_resumed = tmp_path / "resumed_run" 55 | 56 | # === 1. Initial Training Run (to generate a checkpoint) === 57 | initial_config = tiny_training_config_factory(training_steps=total_steps, checkpoint_interval=checkpoint_step) 58 | initial_trainer = CLTTrainer( 59 | clt_config=tiny_clt_config, 60 | training_config=initial_config, 61 | log_dir=str(log_dir_initial), 62 | device="cpu", 63 | ) 64 | initial_trainer.train() 65 | 66 | # Dynamically find the latest *numbered* checkpoint created 67 | trainer_state_files = sorted([p for p in log_dir_initial.glob("trainer_state_*.pt") if "latest" not in p.name]) 68 | assert trainer_state_files, "No numbered trainer state checkpoint files found." 69 | 70 | latest_trainer_state_path = trainer_state_files[-1] 71 | stem = latest_trainer_state_path.stem 72 | checkpoint_step = int(stem.split("_")[-1]) if stem.split("_")[-1].isdigit() else -1 73 | assert checkpoint_step != -1, f"Could not parse step number from filename: {latest_trainer_state_path.name}" 74 | 75 | # Capture the state from the initial run *at the checkpoint step* for later comparison. 76 | model_state_path = log_dir_initial / f"clt_checkpoint_{checkpoint_step}.safetensors" 77 | 78 | assert latest_trainer_state_path.exists(), f"Trainer state file not found at {latest_trainer_state_path}" 79 | assert model_state_path.exists(), f"Model state file not found at {model_state_path}" 80 | 81 | state_from_checkpoint = torch.load(latest_trainer_state_path, map_location="cpu", weights_only=False) 82 | model_state_at_checkpoint = load_safetensors_file(model_state_path, device="cpu") 83 | 84 | # === 2. Resumed Training Run === 85 | # Now, create the actual trainer that will resume from the checkpoint 86 | resumed_config = tiny_training_config_factory(training_steps=total_steps, checkpoint_interval=10_000) 87 | resumed_trainer = CLTTrainer( 88 | clt_config=tiny_clt_config, 89 | training_config=resumed_config, 90 | log_dir=str(log_dir_resumed), 91 | device="cpu", 92 | resume_from_checkpoint_path=str(model_state_path), # Resume from the model file 93 | ) 94 | 95 | # === 3. Verification === 96 | # a) Check that the trainer state (step, optimizer, etc.) was loaded correctly 97 | assert resumed_trainer.loaded_trainer_state is not None 98 | assert resumed_trainer.loaded_trainer_state["step"] == checkpoint_step 99 | 100 | # Manually trigger the state loading logic that happens at the start of train() 101 | resumed_trainer.optimizer.load_state_dict(resumed_trainer.loaded_trainer_state["optimizer_state_dict"]) 102 | if resumed_trainer.scheduler: 103 | resumed_trainer.scheduler.load_state_dict(resumed_trainer.loaded_trainer_state["scheduler_state_dict"]) 104 | if resumed_trainer.scaler and resumed_trainer.scaler.is_enabled(): 105 | resumed_trainer.scaler.load_state_dict(resumed_trainer.loaded_trainer_state["scaler_state_dict"]) 106 | 107 | # b) Check optimizer state 108 | # Convert both to dictionaries on CPU for consistent comparison 109 | resumed_optim_state = resumed_trainer.optimizer.state_dict() 110 | checkpoint_optim_state = state_from_checkpoint["optimizer_state_dict"] 111 | 112 | # We can't do a direct tensor comparison due to floating point variations in state like 'exp_avg' 113 | # Instead, we'll check that the structure is the same. 114 | assert resumed_optim_state.keys() == checkpoint_optim_state.keys(), "Optimizer state keys do not match." 115 | assert len(resumed_optim_state["state"]) == len( 116 | checkpoint_optim_state["state"] 117 | ), "Optimizer state dictionary lengths do not match." 118 | 119 | # c) Check model parameters 120 | resumed_model_state = resumed_trainer.model.state_dict() 121 | for key in model_state_at_checkpoint: 122 | assert key in resumed_model_state, f"Key '{key}' missing in resumed model state." 123 | assert torch.allclose( 124 | model_state_at_checkpoint[key], resumed_model_state[key] 125 | ), f"Model parameter '{key}' does not match after resuming." 126 | 127 | # d) Continue training and verify loss is identical to a continuous run 128 | # We will compare the final model state after 10 steps instead, as it's a simpler and robust check. 129 | resumed_trainer.train() 130 | 131 | # Get the final model from the initial run by loading its final checkpoint 132 | final_initial_model_path = log_dir_initial / "clt_checkpoint_latest.safetensors" 133 | final_model_from_initial_run = load_safetensors_file(final_initial_model_path, device="cpu") 134 | 135 | final_model_from_resumed_run = resumed_trainer.model.state_dict() 136 | 137 | for key in final_model_from_initial_run: 138 | assert torch.allclose( 139 | final_model_from_initial_run[key], final_model_from_resumed_run[key] 140 | ), f"Final model parameter '{key}' does not match between continuous and resumed runs." 141 | -------------------------------------------------------------------------------- /clt/training/data/remote_activation_store.py: -------------------------------------------------------------------------------- 1 | """RemoteActivationStore – manifest‑driven exactly‑once sampler. 2 | 3 | Replaces the old stateless random‑batch client. The workflow is: 4 | 5 | 1. On startup, download `metadata.json`, `index.bin`, optional 6 | `norm_stats.json`. 7 | 2. Build a `ShardedIndexSampler` that shuffles the manifest every epoch 8 | and yields *contiguous slices* of `batch_size` rows (may span 9 | multiple chunks). Each GPU rank consumes a disjoint strided subset. 10 | 3. For each batch: 11 | * Group the next `B` manifest entries by `chunk_id`. 12 | * For each chunk request `/slice?chunk=X&rows=i,j,k`. 13 | * Parse the raw bf16 bytes into tensors: \[layers\] → inputs, targets. 14 | 4. Apply normalization and return `Dict[layer → Tensor]`. 15 | 16 | This module requires the server refactor (`/slice` endpoint). 17 | """ 18 | 19 | from __future__ import annotations 20 | 21 | import logging 22 | import json 23 | from typing import Dict, Optional, Any 24 | 25 | import numpy as np 26 | import torch 27 | import requests 28 | from urllib.parse import urljoin, quote 29 | import time 30 | 31 | from .manifest_activation_store import ManifestActivationStore 32 | 33 | logger = logging.getLogger(__name__) 34 | 35 | 36 | class RemoteActivationStore(ManifestActivationStore): 37 | """ 38 | Activation store that fetches data from a remote slice server using 39 | a manifest file for deterministic, sharded sampling. 40 | Inherits common logic from ManifestActivationStore. 41 | """ 42 | 43 | def __init__( 44 | self, 45 | server_url: str, 46 | dataset_id: str, 47 | train_batch_size_tokens: int = 4096, 48 | device: torch.device | str | None = None, 49 | dtype: torch.dtype | str = "bfloat16", 50 | rank: int = 0, 51 | world: int = 1, 52 | seed: int = 42, 53 | timeout: int = 60, 54 | sampling_strategy: str = "sequential", 55 | normalization_method: str = "none", 56 | shard_data: bool = True, 57 | ): 58 | self.server = server_url.rstrip("/") + "/" 59 | self.did_enc = quote(dataset_id, safe="") 60 | self.did_raw = dataset_id 61 | self.timeout = timeout 62 | 63 | super().__init__( 64 | train_batch_size_tokens=train_batch_size_tokens, 65 | device=device, 66 | dtype=dtype, 67 | rank=rank, 68 | world=world, 69 | seed=seed, 70 | sampling_strategy=sampling_strategy, 71 | normalization_method=normalization_method, 72 | shard_data=shard_data, 73 | ) 74 | 75 | logger.info( 76 | "RemoteActivationStore initialized for dataset '%s' at %s " 77 | "(Rank %d/%d, Seed %d, Batch %d, Device %s, Dtype %s, Strategy '%s')", 78 | self.did_raw, 79 | self.server, 80 | self.rank, 81 | self.world, 82 | self.seed, 83 | self.train_batch_size_tokens, 84 | self.device, 85 | self.dtype, 86 | self.sampling_strategy, 87 | ) 88 | 89 | def _load_metadata(self) -> Optional[Dict[str, Any]]: 90 | return self._get_json("info", required=True, retries=3) 91 | 92 | def _load_manifest(self) -> Optional[np.ndarray]: 93 | max_retries = 3 94 | base_delay = 2 95 | url = urljoin(self.server, f"datasets/{self.did_enc}/manifest") 96 | logger.info(f"Downloading manifest from {url}") 97 | 98 | for attempt in range(max_retries): 99 | try: 100 | r = requests.get(url, timeout=self.timeout) 101 | r.raise_for_status() 102 | data = np.frombuffer(r.content, dtype=np.uint32).reshape(-1, 2) 103 | logger.info(f"Manifest downloaded ({len(data)} rows, {len(r.content) / 1024:.1f} KiB).") 104 | return data 105 | except requests.exceptions.RequestException as e: 106 | logger.warning(f"Attempt {attempt + 1}/{max_retries} failed to download manifest from {url}: {e}") 107 | if attempt + 1 == max_retries: 108 | logger.error(f"Final attempt failed to download manifest. Returning None.") 109 | return None 110 | else: 111 | delay = base_delay * (2**attempt) 112 | logger.info(f"Retrying manifest download in {delay:.1f} seconds...") 113 | time.sleep(delay) 114 | except ValueError as e: 115 | logger.error(f"Error reshaping manifest data (expected Nx2 shape): {e}") 116 | return None 117 | return None 118 | 119 | def _load_norm_stats(self) -> Optional[Dict[str, Any]]: 120 | return self._get_json("norm_stats", required=False, retries=1) 121 | 122 | def _fetch_slice(self, chunk_id: int, row_indices: np.ndarray) -> bytes: 123 | if row_indices.dtype != np.uint32: 124 | logger.warning(f"Row indices dtype is {row_indices.dtype}, expected uint32. Casting.") 125 | row_indices = row_indices.astype(np.uint32) 126 | rows_list = row_indices.tolist() 127 | url = urljoin( 128 | self.server, 129 | f"datasets/{self.did_enc}/slice?chunk={chunk_id}", 130 | ) 131 | try: 132 | r = requests.post(url, json={"rows": rows_list}, timeout=self.timeout) 133 | r.raise_for_status() 134 | return r.content 135 | except requests.exceptions.Timeout: 136 | logger.error(f"Timeout fetching slice from {url}") 137 | raise 138 | except requests.exceptions.RequestException as e: 139 | logger.error(f"HTTP error fetching slice from {url}: {e}") 140 | raise RuntimeError(f"Failed to fetch slice for chunk {chunk_id}") from e 141 | 142 | def _get_json(self, endpoint: str, required: bool = True, retries: int = 1) -> Optional[Dict[str, Any]]: 143 | base_delay = 1 144 | url = urljoin(self.server, f"datasets/{self.did_enc}/{endpoint}") 145 | for attempt in range(retries): 146 | try: 147 | r = requests.get(url, timeout=self.timeout) 148 | if not r.ok: 149 | if r.status_code == 404 and not required: 150 | logger.info(f"Optional resource not found at {url} (404)") 151 | return None 152 | else: 153 | r.raise_for_status() 154 | data = r.json() 155 | logger.info(f"Successfully fetched JSON from {url} on attempt {attempt + 1}") 156 | return data 157 | except (requests.exceptions.RequestException, json.JSONDecodeError) as e: 158 | logger.warning(f"Attempt {attempt + 1}/{retries} failed to fetch JSON from {url}: {e}") 159 | if attempt + 1 == retries: 160 | logger.error(f"Final attempt failed to fetch JSON from {url}. Returning None or raising error.") 161 | if required: 162 | raise RuntimeError( 163 | f"Failed to fetch required resource {endpoint} after {retries} attempts" 164 | ) from e 165 | else: 166 | return None 167 | else: 168 | delay = base_delay * (2**attempt) 169 | logger.info(f"Retrying JSON fetch from {url} in {delay:.1f} seconds...") 170 | time.sleep(delay) 171 | if required: 172 | raise RuntimeError(f"Failed to fetch required resource {endpoint} after {retries} attempts (logic error)") 173 | return None 174 | -------------------------------------------------------------------------------- /clt/models/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Dict, List, Tuple, Optional 4 | import logging 5 | 6 | from clt.config import CLTConfig 7 | from clt.models.parallel import ColumnParallelLinear 8 | from clt.parallel import ops as dist_ops 9 | from torch.distributed import ProcessGroup 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class Encoder(nn.Module): 15 | """ 16 | Encapsulates the encoder functionality of the CrossLayerTranscoder. 17 | It holds the stack of encoder layers and provides methods to get 18 | pre-activations. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | config: CLTConfig, 24 | process_group: Optional[ProcessGroup], 25 | device: torch.device, 26 | dtype: torch.dtype, 27 | ): 28 | super().__init__() 29 | self.config = config 30 | self.process_group = process_group 31 | self.device = device 32 | self.dtype = dtype 33 | 34 | self.world_size = dist_ops.get_world_size(process_group) 35 | self.rank = dist_ops.get_rank(process_group) 36 | 37 | self.encoders = nn.ModuleList( 38 | [ 39 | ColumnParallelLinear( 40 | in_features=config.d_model, 41 | out_features=config.num_features, 42 | bias=True, 43 | process_group=self.process_group, 44 | device=self.device, 45 | dtype=self.dtype, 46 | ) 47 | for _ in range(config.num_layers) 48 | ] 49 | ) 50 | 51 | # Note: feature_offset and feature_scale have been moved to Decoder module 52 | # to match EleutherAI's architecture where they are indexed by target layer 53 | 54 | def get_preactivations(self, x: torch.Tensor, layer_idx: int) -> torch.Tensor: 55 | """Get pre-activation values (full tensor) for features at the specified layer.""" 56 | result: Optional[torch.Tensor] = None 57 | fallback_shape: Optional[Tuple[int, int]] = None 58 | input_for_linear: Optional[torch.Tensor] = None 59 | 60 | # Ensure input is on the correct device and dtype 61 | x = x.to(device=self.device, dtype=self.dtype) 62 | 63 | try: 64 | # 1. Check input shape and reshape if necessary 65 | if x.dim() == 2: 66 | input_for_linear = x 67 | elif x.dim() == 3: 68 | batch, seq_len, d_model = x.shape 69 | if d_model != self.config.d_model: 70 | logger.warning( 71 | f"Rank {self.rank}: Input d_model {d_model} != config {self.config.d_model} layer {layer_idx}" 72 | ) 73 | fallback_shape = (batch * seq_len, self.config.num_features) 74 | else: 75 | input_for_linear = x.reshape(-1, d_model) 76 | else: 77 | logger.warning( 78 | f"Rank {self.rank}: Cannot handle input shape {x.shape} for preactivations layer {layer_idx}" 79 | ) 80 | fallback_batch_dim = x.shape[0] if x.dim() > 0 else 0 81 | fallback_shape = (fallback_batch_dim, self.config.num_features) 82 | 83 | # 2. Check d_model match if not already done and input_for_linear was set 84 | if fallback_shape is None and input_for_linear is not None: 85 | if input_for_linear.shape[1] != self.config.d_model: 86 | logger.warning( 87 | f"Rank {self.rank}: Input d_model {input_for_linear.shape[1]} != config {self.config.d_model} layer {layer_idx}" 88 | ) 89 | fallback_shape = (input_for_linear.shape[0], self.config.num_features) 90 | elif fallback_shape is None and input_for_linear is None: 91 | logger.error( 92 | f"Rank {self.rank}: Could not determine input for linear layer {layer_idx} and no fallback shape set. Input x.shape: {x.shape}" 93 | ) 94 | fallback_batch_dim = x.shape[0] if x.dim() > 0 else 0 95 | fallback_shape = (fallback_batch_dim, self.config.num_features) 96 | 97 | # 3. Proceed if no errors so far (i.e. fallback_shape is still None) 98 | if fallback_shape is None and input_for_linear is not None: 99 | # The input_for_linear is already on self.device and self.dtype due to the .to() call at the start of the function 100 | # or because it's derived from x which was moved. 101 | result = self.encoders[layer_idx](input_for_linear) 102 | elif fallback_shape is None and input_for_linear is None: 103 | logger.error( 104 | f"Rank {self.rank}: Critical logic error in get_preactivations for layer {layer_idx}. input_for_linear is None and fallback_shape is None. Input x.shape: {x.shape}" 105 | ) 106 | fallback_batch_dim = x.shape[0] if x.dim() > 0 else 0 107 | fallback_shape = (fallback_batch_dim, self.config.num_features) 108 | 109 | except IndexError: 110 | logger.error( 111 | f"Rank {self.rank}: Invalid layer index {layer_idx} requested for encoder. Max index is {len(self.encoders) - 1}." 112 | ) 113 | if x.dim() == 2: 114 | fallback_batch_dim = x.shape[0] 115 | elif x.dim() == 3: 116 | fallback_batch_dim = x.shape[0] * x.shape[1] 117 | elif x.dim() > 0: 118 | fallback_batch_dim = x.shape[0] 119 | else: 120 | fallback_batch_dim = 0 121 | fallback_shape = (fallback_batch_dim, self.config.num_features) 122 | 123 | if result is not None: 124 | return result 125 | else: 126 | if fallback_shape is None: 127 | logger.error( 128 | f"Rank {self.rank}: Fallback shape not determined for layer {layer_idx}, and no result. Input x.shape: {x.shape}. Returning empty tensor." 129 | ) 130 | fallback_shape = (0, self.config.num_features) 131 | return torch.zeros(fallback_shape, device=self.device, dtype=self.dtype) 132 | 133 | def encode_all_layers( 134 | self, inputs: Dict[int, torch.Tensor] 135 | ) -> Tuple[Dict[int, torch.Tensor], List[Tuple[int, int, int]]]: 136 | """ 137 | Encodes inputs for all layers using the stored encoders. 138 | Assumes input tensors in `inputs` will be moved to the correct device/dtype 139 | by the `get_preactivations` method. 140 | 141 | Returns: 142 | A tuple containing: 143 | - preactivations_dict: Dictionary mapping layer indices to pre-activation tensors. 144 | - original_shapes_info: List of tuples storing (layer_idx, batch_size, seq_len) 145 | for restoring original 3D shapes if needed. 146 | """ 147 | preactivations_dict: Dict[int, torch.Tensor] = {} 148 | original_shapes_info: List[Tuple[int, int, int]] = [] 149 | 150 | # Iterate in a deterministic layer order 151 | for layer_idx in sorted(inputs.keys()): 152 | x = inputs[layer_idx] # x will be moved to device/dtype in get_preactivations 153 | 154 | if x.dim() == 3: 155 | batch_size, seq_len, _ = x.shape 156 | original_shapes_info.append((layer_idx, batch_size, seq_len)) 157 | elif x.dim() == 2: 158 | batch_size, _ = x.shape 159 | original_shapes_info.append((layer_idx, batch_size, 1)) # seq_len is 1 for 2D 160 | 161 | preact = self.get_preactivations(x, layer_idx) 162 | preactivations_dict[layer_idx] = preact 163 | 164 | return preactivations_dict, original_shapes_info 165 | -------------------------------------------------------------------------------- /tests/unit/models/test_parallel_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | from unittest.mock import patch 4 | 5 | from clt.parallel import ops as dist_ops 6 | 7 | 8 | # Fixture to simulate dist not being initialized 9 | @pytest.fixture 10 | def mock_dist_not_initialized(): 11 | with patch("torch.distributed.is_available", return_value=True), patch( 12 | "torch.distributed.is_initialized", return_value=False 13 | ): 14 | yield 15 | 16 | 17 | # Fixture to simulate dist being initialized, single process (world_size=1, rank=0) 18 | @pytest.fixture 19 | def mock_dist_initialized_single_process(): 20 | with patch("torch.distributed.is_available", return_value=True), patch( 21 | "torch.distributed.is_initialized", return_value=True 22 | ), patch("torch.distributed.get_rank", return_value=0), patch("torch.distributed.get_world_size", return_value=1): 23 | yield 24 | 25 | 26 | # Fixture to simulate dist being initialized, multi-process (world_size=2, rank=1 as example) 27 | @pytest.fixture 28 | def mock_dist_initialized_multi_process(): 29 | with patch("torch.distributed.is_available", return_value=True), patch( 30 | "torch.distributed.is_initialized", return_value=True 31 | ), patch("torch.distributed.get_rank", return_value=1), patch("torch.distributed.get_world_size", return_value=2): 32 | yield 33 | 34 | 35 | def test_is_dist_initialized_and_available_not_initialized(mock_dist_not_initialized): 36 | assert not dist_ops.is_dist_initialized_and_available() 37 | 38 | 39 | def test_is_dist_initialized_and_available_initialized(mock_dist_initialized_single_process): 40 | assert dist_ops.is_dist_initialized_and_available() 41 | 42 | 43 | def test_get_rank_not_initialized(mock_dist_not_initialized): 44 | assert dist_ops.get_rank() == 0 45 | 46 | 47 | def test_get_rank_initialized_single_process(mock_dist_initialized_single_process): 48 | assert dist_ops.get_rank() == 0 49 | 50 | 51 | def test_get_rank_initialized_multi_process(mock_dist_initialized_multi_process): 52 | # Our mock torch.distributed.get_rank is set to return 1 53 | assert dist_ops.get_rank() == 1 54 | 55 | 56 | def test_get_world_size_not_initialized(mock_dist_not_initialized): 57 | assert dist_ops.get_world_size() == 1 58 | 59 | 60 | def test_get_world_size_initialized_single_process(mock_dist_initialized_single_process): 61 | assert dist_ops.get_world_size() == 1 62 | 63 | 64 | def test_get_world_size_initialized_multi_process(mock_dist_initialized_multi_process): 65 | # Our mock torch.distributed.get_world_size is set to return 2 66 | assert dist_ops.get_world_size() == 2 67 | 68 | 69 | def test_is_main_process_not_initialized(mock_dist_not_initialized): 70 | assert dist_ops.is_main_process() 71 | 72 | 73 | def test_is_main_process_initialized_rank_0(mock_dist_initialized_single_process): 74 | # This fixture sets rank to 0 75 | assert dist_ops.is_main_process() 76 | 77 | 78 | def test_is_main_process_initialized_rank_1(mock_dist_initialized_multi_process): 79 | # This fixture sets rank to 1 80 | assert not dist_ops.is_main_process() 81 | 82 | 83 | # Tests for collective wrappers in non-initialized state 84 | 85 | 86 | def test_all_reduce_not_initialized(mock_dist_not_initialized): 87 | tensor = torch.tensor([1.0, 2.0]) 88 | original_tensor = tensor.clone() 89 | work_obj = dist_ops.all_reduce(tensor) 90 | assert work_obj is None 91 | assert torch.equal(tensor, original_tensor) # Should be a no-op 92 | 93 | 94 | def test_all_reduce_initialized_single_process(mock_dist_initialized_single_process): 95 | tensor = torch.tensor([1.0, 2.0]) 96 | original_tensor = tensor.clone() 97 | # We need to mock the actual dist.all_reduce since it might be called if initialized 98 | with patch("torch.distributed.all_reduce") as mock_actual_all_reduce: 99 | work_obj = dist_ops.all_reduce(tensor) 100 | assert work_obj is None # Our wrapper returns None for world_size = 1 101 | mock_actual_all_reduce.assert_not_called() # Should not call actual dist op 102 | assert torch.equal(tensor, original_tensor) 103 | 104 | 105 | def test_broadcast_not_initialized(mock_dist_not_initialized): 106 | tensor = torch.tensor([1.0, 2.0]) 107 | original_tensor = tensor.clone() 108 | work_obj = dist_ops.broadcast(tensor, src=0) 109 | assert work_obj is None 110 | assert torch.equal(tensor, original_tensor) # Should be a no-op 111 | 112 | 113 | def test_broadcast_initialized_single_process(mock_dist_initialized_single_process): 114 | tensor = torch.tensor([1.0, 2.0]) 115 | original_tensor = tensor.clone() 116 | with patch("torch.distributed.broadcast") as mock_actual_broadcast: 117 | work_obj = dist_ops.broadcast(tensor, src=0) 118 | assert work_obj is None 119 | mock_actual_broadcast.assert_not_called() 120 | assert torch.equal(tensor, original_tensor) 121 | 122 | 123 | def test_all_gather_not_initialized(mock_dist_not_initialized): 124 | tensor = torch.tensor([1.0, 2.0]) 125 | tensor_list = [torch.empty_like(tensor) for _ in range(2)] # Example list 126 | 127 | work_obj = dist_ops.all_gather(tensor_list, tensor) 128 | assert work_obj is None 129 | # In non-initialized case, tensor_list[0] should contain the tensor 130 | assert torch.equal(tensor_list[0], tensor) 131 | # Other elements of tensor_list should remain unchanged if not rank 0 132 | # (assuming rank is 0 in non-initialized state, as per get_rank logic) 133 | assert torch.equal(tensor_list[1], torch.empty_like(tensor)) # Or its original value 134 | 135 | 136 | def test_all_gather_initialized_single_process(mock_dist_initialized_single_process): 137 | tensor = torch.tensor([1.0, 2.0]) 138 | # For world_size = 1, tensor_list should have at least one element 139 | tensor_list = [torch.empty_like(tensor)] 140 | 141 | with patch("torch.distributed.all_gather") as mock_actual_all_gather: 142 | dist_ops.all_gather(tensor_list, tensor) 143 | # In single process, dist.all_gather may or may not be called by the underlying 144 | # torch.distributed.all_gather depending on its implementation. 145 | # Our wrapper for world_size=1 would try to call it. 146 | # The critical part for our wrapper is that it *should* call the underlying if initialized. 147 | # However, for world_size=1, dist.all_gather itself should effectively 148 | # behave like a copy from input to tensor_list[0]. 149 | 150 | # If dist_ops.all_gather directly handles world_size=1 by not calling dist.all_gather: 151 | # mock_actual_all_gather.assert_not_called() 152 | # assert torch.equal(tensor_list[0], tensor) 153 | 154 | # If dist_ops.all_gather calls dist.all_gather which handles world_size=1: 155 | mock_actual_all_gather.assert_called_once() 156 | # We can't easily assert tensor_list[0] without knowing mock_actual_all_gather's behavior 157 | # For now, just ensure our wrapper attempts the call. 158 | # The behavior of actual dist.all_gather in ws=1 is that it populates tensor_list[rank] 159 | 160 | # Let's refine the logic in dist_ops.all_gather for ws=1 if it's not calling the backend. 161 | # Current `dist_ops.all_gather` calls `dist.all_gather` if initialized. 162 | # So, mock_actual_all_gather *should* be called. 163 | # To test the outcome, we can make the mock_actual_all_gather simulate the copy. 164 | def mock_all_gather_side_effect(out_list, in_tensor, group=None, async_op=False): 165 | out_list[0] = in_tensor # Simulate behavior for rank 0, world_size 1 166 | return None # Simulate no Work object for sync op 167 | 168 | with patch("torch.distributed.all_gather", side_effect=mock_all_gather_side_effect) as mock_actual_all_gather_ws1: 169 | work_obj_ws1 = dist_ops.all_gather(tensor_list, tensor) 170 | assert work_obj_ws1 is None 171 | mock_actual_all_gather_ws1.assert_called_once() 172 | assert torch.equal(tensor_list[0], tensor) 173 | 174 | 175 | # Example of how one might test a specific ReduceOp re-export 176 | def test_sum_op_export(): 177 | assert dist_ops.SUM == torch.distributed.ReduceOp.SUM 178 | -------------------------------------------------------------------------------- /scripts/experiments/run_pythia_batchtopk_training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import time 4 | import sys 5 | import traceback 6 | 7 | # Import components from the clt library 8 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 9 | 10 | project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 11 | if project_root not in sys.path: 12 | sys.path.insert(0, project_root) 13 | 14 | try: 15 | from clt.config import CLTConfig, TrainingConfig, ActivationConfig 16 | from clt.activation_generation.generator import ActivationGenerator 17 | from clt.training.trainer import CLTTrainer 18 | except ImportError as e: 19 | print(f"ImportError: {e}") 20 | print("Please ensure the 'clt' library is installed or the clt directory is in your PYTHONPATH.") 21 | raise 22 | 23 | # Device setup 24 | if torch.cuda.is_available(): 25 | device = "cuda" 26 | elif torch.backends.mps.is_available(): 27 | device = "mps" 28 | else: 29 | device = "cpu" 30 | 31 | print(f"Using device: {device}") 32 | 33 | # Base model for activation extraction 34 | BASE_MODEL_NAME = "EleutherAI/pythia-70m" 35 | 36 | # --- CLT Architecture Configuration --- 37 | num_layers = 6 38 | d_model = 512 39 | expansion_factor = 32 40 | clt_num_features = d_model * expansion_factor 41 | 42 | batchtopk_sparsity_fraction = 0.005 43 | 44 | clt_config = CLTConfig( 45 | num_features=clt_num_features, 46 | num_layers=num_layers, 47 | d_model=d_model, 48 | activation_fn="batchtopk", 49 | batchtopk_k=None, 50 | batchtopk_straight_through=True, 51 | ) 52 | print("CLT Configuration (BatchTopK):") 53 | print(clt_config) 54 | 55 | # --- Activation Generation Configuration --- 56 | activation_dir = "./activations_local_10M_pythia" 57 | dataset_name = "monology/pile-uncopyrighted" 58 | activation_config = ActivationConfig( 59 | model_name=BASE_MODEL_NAME, 60 | mlp_input_module_path_template="gpt_neox.layers.{}.mlp.input", 61 | mlp_output_module_path_template="gpt_neox.layers.{}.mlp.output", 62 | model_dtype=None, 63 | dataset_path=dataset_name, 64 | dataset_split="train", 65 | dataset_text_column="text", 66 | context_size=128, 67 | inference_batch_size=192, 68 | exclude_special_tokens=True, 69 | prepend_bos=True, 70 | streaming=True, 71 | dataset_trust_remote_code=False, 72 | cache_path=None, 73 | target_total_tokens=10_000_000, 74 | activation_dir=activation_dir, 75 | output_format="hdf5", 76 | compression="gzip", 77 | chunk_token_threshold=32_000, 78 | activation_dtype="float32", 79 | compute_norm_stats=True, 80 | nnsight_tracer_kwargs={}, 81 | nnsight_invoker_args={}, 82 | ) 83 | print("Activation Generation Configuration:") 84 | print(activation_config) 85 | 86 | # --- Training Configuration --- 87 | expected_activation_path = os.path.join( 88 | activation_config.activation_dir, 89 | activation_config.model_name, 90 | f"{os.path.basename(activation_config.dataset_path)}_{activation_config.dataset_split}", 91 | ) 92 | 93 | _lr = 1e-4 94 | _batch_size = 1024 95 | _k = clt_config.batchtopk_k 96 | 97 | wdb_run_name = f"{clt_config.num_features}-width-" f"batchtopk-k{_k}-" f"{_batch_size}-batch-" f"{_lr:.1e}-lr" 98 | print("\nGenerated WandB run name: " + wdb_run_name) 99 | 100 | training_config = TrainingConfig( 101 | learning_rate=_lr, 102 | training_steps=10000, 103 | seed=42, 104 | activation_source="local_manifest", 105 | activation_path=expected_activation_path, 106 | activation_dtype="float32", 107 | train_batch_size_tokens=_batch_size, 108 | sampling_strategy="sequential", 109 | normalization_method="none", 110 | sparsity_lambda=0.0, 111 | sparsity_lambda_schedule="linear", 112 | sparsity_c=0.0, 113 | preactivation_coef=0, 114 | aux_loss_factor=1 / 32, 115 | apply_sparsity_penalty_to_batchtopk=False, 116 | optimizer="adamw", 117 | lr_scheduler="linear_final20", 118 | optimizer_beta2=0.98, 119 | log_interval=10, 120 | eval_interval=50, 121 | checkpoint_interval=1000, 122 | dead_feature_window=1000, 123 | enable_wandb=True, 124 | wandb_project="clt-hp-sweeps-pythia-70m", 125 | wandb_run_name=wdb_run_name, 126 | ) 127 | print("\nTraining Configuration (BatchTopK):") 128 | print(training_config) 129 | 130 | # --- Generate Activations (One-Time Step) --- 131 | print("Step 1: Generating/Verifying Activations (including manifest)...") 132 | 133 | metadata_path = os.path.join(expected_activation_path, "metadata.json") 134 | manifest_path = os.path.join(expected_activation_path, "index.bin") 135 | 136 | if os.path.exists(metadata_path) and os.path.exists(manifest_path): 137 | print(f"Activations and manifest already found at: {expected_activation_path}") 138 | print("Skipping generation. Delete the directory to regenerate.") 139 | else: 140 | print(f"Activations or manifest not found. Generating them now at: {expected_activation_path}") 141 | try: 142 | generator = ActivationGenerator( 143 | cfg=activation_config, 144 | device=device, 145 | ) 146 | generation_start_time = time.time() 147 | generator.generate_and_save() 148 | generation_end_time = time.time() 149 | print(f"Activation generation complete in {generation_end_time - generation_start_time:.2f}s.") 150 | except Exception as gen_err: 151 | print(f"[ERROR] Activation generation failed: {gen_err}") 152 | traceback.print_exc() 153 | raise 154 | 155 | # --- Training the CLT with BatchTopK Activation --- 156 | print("Initializing CLTTrainer for training with BatchTopK...") 157 | 158 | log_dir = f"clt_training_logs/clt_pythia_batchtopk_train_{int(time.time())}" 159 | os.makedirs(log_dir, exist_ok=True) 160 | print(f"Logs and checkpoints will be saved to: {log_dir}") 161 | 162 | try: 163 | print("Creating CLTTrainer instance...") 164 | print(f"- Using device: {device}") 165 | print(f"- CLT config (BatchTopK): {vars(clt_config)}") 166 | print(f"- Activation Source: {training_config.activation_source}") 167 | print(f"- Reading activations from: {training_config.activation_path}") 168 | 169 | trainer = CLTTrainer( 170 | clt_config=clt_config, 171 | training_config=training_config, 172 | log_dir=log_dir, 173 | device=device, 174 | distributed=False, 175 | ) 176 | print("CLTTrainer instance created successfully.") 177 | except Exception as e: 178 | print(f"[ERROR] Failed to initialize CLTTrainer: {e}") 179 | traceback.print_exc() 180 | raise 181 | 182 | # Start training 183 | print("Beginning training using BatchTopK activation...") 184 | print(f"Training for {training_config.training_steps} steps.") 185 | print(f"Normalization method set to: {training_config.normalization_method}") 186 | print( 187 | f"Standard sparsity penalty applied to BatchTopK activations: {training_config.apply_sparsity_penalty_to_batchtopk}" 188 | ) 189 | 190 | try: 191 | start_train_time = time.time() 192 | trained_clt_model = trainer.train(eval_every=training_config.eval_interval) 193 | end_train_time = time.time() 194 | print(f"Training finished in {end_train_time - start_train_time:.2f} seconds.") 195 | except Exception as train_err: 196 | print(f"[ERROR] Training failed: {train_err}") 197 | traceback.print_exc() 198 | raise 199 | 200 | # --- Saving the Trained Model --- 201 | final_model_path = os.path.join(log_dir, "clt_batchtopk_final.pt") # Changed from _manual 202 | # trainer.save_model(trained_clt_model, final_model_path) # Trainer saves automatically via checkpoints and at the end. 203 | # If you want an explicit final save *after* training finishes and returns the model, you could do this: 204 | # if trained_clt_model: 205 | # torch.save(trained_clt_model.state_dict(), final_model_path) 206 | # print(f"Manually saved final BatchTopK model to: {final_model_path}") 207 | # else: 208 | # print("Training did not complete successfully, model not saved manually.") 209 | 210 | print(f"\nContents of log directory ({log_dir}):") 211 | try: 212 | print(os.listdir(log_dir)) 213 | except FileNotFoundError: 214 | print(f"Log directory {log_dir} not found. This might happen if training failed very early.") 215 | 216 | 217 | print("\nBatchTopK Training Script Complete!") 218 | print(f"The trained BatchTopK CLT model and logs are saved in: {log_dir}") 219 | -------------------------------------------------------------------------------- /scripts/experiments/run_pythia_batchtopk_training_fp16.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import time 4 | import sys 5 | import traceback 6 | 7 | # Import components from the clt library 8 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 9 | 10 | project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 11 | if project_root not in sys.path: 12 | sys.path.insert(0, project_root) 13 | 14 | try: 15 | from clt.config import CLTConfig, TrainingConfig, ActivationConfig 16 | from clt.activation_generation.generator import ActivationGenerator 17 | from clt.training.trainer import CLTTrainer 18 | except ImportError as e: 19 | print(f"ImportError: {e}") 20 | print("Please ensure the 'clt' library is installed or the clt directory is in your PYTHONPATH.") 21 | raise 22 | 23 | # Device setup 24 | if torch.cuda.is_available(): 25 | device = "cuda" 26 | elif torch.backends.mps.is_available(): 27 | device = "mps" 28 | else: 29 | device = "cpu" 30 | 31 | print(f"Using device: {device}") 32 | 33 | # Base model for activation extraction 34 | BASE_MODEL_NAME = "EleutherAI/pythia-70m" 35 | 36 | # --- CLT Architecture Configuration --- 37 | num_layers = 6 38 | d_model = 512 39 | expansion_factor = 32 40 | clt_num_features = d_model * expansion_factor 41 | 42 | batchtopk_sparsity_fraction = 0.005 43 | 44 | clt_config = CLTConfig( 45 | num_features=clt_num_features, 46 | num_layers=num_layers, 47 | d_model=d_model, 48 | activation_fn="batchtopk", 49 | batchtopk_k=200, 50 | batchtopk_straight_through=True, 51 | ) 52 | print("CLT Configuration (BatchTopK):") 53 | print(clt_config) 54 | 55 | # --- Activation Generation Configuration --- 56 | activation_dir = "./activations_local_100k_pythia_fp16" 57 | dataset_name = "monology/pile-uncopyrighted" 58 | activation_config = ActivationConfig( 59 | model_name=BASE_MODEL_NAME, 60 | mlp_input_module_path_template="gpt_neox.layers.{}.mlp.input", 61 | mlp_output_module_path_template="gpt_neox.layers.{}.mlp.output", 62 | model_dtype=None, 63 | dataset_path=dataset_name, 64 | dataset_split="train", 65 | dataset_text_column="text", 66 | context_size=128, 67 | inference_batch_size=192, 68 | exclude_special_tokens=True, 69 | prepend_bos=True, 70 | streaming=True, 71 | dataset_trust_remote_code=False, 72 | cache_path=None, 73 | target_total_tokens=100_000, 74 | activation_dir=activation_dir, 75 | output_format="hdf5", 76 | compression="gzip", 77 | chunk_token_threshold=8_000, 78 | activation_dtype="float16", 79 | compute_norm_stats=True, 80 | nnsight_tracer_kwargs={}, 81 | nnsight_invoker_args={}, 82 | ) 83 | print("Activation Generation Configuration:") 84 | print(activation_config) 85 | 86 | # --- Training Configuration --- 87 | expected_activation_path = os.path.join( 88 | activation_config.activation_dir, 89 | activation_config.model_name, 90 | f"{os.path.basename(activation_config.dataset_path)}_{activation_config.dataset_split}", 91 | ) 92 | 93 | _lr = 1e-4 94 | _batch_size = 1024 95 | _k = clt_config.batchtopk_k 96 | 97 | wdb_run_name = f"{clt_config.num_features}-width-" f"batchtopk-k{_k}-" f"{_batch_size}-batch-" f"{_lr:.1e}-lr" 98 | print("\nGenerated WandB run name: " + wdb_run_name) 99 | 100 | training_config = TrainingConfig( 101 | learning_rate=_lr, 102 | training_steps=100, 103 | seed=42, 104 | activation_source="local_manifest", 105 | activation_path=expected_activation_path, 106 | activation_dtype="float16", 107 | train_batch_size_tokens=_batch_size, 108 | sampling_strategy="sequential", 109 | normalization_method="none", 110 | sparsity_lambda=0.0, 111 | sparsity_lambda_schedule="linear", 112 | sparsity_c=0.0, 113 | preactivation_coef=0, 114 | aux_loss_factor=1 / 32, 115 | apply_sparsity_penalty_to_batchtopk=False, 116 | optimizer="adamw", 117 | lr_scheduler="linear_final20", 118 | optimizer_beta2=0.98, 119 | log_interval=10, 120 | eval_interval=50, 121 | checkpoint_interval=1000, 122 | dead_feature_window=1000, 123 | enable_wandb=True, 124 | wandb_project="clt-fp16-testing", 125 | wandb_run_name=wdb_run_name, 126 | precision="fp16", 127 | fp16_convert_weights=True, 128 | debug_anomaly=False, 129 | ) 130 | print("\nTraining Configuration (BatchTopK):") 131 | print(training_config) 132 | 133 | # --- Generate Activations (One-Time Step) --- 134 | print("Step 1: Generating/Verifying Activations (including manifest)...") 135 | 136 | metadata_path = os.path.join(expected_activation_path, "metadata.json") 137 | manifest_path = os.path.join(expected_activation_path, "index.bin") 138 | 139 | if os.path.exists(metadata_path) and os.path.exists(manifest_path): 140 | print(f"Activations and manifest already found at: {expected_activation_path}") 141 | print("Skipping generation. Delete the directory to regenerate.") 142 | else: 143 | print(f"Activations or manifest not found. Generating them now at: {expected_activation_path}") 144 | try: 145 | generator = ActivationGenerator( 146 | cfg=activation_config, 147 | device=device, 148 | ) 149 | generation_start_time = time.time() 150 | generator.generate_and_save() 151 | generation_end_time = time.time() 152 | print(f"Activation generation complete in {generation_end_time - generation_start_time:.2f}s.") 153 | except Exception as gen_err: 154 | print(f"[ERROR] Activation generation failed: {gen_err}") 155 | traceback.print_exc() 156 | raise 157 | 158 | # --- Training the CLT with BatchTopK Activation --- 159 | print("Initializing CLTTrainer for training with BatchTopK...") 160 | 161 | log_dir = f"clt_training_logs/clt_pythia_batchtopk_train_{int(time.time())}" 162 | os.makedirs(log_dir, exist_ok=True) 163 | print(f"Logs and checkpoints will be saved to: {log_dir}") 164 | 165 | try: 166 | print("Creating CLTTrainer instance...") 167 | print(f"- Using device: {device}") 168 | print(f"- CLT config (BatchTopK): {vars(clt_config)}") 169 | print(f"- Activation Source: {training_config.activation_source}") 170 | print(f"- Reading activations from: {training_config.activation_path}") 171 | 172 | trainer = CLTTrainer( 173 | clt_config=clt_config, 174 | training_config=training_config, 175 | log_dir=log_dir, 176 | device=device, 177 | distributed=False, 178 | ) 179 | print("CLTTrainer instance created successfully.") 180 | except Exception as e: 181 | print(f"[ERROR] Failed to initialize CLTTrainer: {e}") 182 | traceback.print_exc() 183 | raise 184 | 185 | # Start training 186 | print("Beginning training using BatchTopK activation...") 187 | print(f"Training for {training_config.training_steps} steps.") 188 | print(f"Normalization method set to: {training_config.normalization_method}") 189 | print( 190 | f"Standard sparsity penalty applied to BatchTopK activations: {training_config.apply_sparsity_penalty_to_batchtopk}" 191 | ) 192 | 193 | try: 194 | start_train_time = time.time() 195 | trained_clt_model = trainer.train(eval_every=training_config.eval_interval) 196 | end_train_time = time.time() 197 | print(f"Training finished in {end_train_time - start_train_time:.2f} seconds.") 198 | except Exception as train_err: 199 | print(f"[ERROR] Training failed: {train_err}") 200 | traceback.print_exc() 201 | raise 202 | 203 | # --- Saving the Trained Model --- 204 | final_model_path = os.path.join(log_dir, "clt_batchtopk_final.pt") # Changed from _manual 205 | # trainer.save_model(trained_clt_model, final_model_path) # Trainer saves automatically via checkpoints and at the end. 206 | # If you want an explicit final save *after* training finishes and returns the model, you could do this: 207 | # if trained_clt_model: 208 | # torch.save(trained_clt_model.state_dict(), final_model_path) 209 | # print(f"Manually saved final BatchTopK model to: {final_model_path}") 210 | # else: 211 | # print("Training did not complete successfully, model not saved manually.") 212 | 213 | print(f"\nContents of log directory ({log_dir}):") 214 | try: 215 | print(os.listdir(log_dir)) 216 | except FileNotFoundError: 217 | print(f"Log directory {log_dir} not found. This might happen if training failed very early.") 218 | 219 | 220 | print("\nBatchTopK Training Script Complete!") 221 | print(f"The trained BatchTopK CLT model and logs are saved in: {log_dir}") 222 | --------------------------------------------------------------------------------