├── .gitignore ├── LICENSE ├── README.md ├── configs ├── README.md ├── commanda_111b.yaml ├── gemma2_27b.yaml ├── llama3.3_70b.yaml ├── mistral_24b.yaml ├── qwq_32b.yaml └── r1_distill_qwen_32b.yaml ├── pipeline ├── config.py ├── data │ ├── dataloader.py │ ├── dataset.py │ └── datasets.json ├── generate.py ├── setup.py ├── utils │ ├── cli.py │ └── ray_utils.py └── vault │ ├── __init__.py │ ├── hook_uploader.py │ └── rcache.py ├── pyproject.toml ├── retrieve.py ├── s3 ├── shell │ ├── __init__.py │ ├── history.py │ ├── s3_operations.py │ ├── sanity.py │ └── shell.py └── utils.py ├── scripts ├── collect.slurm ├── run_ray_jobs.py └── run_slurm_jobs.sh ├── stash.py └── uv.lock /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .venv/ 3 | .env 4 | .vscode/ 5 | logs/ 6 | *.out 7 | *.err 8 | *temp* 9 | *.egg-info 10 | 11 | benchmark_results/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Activault 2 | **Activault** is an activation data engine that dramatically reduces costs for training interpreter models on frontier LLMs: 3 | - Collects and stores model activations (the model's "mental state") efficiently using S3 object storage, reducing activation management costs by 4-8x. 4 | - Enables reproducible and shareable interpretability research through standardized object storage. 5 | - Maintains peak efficiency and throughput while handling petabyte-scale activation datasets. 6 | 7 | 8 | You can read about Activault in our [blog post](https://www.tilderesearch.com/blog/activault). 9 | 10 | > ⚠️ **CRITICAL WARNING** 11 | > Streaming/storing activations with Activault can be expensive ($$$) and slow if care is not taken before launching large-scale jobs. We recommend users set up their compute environment in the same region/data center as their s3 solution to ensure minimal latency and avoid egress fees. We also **strongly** recommend users consult the pricing page for their s3 solution to ensure they understand the costs associated with their jobs. 12 | 13 | ## 💾 Activault principles 14 | 15 | When designing Activault, we considered the tradeoffs between computing activations on-the-fly vs storing them on disk vs storing them in S3. Here's how the approaches compare: 16 | 17 | | Aspect | On-the-fly | Local Cache | Naive S3 | Activault | 18 | |--------|------------|-------------|----------|-----------| 19 | | Setup Complexity | ✅ Easy | ✅ Easy | ❌ Hard | ✅ Easy | 20 | | Write Performance | ✅ Fast | ✅ Fast | ❌ Slow | ✅ Fast | 21 | | Read Performance | ✅ Fast | ✅ Fast | ❌ Slow | ✅ Fast | 22 | | Efficiency | ❌ Enormously inefficient as activations must be regenerated across runs | ✅ Efficient | ✅ Efficient | ✅ Efficient | 23 | | Reproducibility | ❌ Poor | ✅ Good | ✅ Guaranteed | ✅ Guaranteed | 24 | | Token Context | ❌ Autointerp requires recomputing | ✅ Good | ❌ Poor (no tokens saved) | ✅ Tokens saved with data | 25 | | Shareability | ❌ Vanishes after training | ❌ Terrible | ✅ Guaranteed | ✅ Guaranteed | 26 | | Storage Cost | ✅ None | ❌ Very expensive | ✅ Cheap | ✅ Cheap | 27 | | Storage Availability | ✅ N/A | ❌ Very low | ✅ High | ✅ High | 28 | 29 | ## Table of Contents 30 | 31 | - [🔧 Setup](#setup) - Installation and AWS credential configuration 32 | - [📊 Collecting Activations](#collecting-activations) - Core pipeline for gathering model activations 33 | - [🚀 Running Collection Jobs](#running-collection-jobs) 34 | - [Using Slurm](#using-slurm-traditional-method) - Run on HPC clusters 35 | - [Using Ray](#using-ray-distributed-computing) - Distributed computing setup 36 | - [Running Locally](#running-locally) - Single machine execution 37 | - [🔍 Checking the Outputs: S3 Shell](#checking-the-outputs-s3-shell) - Tools for inspecting collected data 38 | - [📈 Using Activations with RCache](#using-your-activations-with-rcache) - Efficient streaming interface 39 | - [❓ FAQs](#faqs) - Common questions and answers 40 | - [💾 Local vs S3 Storage](#local-storage-vs-s3-storage) - Storage approach comparisons 41 | - [👥 Credits](#credits) - Attribution and inspiration 42 | 43 | ## 🔧 Setup 44 | 45 | ``` 46 | pip install uv 47 | uv sync --no-build-isolation 48 | uv pip install -e . 49 | ``` 50 | 51 | Make sure your AWS credentials are set. 52 | ``` 53 | export AWS_ACCESS_KEY_ID= 54 | export AWS_SECRET_ACCESS_KEY= 55 | export S3_ENDPOINT_URL= 56 | ``` 57 | 58 | ## 📊 Collecting Activations 59 | 60 | Use one of the pre-existing configs in `configs/` or create your own. We provide configs for several frontier open-weight models out-of-box. 61 | 62 | The collection pipeline: 63 | 64 | 1. Loads a transformer model and hooks into the specified layers and modules 65 | 2. Streams text data according to a specified data_key (mappings defined in `pipeline/data/datasets.json`) through the model in batches 66 | 3. For each hook (e.g., residual stream, attention outputs): 67 | - Collects activations and their corresponding input tokens 68 | - Concatenates multiple batches into "megabatch" files 69 | - Computes running statistics (mean, std, norm) 70 | - Uploads to S3 asynchronously 71 | 72 | Each hook's data is stored in its own directory: 73 | ``` 74 | s3://{bucket}/{run_name}/ 75 | ├── cfg.json # Collection config and model info 76 | └── {hook_name}/ 77 | ├── metadata.json # Shape and dtype info 78 | ├── statistics.json # Running statistics 79 | └── {uuid}--{n}.pt # Megabatch files 80 | ``` 81 | 82 | ## 🚀 Running Collection Jobs 83 | 84 | > 📢 **IMPORTANT** 85 | > Ensure `n_runs` in the config file is set to the total number of runs you want to launch before runnign large-scale distributed jobs. if this is not done, you will generate redundant data. 86 | 87 | ### Using Slurm (Traditional method) 88 | 89 | #### 1. Basic Single-Node Usage 90 | 91 | For a simple job on a single Slurm node: 92 | ```bash 93 | sbatch scripts/collect.slurm configs/your_config.yaml 94 | ``` 95 | 96 | #### 2. Running Distributed Jobs 97 | 98 | To run multiple distributed jobs across different nodes: 99 | ```bash 100 | ./scripts/run_slurm_jobs.sh configs/your_config.yaml 8 0 7 101 | ``` 102 | 103 | **Key Arguments:** 104 | - `configs/your_config.yaml`: Path to configuration file 105 | - `8`: Total number of workers to spawn 106 | - `0 7`: Start and end indices for worker assignment (will launch jobs for indices 0-7) 107 | 108 | The script will generate a log file mapping machine indices to Slurm job IDs. 109 | 110 | #### 3. Configuration 111 | 112 | Slurm job parameters (CPUs, GPUs, memory, etc.) can be adjusted by editing `scripts/collect.slurm`. Important parameters: 113 | ```bash 114 | #SBATCH --cpus-per-task=16 # CPUs per task 115 | #SBATCH --gres=gpu:1 # GPUs per node 116 | #SBATCH --mem=250G # Memory per node 117 | ``` 118 | 119 | ### Using Ray 120 | 121 | Be sure to start a Ray cluster. 122 | ```bash 123 | # Start Ray locally 124 | ray start --head 125 | 126 | # On head node 127 | ray start --head --port=6379 128 | 129 | # On worker nodes 130 | ray start --address=:6379 131 | ``` 132 | 133 | Running a single worker: 134 | ```bash 135 | python scripts/run_ray_jobs.py configs/your_config.yaml 1 0 0 --resources '{"CPU": 32, "GPU": 2}' --wait 136 | ``` 137 | 138 | Running distributed jobs (8 workers from index 0-7): 139 | ```bash 140 | python scripts/run_ray_jobs.py configs/your_config.yaml 8 0 7 --resources '{"CPU": 32, "GPU": 2}' --wait 141 | ``` 142 | 143 | **Key Arguments:** 144 | - `configs/your_config.yaml`: Path to configuration file 145 | - `8`: Total number of workers to spawn 146 | - `0 7`: Start and end indices for worker assignment 147 | - `--resources`: CPU and GPU allocation per worker (JSON format) 148 | - `--address`: Optional Ray cluster address (if not using environment variable) 149 | - `--wait`: Wait for all jobs to complete and show results 150 | 151 | Check Ray's dashboard periodically (typically at http://localhost:8265) for cluster status. 152 | 153 | ### Running locally 154 | 155 | To run the pipeline locally, you can use the Activault CLI: 156 | ```bash 157 | activault collect --config configs/your_config.yaml 158 | ``` 159 | 160 | Alternatively, you can run it directly: 161 | ```bash 162 | python stash.py --config configs/your_config.yaml 163 | ``` 164 | 165 | For distributed execution, specify the machine index: 166 | ```bash 167 | activault collect --config configs/your_config.yaml --machine 0 168 | ``` 169 | 170 | ## 🔍 Checking the Outputs: S3 Shell 171 | 172 | After running the pipeline, you can check the outputs by using our S3 shell. 173 | 174 | First, make sure your S3 bucket name is set: 175 | ```bash 176 | export S3_BUCKET_NAME= 177 | ``` 178 | 179 | Then, launch the S3 shell using the Activault CLI: 180 | ```bash 181 | activault s3 182 | ``` 183 | 184 | In the S3 shell, navigate to your run directory and use these commands: 185 | - `ls` - List files and directories 186 | - `cd directory_name` - Change directory 187 | - `filecount` - Count the number of files in the current directory and subdirectories 188 | - `sizecheck` - Calculate the total size of files in the current directory 189 | - `inspect ` - Inspect a specific megabatch file 190 | 191 | Example inspection output: 192 | 193 | ``` 194 | s3://main/testing/models.layers.24.mlp.post> inspect 1 195 | Inspecting file: /tmp/0f909221-ff28-4a94-a43f-cfe973e835cf--5_0.saved.pt 196 | 197 | PT File Inspection: 198 | ---------------------------------------- 199 | Model: meta-llama/Llama-3.3-70B-Instruct 200 | 201 | Tensor Shapes: 202 | states: [32, 2048, 8192] 203 | input_ids: [32, 2048] 204 | 205 | States Tensor Check: 206 | No NaNs: ✅ 207 | No Infs: ✅ 208 | Value range: [-6.941, 4.027] 209 | 210 | First 4 batches (first 250 chars each): 211 | ---------------------------------------- 212 | Batch 0: Given a triangle... (truncated) 213 | Batch 1: Neonatal reviewers indicated... (truncated) 214 | Batch 2: Is there a method... (truncated) 215 | Batch 3: John visits three different... (truncated) 216 | 217 | Enter batch number (0-31) to view full text, or 'q' to quit: 218 | ``` 219 | 220 | ## 📈 Using Activations with RCache 221 | 222 | RCache provides a simple interface for efficiently streaming large activation datasets from S3 without memory or I/O bottlenecks. 223 | 224 | 1. RCache maintains a small buffer (default: 2 files) in memory 225 | 2. While you process the current megabatch, the next ones are downloaded asynchronously 226 | 3. After a brief initial load (<30s), processing should never be bottlenecked by the downloads/streamings 227 | 228 | ### Quick usage: 229 | ```python 230 | cache = S3RCache.from_credentials( 231 | aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], 232 | aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"], 233 | s3_prefix="run_name/hook_name", 234 | device="cuda", # or "cpu" 235 | return_ids=True # if you need the input tokens 236 | ) 237 | 238 | for batch in cache: 239 | states = batch["states"] # shape: [n_batches, seq_len, d_model] 240 | input_ids = batch["input_ids"] # shape: [n_batches, seq_len] 241 | # ... process batch ... 242 | 243 | cache.finalize() # clean up 244 | ``` 245 | 246 | See `retrieve.py` for a complete example. 247 | 248 | ## ❓ FAQs 249 | 250 | ### Can I use this to train my own sparse autoencoder? 251 | 252 | Yes! This is the intended goal. RCache can be used out of box in SAE training workflows. It supports blazing fast throughput to ensure training is always FLOP-bottlencked, not IO-bottlencked. 253 | 254 | ### What platforms is Activault built for? 255 | 256 | Activault is designed to be compatible with any S3-style object storage solution. We performed most of our testing on Nebius S3 and have also tested on AWS S3. It is possible that other platforms may encounter issues, and we welcome contributions to expand support. 257 | 258 | ### Why does Activault use `transformers` instead of a more efficient inference library such as `vllm`? 259 | 260 | A few reasons: 261 | 1. The main reason is that the bottleneck is upload speed not throughput. We experimented with using much faster internal serving engines but the main process ran far ahead of the save processes and there was no real gain in overall time. 262 | 2. Activault does not use the `generate` method and prefill speeds are more comparable between the public libraries. 263 | 3. Activault should be compatible with as many models as possible. 264 | 4. `vllm` does not play nice with procuring internal states. 265 | 266 | That said, we welcome contributions to expand Activault's support for more efficient inference libraries. 267 | 268 | ### Why does Activault not use `nnsight` or `transformer-lens`? 269 | 270 | We do not use libraries such as `nnsight` or `transformer-lens` to minimize dependencies and potential failure points, and to ensure maximal compatibility with a wide range of models. 271 | 272 | ### Activault doesn't (support vision models/get the activations I need/work for my storage solution)! 273 | 274 | We welcome contributions! Please open an issue or PR. We are releasing Activault as a community tool to enable low-resource users to collect activations, run experiments, and share data to analyze frontier open-weight models. 275 | 276 | 277 | ## 👥 Credits 278 | 279 | This repo was originally inspired by [Lewington-pitsos/sache](https://github.com/Lewington-pitsos/sache), which is linked in the LessWrong post [here](https://www.lesswrong.com/posts/AtJdPZkMdsakLc6mB/training-a-sparse-autoencoder-in-less-than-30-minutes-on). 280 | 281 | ## 📄 License 282 | 283 | Activault is licensed under the [Apache License 2.0](LICENSE). 284 | 285 | This is a permissive license that allows you to: 286 | - Use the code commercially 287 | - Modify the code 288 | - Distribute your modifications 289 | - Use patent claims of contributors (if applicable) 290 | - Sublicense and/or distribute the code 291 | 292 | Key requirements: 293 | - Include a copy of the license in any redistribution 294 | - Clearly mark any changes you make to the code 295 | - Include the original copyright notices 296 | 297 | The full text of the license is available in the [LICENSE](LICENSE) file. 298 | -------------------------------------------------------------------------------- /configs/README.md: -------------------------------------------------------------------------------- 1 | # Configuration Files 2 | 3 | This directory contains YAML configuration files for the ActiVault pipeline, which is used to extract and store model activations. 4 | 5 | ## Configuration Structure 6 | 7 | Each configuration file follows this general structure: 8 | 9 | ```yaml 10 | run_name: model_name_dataset # Format: model_family.model_size.dataset[.template] 11 | num_runs: 1 # Number of parallel runs to split work across. If you are running distributed, set this to the number of jobs. 12 | transformer_config: # Model-specific configurations 13 | model_name: path/to/model # HuggingFace model ID or path 14 | dtype: float16 # Precision (float16, bfloat16, float32) 15 | cache_dir: /cache # Local cache directory for models 16 | max_per_device_memory: 55GB # Maximum GPU memory per device 17 | kwargs: # Additional model-specific configurations 18 | load_in_8bit: True 19 | data_config: # Dataset and processing configurations 20 | bucket_name: main # S3 bucket for storing results 21 | data_key: dataset:TEMPLATE # Dataset key with optional template 22 | n_batches: 100 # Number of batches to process 23 | seq_length: 2048 # Maximum sequence length 24 | batch_size: 4 # Batch size for processing 25 | seed: 42 # Random seed for reproducibility 26 | skip_cache: False # Skip caching very long texts 27 | start_batch: 0 # Starting batch number 28 | clean_added_tokens: True # Whether to skip hidden states for special tokens cleaned from outputs 29 | clean_default_system_prompt: True # Whether to skip hidden states for default system prompt (will ignore custom system prompts) 30 | upload_config: # Upload behavior and hook specifications 31 | batches_per_upload: 8 # Number of batches to accumulate before upload 32 | hooks: # List of activation hooks to capture 33 | - "models.layers.24.mlp.post" 34 | - "models.layers.36.mlp.post" 35 | ``` 36 | 37 | ## Configuration Fields 38 | 39 | ### Root Level 40 | 41 | - `run_name` (required): Unique identifier for this run. Format: `model_family.model_size.dataset[.template]` 42 | - `num_runs` (required): Number of parallel runs to split the work across. For distributed processing. 43 | 44 | ### Transformer Config 45 | 46 | - `model_name` (required): HuggingFace model identifier or path 47 | - `dtype` (optional, default: float16): Model precision (float16, bfloat16, float32) 48 | - `cache_dir` (optional): Directory to cache downloaded models 49 | - `max_per_device_memory` (optional): Maximum GPU memory allocated per device 50 | 51 | ### Data Config 52 | 53 | - `bucket_name` (required): S3 bucket name for storing results 54 | - `data_key` (required): Dataset key with optional template format 55 | - `n_batches` (required): Total number of batches to process 56 | - `seq_length` (required): Maximum sequence length for processing 57 | - `batch_size` (required): Batch size for processing 58 | - `seed` (optional, default: 42): Random seed for reproducibility 59 | - `skip_cache` (optional, default: False): Whether to skip caching very long texts 60 | - `start_batch` (optional, default: 0): Starting batch index (for resuming or parallel jobs) 61 | - `clean_added_tokens` (optional, default: False): Whether to ignore special tokens in the output 62 | 63 | ### Upload Config 64 | 65 | - `batches_per_upload` (required): Number of batches to accumulate before uploading to S3 66 | - `hooks` (required): List of activation hooks to capture (format: "models.layers.{layer}.{self_attn|mlp}.{pre|post}") 67 | 68 | ## Example Configurations 69 | 70 | 1. **Llama 3.3 70B** ([llama3.3_70b.yaml](llama3.3_70b.yaml)) 71 | - 70 billion parameter model from Meta 72 | - FP16 precision 73 | - Captures activations from 4 different layers 74 | 75 | 2. **Gemma 3 27B** ([gemma3_27b.yaml](gemma3_27b.yaml)) 76 | - 27 billion parameter model from Google 77 | - BF16 precision 78 | 79 | ## Creating a New Configuration 80 | 81 | To create a new configuration: 82 | 83 | 1. Copy an existing configuration file that most closely matches your needs 84 | 2. Modify the following fields: 85 | - `run_name`: Change to match your new model and dataset 86 | - `model_name`: Update to the correct HuggingFace model ID 87 | - Adjust batch sizes, sequence lengths, and other parameters as needed 88 | - Update hook positions based on the model's architecture 89 | 90 | ## Usage 91 | 92 | Configuration files are used with the ActiVault pipeline: 93 | 94 | ```bash 95 | python stash.py --config configs/your_config.yaml 96 | ``` 97 | 98 | For distributed processing across multiple machines: 99 | 100 | ```bash 101 | python stash.py --config configs/your_config.yaml --machine 0 102 | python stash.py --config configs/your_config.yaml --machine 1 103 | # ... and so on 104 | ``` -------------------------------------------------------------------------------- /configs/commanda_111b.yaml: -------------------------------------------------------------------------------- 1 | run_name: commanda_111b 2 | num_runs: 1 # Increase this! 3 | transformer_config: 4 | model_name: CohereForAI/c4ai-command-a-03-2025 5 | dtype: bfloat16 6 | cache_dir: /cache 7 | max_per_device_memory: 35GB 8 | data_config: 9 | bucket_name: main 10 | data_key: edu 11 | n_batches: 64 # 250000 12 | seq_length: 2048 13 | batch_size: 4 14 | seed: 42 15 | skip_cache: False 16 | start_batch: 0 17 | clean_added_tokens: True 18 | # clean_default_system_prompt: True # Cohere has a long system prompt 19 | upload_config: 20 | batches_per_upload: 16 21 | hooks: [ 22 | "models.layers.12.self_attn.pre", "models.layers.12.self_attn.post", 23 | "models.layers.12.mlp.pre", "models.layers.12.mlp.post", 24 | "models.layers.24.self_attn.pre", "models.layers.24.self_attn.post", 25 | "models.layers.24.mlp.pre", "models.layers.24.mlp.post", 26 | "models.layers.36.self_attn.pre", "models.layers.36.self_attn.post", 27 | "models.layers.36.mlp.pre", "models.layers.36.mlp.post" 28 | ] 29 | -------------------------------------------------------------------------------- /configs/gemma2_27b.yaml: -------------------------------------------------------------------------------- 1 | run_name: gemma2_27b 2 | num_runs: 1 # Increase this! 3 | transformer_config: 4 | model_name: google/gemma-2-27b-it 5 | dtype: bfloat16 6 | cache_dir: /cache 7 | max_per_device_memory: 60GB 8 | data_config: 9 | bucket_name: main 10 | data_key: web-instruct:CHAT_TEMPLATE 11 | n_batches: 150 # 250000 12 | seq_length: 2048 13 | batch_size: 4 14 | seed: 42 15 | skip_cache: False 16 | start_batch: 0 17 | clean_added_tokens: True 18 | clean_default_system_prompt: True 19 | upload_config: 20 | batches_per_upload: 32 21 | hooks: [ 22 | "models.layers.6.self_attn.pre", "models.layers.6.self_attn.post", 23 | "models.layers.6.mlp.pre", "models.layers.6.mlp.post", 24 | "models.layers.12.self_attn.pre", "models.layers.12.self_attn.post", 25 | "models.layers.12.mlp.pre", "models.layers.12.mlp.post", 26 | "models.layers.18.self_attn.pre", "models.layers.18.self_attn.post", 27 | "models.layers.18.mlp.pre", "models.layers.18.mlp.post", 28 | "models.layers.24.self_attn.pre", "models.layers.24.self_attn.post", 29 | "models.layers.24.mlp.pre", "models.layers.24.mlp.post" 30 | ] -------------------------------------------------------------------------------- /configs/llama3.3_70b.yaml: -------------------------------------------------------------------------------- 1 | run_name: llama3.3_70b 2 | num_runs: 1 # Increase this! 3 | transformer_config: 4 | model_name: meta-llama/Llama-3.3-70B-Instruct 5 | dtype: float16 6 | cache_dir: /cache 7 | max_per_device_memory: 55GB 8 | data_config: 9 | bucket_name: main 10 | data_key: infinity-instruct:CHAT_TEMPLATE 11 | n_batches: 256 # 250000 12 | seq_length: 2048 13 | batch_size: 4 14 | seed: 42 15 | skip_cache: False 16 | start_batch: 0 17 | clean_added_tokens: True 18 | clean_default_system_prompt: True 19 | upload_config: 20 | batches_per_upload: 32 21 | hooks: [ 22 | "models.layers.24.self_attn.pre", "models.layers.24.self_attn.post", 23 | "models.layers.24.mlp.pre", "models.layers.24.mlp.post", 24 | "models.layers.36.self_attn.pre", "models.layers.36.self_attn.post", 25 | "models.layers.36.mlp.pre", "models.layers.36.mlp.post", 26 | "models.layers.48.self_attn.pre", "models.layers.48.self_attn.post", 27 | "models.layers.48.mlp.pre", "models.layers.48.mlp.post", 28 | "models.layers.56.self_attn.pre", "models.layers.56.self_attn.post", 29 | "models.layers.56.mlp.pre", "models.layers.56.mlp.post" 30 | ] 31 | -------------------------------------------------------------------------------- /configs/mistral_24b.yaml: -------------------------------------------------------------------------------- 1 | run_name: mistral_24b 2 | num_runs: 1 # Increase this! 3 | transformer_config: 4 | model_name: mistralai/Mistral-Small-24B-Instruct-2501 5 | dtype: bfloat16 6 | cache_dir: /cache 7 | max_per_device_memory: 50GB 8 | data_config: 9 | bucket_name: main 10 | data_key: lmsys:CHAT_TEMPLATE 11 | n_batches: 150 12 | seq_length: 2048 13 | batch_size: 4 14 | seed: 42 15 | skip_cache: False 16 | start_batch: 0 17 | clean_added_tokens: True 18 | # clean_default_system_prompt: True # Mistral has a long system prompt 19 | upload_config: 20 | batches_per_upload: 32 21 | hooks: [ 22 | "models.layers.4.self_attn.pre", "models.layers.4.self_attn.post", 23 | "models.layers.4.mlp.pre", "models.layers.4.mlp.post", 24 | "models.layers.12.self_attn.pre", "models.layers.12.self_attn.post", 25 | "models.layers.12.mlp.pre", "models.layers.12.mlp.post", 26 | "models.layers.20.self_attn.pre", "models.layers.20.self_attn.post", 27 | "models.layers.20.mlp.pre", "models.layers.20.mlp.post", 28 | "models.layers.28.self_attn.pre", "models.layers.28.self_attn.post", 29 | "models.layers.28.mlp.pre", "models.layers.28.mlp.post" 30 | ] -------------------------------------------------------------------------------- /configs/qwq_32b.yaml: -------------------------------------------------------------------------------- 1 | run_name: qwq_32b 2 | num_runs: 1 # Increase this! 3 | transformer_config: 4 | model_name: Qwen/QwQ-32B 5 | dtype: bfloat16 6 | cache_dir: /cache 7 | max_per_device_memory: 30GB 8 | data_config: 9 | bucket_name: main 10 | data_key: qwq_cot:CHAT_TEMPLATE 11 | n_batches: 150 # 250000 12 | seq_length: 8192 13 | batch_size: 2 14 | seed: 42 15 | skip_cache: True 16 | clean_added_tokens: True 17 | clean_default_system_prompt: True 18 | start_batch: 0 19 | upload_config: 20 | batches_per_upload: 32 21 | hooks: [ 22 | "models.layers.24.self_attn.pre", "models.layers.24.self_attn.post", 23 | "models.layers.24.mlp.pre", "models.layers.24.mlp.post", 24 | "models.layers.36.self_attn.pre", "models.layers.36.self_attn.post", 25 | "models.layers.36.mlp.pre", "models.layers.36.mlp.post", 26 | "models.layers.48.self_attn.pre", "models.layers.48.self_attn.post", 27 | "models.layers.48.mlp.pre", "models.layers.48.mlp.post", 28 | "models.layers.56.self_attn.pre", "models.layers.56.self_attn.post", 29 | "models.layers.56.mlp.pre", "models.layers.56.mlp.post" 30 | ] -------------------------------------------------------------------------------- /configs/r1_distill_qwen_32b.yaml: -------------------------------------------------------------------------------- 1 | run_name: r1_distill_qwen_32b 2 | num_runs: 1 # Increase this! 3 | transformer_config: 4 | model_name: deepseek-ai/DeepSeek-R1-Distill-Qwen-32B 5 | dtype: bfloat16 6 | cache_dir: /cache 7 | max_per_device_memory: 40GB 8 | data_config: 9 | bucket_name: main 10 | data_key: s1.1 11 | n_batches: 150 # 500 12 | seq_length: 8192 13 | batch_size: 2 14 | seed: 42 15 | skip_cache: True 16 | start_batch: 0 17 | clean_added_tokens: True 18 | clean_default_system_prompt: True 19 | upload_config: 20 | batches_per_upload: 32 21 | hooks: [ 22 | "models.layers.24.self_attn.pre", "models.layers.24.self_attn.post", 23 | "models.layers.24.mlp.pre", "models.layers.24.mlp.post", 24 | "models.layers.36.self_attn.pre", "models.layers.36.self_attn.post", 25 | "models.layers.36.mlp.pre", "models.layers.36.mlp.post", 26 | "models.layers.48.self_attn.pre", "models.layers.48.self_attn.post", 27 | "models.layers.48.mlp.pre", "models.layers.48.mlp.post", 28 | "models.layers.56.self_attn.pre", "models.layers.56.self_attn.post", 29 | "models.layers.56.mlp.pre", "models.layers.56.mlp.post" 30 | ] -------------------------------------------------------------------------------- /pipeline/config.py: -------------------------------------------------------------------------------- 1 | """Copyright (2025) Tilde Research Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | """ 15 | 16 | from dataclasses import dataclass 17 | from typing import Optional, Dict, Any 18 | import yaml 19 | import argparse 20 | import json 21 | import boto3 22 | import threading 23 | import re 24 | 25 | 26 | @dataclass 27 | class Config: 28 | """Configuration manager for the ActiVault pipeline. 29 | 30 | This class handles all configuration aspects of the pipeline, including loading from YAML, 31 | command line arguments, and S3 integration. It supports both synchronous and asynchronous 32 | saving to S3. 33 | 34 | Attributes: 35 | run_name: Unique identifier for this pipeline run (format: model_family.model_size.dataset[.template]) 36 | transformer_config: Configuration settings for the transformer model 37 | data_config: Settings for data loading and processing 38 | upload_config: Configuration for upload behavior and hooks 39 | num_runs: Number of parallel runs to execute 40 | total_tokens: Counter for processed tokens 41 | d_model: Model dimension if applicable 42 | n_total_files: Counter for processed files 43 | batches_processed: Counter for processed batches 44 | machine_index: Index for distributed processing 45 | 46 | Example: 47 | ```python 48 | # Load from YAML 49 | config = Config.from_yaml("configs/default.yaml") 50 | 51 | # Load from command line 52 | config = Config.from_args() 53 | 54 | # Save to S3 55 | config.save_to_s3(s3_client) 56 | ``` 57 | """ 58 | 59 | run_name: str 60 | transformer_config: Dict[str, Any] 61 | data_config: Dict[str, Any] 62 | upload_config: Dict[str, Any] 63 | num_runs: int 64 | total_tokens: int = 0 65 | d_model: Optional[int] = None 66 | n_total_files: int = 0 67 | batches_processed: int = 0 68 | _save_thread: Optional[threading.Thread] = None 69 | machine_index: int = 0 70 | _config_key: Optional[str] = None # Track the config file path once created 71 | 72 | def to_dict(self) -> dict: 73 | """Convert configuration to a dictionary format for serialization. 74 | 75 | Returns: 76 | dict: Configuration in dictionary format 77 | """ 78 | return { 79 | "run_name": self.run_name, 80 | "transformer_config": self.transformer_config, 81 | "data_config": self.data_config, 82 | "upload_config": self.upload_config, 83 | "total_tokens": self.total_tokens, 84 | "d_model": self.d_model, 85 | "n_total_files": self.n_total_files, 86 | "batches_processed": self.batches_processed, 87 | "num_runs": self.num_runs, 88 | } 89 | 90 | @classmethod 91 | def from_yaml(cls, path: str = "configs/default.yaml") -> "Config": 92 | """Create configuration from a YAML file. 93 | 94 | Args: 95 | path: Path to YAML configuration file 96 | 97 | Returns: 98 | Config: New configuration instance 99 | """ 100 | with open(path) as f: 101 | config_dict = yaml.safe_load(f) 102 | return cls(**config_dict) 103 | 104 | @classmethod 105 | def from_args(cls) -> "Config": 106 | """Create configuration from command line arguments. 107 | 108 | Supports --config and --machine arguments. 109 | 110 | Returns: 111 | Config: New configuration instance 112 | """ 113 | parser = argparse.ArgumentParser() 114 | parser.add_argument( 115 | "--config", type=str, default="configs/llama3.3_70b.yaml", help="Path to config file" 116 | ) 117 | parser.add_argument( 118 | "--machine", type=int, default=0, help="Index of the machine for distributed processing" 119 | ) 120 | args = parser.parse_args() 121 | config = cls.from_yaml(args.config) 122 | config.machine_index = args.machine 123 | return config 124 | 125 | def _get_next_config_number(self, s3_client: boto3.client) -> int: 126 | """Find the next available config number by checking existing configs.""" 127 | try: 128 | # List all objects in the run directory 129 | response = s3_client.list_objects_v2( 130 | Bucket=self.data_config["bucket_name"], Prefix=f"{self.run_name}/cfg" 131 | ) 132 | 133 | # Find all numbered configs (cfg1.json, cfg2.json, etc) 134 | numbers = [0] # 0 represents the base cfg.json 135 | pattern = re.compile(r"cfg(\d+)\.json$") 136 | 137 | if "Contents" in response: 138 | for obj in response["Contents"]: 139 | match = pattern.search(obj["Key"]) 140 | if match: 141 | numbers.append(int(match.group(1))) 142 | 143 | return max(numbers) + 1 144 | except Exception: 145 | return 1 146 | 147 | def _get_latest_config_key(self, s3_client: boto3.client) -> str: 148 | """Find the latest config file (highest number).""" 149 | try: 150 | response = s3_client.list_objects_v2( 151 | Bucket=self.data_config["bucket_name"], Prefix=f"{self.run_name}/cfg" 152 | ) 153 | 154 | latest_num = -1 155 | latest_key = f"{self.run_name}/cfg.json" 156 | pattern = re.compile(r"cfg(\d+)\.json$") 157 | 158 | if "Contents" in response: 159 | for obj in response["Contents"]: 160 | match = pattern.search(obj["Key"]) 161 | if match: 162 | num = int(match.group(1)) 163 | if num > latest_num: 164 | latest_num = num 165 | latest_key = obj["Key"] 166 | elif obj["Key"].endswith("/cfg.json") and latest_num == -1: 167 | latest_key = obj["Key"] 168 | 169 | return latest_key 170 | except Exception: 171 | return f"{self.run_name}/cfg.json" 172 | 173 | def _save_config_thread(self, s3_client: boto3.client) -> None: 174 | """Thread target for saving config to S3.""" 175 | config_dict = self.to_dict() 176 | 177 | # Only determine the config file name if we haven't already 178 | if self._config_key is None: 179 | self._config_key = f"{self.run_name}/cfg.json" 180 | 181 | s3_client.put_object( 182 | Body=json.dumps(config_dict), 183 | Bucket=self.data_config["bucket_name"], 184 | Key=self._config_key, 185 | ) 186 | 187 | def save_to_s3(self, s3_client: boto3.client, blocking: bool = False) -> None: 188 | """Save configuration to S3. 189 | 190 | Args: 191 | s3_client: Boto3 S3 client 192 | blocking: If True, wait for previous save to complete before starting new one 193 | 194 | Note: 195 | For a new run, the first save will create a new config file. 196 | Subsequent saves will overwrite the same file. 197 | """ 198 | # Wait for previous save to complete if requested 199 | if blocking and self._save_thread is not None: 200 | self._save_thread.join() 201 | 202 | # Start new save thread 203 | self._save_thread = threading.Thread(target=self._save_config_thread, args=(s3_client,)) 204 | self._save_thread.start() 205 | 206 | # If blocking, wait for this save to complete 207 | if blocking: 208 | self._save_thread.join() 209 | 210 | @classmethod 211 | def load_from_s3(cls, s3_client: boto3.client, bucket_name: str) -> Optional["Config"]: 212 | """Load existing configuration from S3. 213 | 214 | Args: 215 | s3_client: Boto3 S3 client 216 | run_name: Name of the run to load 217 | bucket_name: S3 bucket name 218 | 219 | Returns: 220 | Optional[Config]: Loaded configuration or None if not found 221 | """ 222 | try: 223 | # Create temporary config to access methods 224 | temp_config = cls( 225 | run_name="temp", transformer_config={}, data_config={}, upload_config={}, num_runs=1 226 | ) 227 | temp_config.data_config["bucket_name"] = bucket_name 228 | 229 | # Get the latest config key 230 | key = temp_config._get_latest_config_key(s3_client) 231 | 232 | response = s3_client.get_object(Bucket=bucket_name, Key=key) 233 | config_dict = json.loads(response["Body"].read()) 234 | 235 | # Convert from old format to new 236 | if "max_length" in config_dict: 237 | config_dict["seq_length"] = config_dict.pop("max_length") 238 | if "batches_per_cache" in config_dict: 239 | config_dict["batches_per_upload"] = config_dict.pop("batches_per_cache") 240 | 241 | # Move n_batches to data_config if it exists at root level 242 | if "n_batches" in config_dict: 243 | if "data_config" not in config_dict: 244 | config_dict["data_config"] = {} 245 | config_dict["data_config"]["n_batches"] = config_dict.pop("n_batches") 246 | 247 | # Add missing fields with defaults 248 | config_dict.setdefault("num_runs", 1) # Default to 1 if missing 249 | 250 | # Remove thread field if present 251 | config_dict.pop("_save_thread", None) 252 | 253 | return cls(**config_dict) 254 | except s3_client.exceptions.NoSuchKey: 255 | return None 256 | 257 | @staticmethod 258 | def load_hook_statistics( 259 | s3_client: boto3.client, run_name: str, hook: str, bucket_name: str 260 | ) -> Optional[Dict[str, Any]]: 261 | """Load existing statistics for a hook from S3.""" 262 | try: 263 | response = s3_client.get_object( 264 | Bucket=bucket_name, Key=f"{run_name}/{hook}/statistics.json" 265 | ) 266 | return json.loads(response["Body"].read()) 267 | except s3_client.exceptions.NoSuchKey: 268 | return None 269 | -------------------------------------------------------------------------------- /pipeline/data/dataloader.py: -------------------------------------------------------------------------------- 1 | """Copyright (2025) Tilde Research Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | """ 15 | 16 | from typing import Iterator, List, Dict 17 | import torch 18 | import numpy as np 19 | from transformers import PreTrainedTokenizer, DataCollatorWithFlattening 20 | from datasets import IterableDataset 21 | from tqdm import tqdm 22 | import logging 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | # If your samples are very short and you are cleaning tokens, you may need to increase this. 27 | # Alternatively, you can reduce this for higher efficiency. 28 | MAX_N_PACKED = 12 29 | 30 | 31 | class DataLoader: 32 | """ 33 | A streaming dataloader that packs clean token sequences for activation collection. 34 | Removes redundant/anomalous tokens (special tokens, chat templates, etc.) from the dataset. 35 | """ 36 | 37 | def __init__( 38 | self, 39 | dataset: IterableDataset, 40 | tokenizer: PreTrainedTokenizer, 41 | max_length: int, 42 | batch_size: int, 43 | start_batch_skip: int = 0, 44 | batches_per_machine: int = None, 45 | dataset_key: str = None, 46 | skip_cache: bool = False, 47 | clean_added_tokens: bool = True, 48 | clean_default_system_prompt: bool = True, 49 | ): 50 | self.dataset = dataset 51 | self.tokenizer = tokenizer 52 | self.batch_size = batch_size 53 | self.batches_per_machine = batches_per_machine 54 | self.apply_chat_template = dataset_key.endswith(":CHAT_TEMPLATE") 55 | self.skip_cache = skip_cache 56 | self.clean_added_tokens = clean_added_tokens 57 | self.clean_default_system_prompt = clean_default_system_prompt 58 | logger.info(f"NOTICE: Cleaning default system prompt: {clean_default_system_prompt}") 59 | # Initialize dataset iterator 60 | self.dataset_iter = iter(self.dataset) 61 | 62 | # Setup invalid tokens list 63 | self.invalid_tokens = [ 64 | self.tokenizer.eos_token_id, 65 | self.tokenizer.pad_token_id, 66 | self.tokenizer.bos_token_id, 67 | -100, 68 | 128000, # Chat template token 69 | ] 70 | 71 | if hasattr(self.tokenizer, "added_tokens_decoder") and self.clean_added_tokens: 72 | self.invalid_tokens.extend(self.tokenizer.added_tokens_decoder.keys()) 73 | 74 | # Add buffer to account for token replacement 75 | self.buffer_size = 0 76 | if self.clean_added_tokens: 77 | self.buffer_size += 20 * MAX_N_PACKED 78 | 79 | # Chat-specific setup 80 | if self.clean_default_system_prompt: 81 | # Get the start string pattern by applying template to a dummy message 82 | dummy_msg = [{"role": "user", "content": "test"}] 83 | template_text = tokenizer.apply_chat_template(dummy_msg, tokenize=False) 84 | 85 | pos = template_text.find("test") 86 | self.begin_str = template_text[:pos] 87 | 88 | # Tokenize string to remove 89 | self.tokenized_system_prompt = tokenizer( 90 | self.begin_str, truncation=False, return_attention_mask=False 91 | )["input_ids"][1:] 92 | 93 | self.buffer_size += len(self.tokenized_system_prompt) * MAX_N_PACKED 94 | 95 | self.base_max_length = max_length 96 | self.max_length = max_length + self.buffer_size 97 | logger.info(f"Buffer size: {self.buffer_size}") 98 | 99 | # Initialize state 100 | self.current_batch_tokens = 0 101 | self.current_texts = [] 102 | self.cached_text = None 103 | 104 | # Setup collator 105 | self.collator = DataCollatorWithFlattening(return_position_ids=True, separator_id=-100) 106 | 107 | # Initial skip based on start batch 108 | start_tokens = 25 * start_batch_skip * batch_size * max_length 109 | self.skip_tokens(start_tokens) 110 | 111 | def skip_tokens(self, num_tokens_to_skip: int) -> None: 112 | """Skip the specified number of tokens.""" 113 | tokens_so_far = 0 114 | n_seqs = 0 115 | 116 | # Heuristic -> average seq_len = 256 117 | if num_tokens_to_skip > 1e7: 118 | base_token_volume = 64 119 | self.dataset = self.dataset.skip((num_tokens_to_skip - 1e7) // base_token_volume) 120 | tokens_so_far = num_tokens_to_skip - 1e7 121 | n_seqs = (num_tokens_to_skip - 1e7) // base_token_volume 122 | 123 | pbar = tqdm(total=num_tokens_to_skip, desc="Skipping tokens") 124 | pbar.update(tokens_so_far) 125 | 126 | while tokens_so_far < num_tokens_to_skip: 127 | text_batch = [] 128 | for _ in range(64): 129 | try: 130 | text = next(self.dataset_iter)["text"] 131 | if self.apply_chat_template: 132 | text_batch.append( 133 | [self._convert_chat_format(text[i]) for i in range(len(text))] 134 | ) 135 | else: 136 | text_batch.append(text) 137 | except Exception as e: 138 | if isinstance(e, StopIteration): 139 | break 140 | else: 141 | logger.error(f"Error skipping tokens: {e}") 142 | 143 | if not text_batch: 144 | break 145 | 146 | if self.apply_chat_template: 147 | text_batch = self.tokenizer.apply_chat_template(text_batch, tokenize=False) 148 | 149 | tokenized = self.tokenizer(text_batch, truncation=False, return_attention_mask=False) 150 | 151 | for input_ids in tokenized["input_ids"]: 152 | n_tokens = len(self._clean_sequence(input_ids)[0]) 153 | tokens_so_far += n_tokens 154 | n_seqs += 1 155 | pbar.update(n_tokens) 156 | 157 | if tokens_so_far >= num_tokens_to_skip: 158 | break 159 | 160 | pbar.close() 161 | logger.info(f"Skipped {tokens_so_far} tokens from {n_seqs} sequences") 162 | 163 | @staticmethod 164 | def _convert_chat_format(msg: Dict[str, str]) -> Dict[str, str]: 165 | """Unify chat message formats.""" 166 | try: 167 | # If already in correct format, validate fields 168 | if "role" in msg: 169 | if not isinstance(msg.get("content"), str): 170 | # Default to empty string if content is None or invalid 171 | msg["content"] = "" 172 | if not isinstance(msg.get("role"), str): 173 | # Default to user if role is None or invalid 174 | msg["role"] = "user" 175 | return msg 176 | 177 | # Convert from old format 178 | if not isinstance(msg.get("from"), str) or not isinstance(msg.get("value"), str): 179 | # Default to empty message if fields are invalid 180 | return {"role": "user", "content": ""} 181 | 182 | role_map = {"human": "user", "gpt": "assistant", "system": "system"} 183 | role = role_map.get(msg["from"], "user") # Default to user if unknown role 184 | return {"role": role, "content": msg["value"]} 185 | except Exception as e: 186 | logger.warning(f"Error converting chat format for message: {msg}. Error: {e}") 187 | return {"role": "user", "content": ""} # Return safe default 188 | 189 | def _clean_sequence(self, input_ids: torch.Tensor) -> tuple[List[int], List[bool]]: 190 | """Clean a sequence by removing invalid tokens.""" 191 | ids = input_ids.cpu().numpy() if torch.is_tensor(input_ids) else np.array(input_ids) 192 | length = len(ids) 193 | 194 | # Initialize mask array 195 | mask = np.ones(length, dtype=bool) 196 | 197 | # Remove invalid tokens 198 | invalid_mask = np.isin(ids, self.invalid_tokens) 199 | mask &= ~invalid_mask 200 | 201 | if self.clean_default_system_prompt: 202 | # Remove both sys prompt 203 | for sys_str in [self.tokenized_system_prompt]: 204 | cutoff_len = len(sys_str) 205 | if cutoff_len > 0 and len(ids) >= cutoff_len: 206 | windows = np.lib.stride_tricks.sliding_window_view(ids, cutoff_len) 207 | matches = np.all(windows == sys_str, axis=1) 208 | match_positions = np.where(matches)[0] 209 | for pos in match_positions: 210 | mask[pos : pos + cutoff_len] = False 211 | 212 | return ids[mask].tolist(), mask.tolist() 213 | 214 | def clean_batch( 215 | self, input_ids: torch.Tensor, states: torch.Tensor 216 | ) -> tuple[torch.Tensor, torch.Tensor]: 217 | """ 218 | Clean a batch of sequences by removing invalid tokens. 219 | 220 | Args: 221 | input_ids: tensor of shape [batch_size, seq_len] 222 | states: tensor of shape [batch_size, seq_len, hidden_dim] 223 | Returns: 224 | cleaned input_ids and states 225 | """ 226 | batch_size, seq_len = input_ids.shape 227 | true_input_ids = [] 228 | true_states = [] 229 | 230 | for i in range(batch_size): 231 | valid_ids, valid_mask = self._clean_sequence(input_ids[i]) 232 | valid_positions = torch.where(torch.tensor(valid_mask))[0] 233 | 234 | if len(valid_positions) > 0: 235 | true_input_ids.append( 236 | input_ids[i][valid_positions].contiguous()[: self.base_max_length] 237 | ) 238 | true_states.append(states[i][valid_positions].contiguous()[: self.base_max_length]) 239 | try: 240 | return torch.stack(true_input_ids), torch.stack(true_states) 241 | except Exception as e: 242 | logger.error( 243 | f"Error cleaning batch. \nThis is likely due to too small a buffer. Consider increasing MAX_N_PACKED in dataloader.py" 244 | ) 245 | raise e 246 | 247 | def __len__(self) -> int: 248 | return len(self.dataset) // (self.batch_size * self.max_length) 249 | 250 | def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: 251 | self.processed_batches = 0 252 | return self 253 | 254 | @staticmethod 255 | def _fix_chat_sequence(messages: List[Dict[str, str]]) -> List[Dict[str, str]]: 256 | """Ensure chat messages follow the required alternating pattern.""" 257 | if not messages: 258 | return [{"role": "user", "content": ""}] 259 | 260 | fixed = [] 261 | # Handle optional system message 262 | if messages[0]["role"] == "system": 263 | fixed.append(messages[0]) 264 | messages = messages[1:] 265 | 266 | # Start with user if no messages or last was assistant 267 | if not messages or (fixed and fixed[-1]["role"] == "assistant") or messages[0]["role"] == "assistant": 268 | fixed.append({"role": "user", "content": ""}) 269 | 270 | # Process remaining messages ensuring alternation 271 | for msg in messages: 272 | # Skip if would violate alternation 273 | if fixed and msg["role"] == fixed[-1]["role"]: 274 | continue 275 | # Add dummy message if needed to maintain alternation 276 | if fixed and fixed[-1]["role"] not in ["user", "assistant"]: 277 | fixed.append({"role": "user" if msg["role"] == "assistant" else "assistant", "content": ""}) 278 | fixed.append(msg) 279 | 280 | # Ensure we end with assistant response 281 | if not fixed or fixed[-1]["role"] == "user": 282 | fixed.append({"role": "assistant", "content": ""}) 283 | 284 | return fixed 285 | 286 | def __next__(self) -> Dict[str, torch.Tensor]: 287 | if ( 288 | self.batches_per_machine is not None 289 | and self.processed_batches >= self.batches_per_machine 290 | ): 291 | raise StopIteration 292 | 293 | # Initialize sequence buffer 294 | seq_buffer = [] 295 | 296 | # Keep collecting sequences until we have batch_size or hit StopIteration 297 | while len(seq_buffer) < self.batch_size: 298 | try: 299 | # Get next sequence 300 | if self.cached_text and not self.apply_chat_template and not self.skip_cache: 301 | sample = self.cached_text 302 | self.cached_text = None 303 | else: 304 | sample = next(self.dataset_iter)["text"] 305 | 306 | if self.apply_chat_template: 307 | messages = [self._convert_chat_format(s) for s in sample] 308 | messages = self._fix_chat_sequence(messages) 309 | sample = self.tokenizer.apply_chat_template(messages, tokenize=False) 310 | 311 | tokenized = self.tokenizer(sample, truncation=False, return_attention_mask=False) 312 | 313 | n_tokens = len(tokenized["input_ids"]) 314 | self.current_texts.append(tokenized) 315 | 316 | space_left = self.max_length - self.current_batch_tokens 317 | if n_tokens > space_left: 318 | if n_tokens > self.max_length * 2: 319 | self.cached_text = self.tokenizer.decode( 320 | self.current_texts[-1]["input_ids"][space_left:] 321 | ) 322 | self.current_texts[-1]["input_ids"] = self.current_texts[-1]["input_ids"][ 323 | :space_left 324 | ] 325 | 326 | # Add packed sequence to buffer 327 | texts = self.current_texts 328 | self.current_texts = [] 329 | self.current_batch_tokens = 0 330 | seq_buffer.append(self.collator(texts)) 331 | 332 | else: 333 | self.current_batch_tokens += n_tokens 334 | if self.current_batch_tokens >= self.max_length: 335 | # Add packed sequence to buffer 336 | texts = self.current_texts 337 | self.current_texts = [] 338 | self.current_batch_tokens = 0 339 | seq_buffer.append(self.collator(texts)) 340 | 341 | except StopIteration: 342 | if self.current_texts: 343 | # Add final packed sequence to buffer 344 | seq_buffer.append(self.collator(self.current_texts)) 345 | self.current_texts = [] 346 | self.current_batch_tokens = 0 347 | if not seq_buffer: 348 | raise StopIteration 349 | break 350 | 351 | # Stack all sequences in the buffer into a batch 352 | self.processed_batches += 1 353 | return { 354 | "input_ids": torch.cat([seq["input_ids"] for seq in seq_buffer], dim=0), 355 | "position_ids": torch.cat([seq["position_ids"] for seq in seq_buffer], dim=0), 356 | } 357 | -------------------------------------------------------------------------------- /pipeline/data/dataset.py: -------------------------------------------------------------------------------- 1 | """Copyright (2025) Tilde Research Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | """ 15 | 16 | import json 17 | from pathlib import Path 18 | from typing import Dict, Any, List, Union, Callable 19 | from datasets import load_dataset, Dataset, IterableDataset 20 | import logging 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | def load_dataset_config() -> Dict[str, Any]: 26 | """Load dataset configurations from datasets.json. 27 | 28 | The configuration defines how each dataset should be loaded, 29 | including templates and filtering conditions. 30 | 31 | Returns: 32 | dict: Mapping of dataset names to their configurations 33 | 34 | Raises: 35 | FileNotFoundError: If datasets.json is not found 36 | json.JSONDecodeError: If datasets.json is invalid 37 | """ 38 | config_path = Path(__file__).parent / "datasets.json" 39 | with open(config_path) as f: 40 | return json.load(f) 41 | 42 | 43 | def apply_filter_conditions( 44 | dataset: Union[Dataset, IterableDataset], conditions: Dict[str, Any] 45 | ) -> Union[Dataset, IterableDataset]: 46 | """Apply filtering conditions to a dataset. 47 | 48 | Args: 49 | dataset: Input dataset to filter 50 | conditions: Dictionary mapping column names to required values 51 | 52 | Returns: 53 | Dataset: Filtered dataset containing only matching rows 54 | 55 | Example: 56 | ```python 57 | filtered_ds = apply_filter_conditions(ds, {"language": "en"}) 58 | ``` 59 | """ 60 | for column, value in conditions.items(): 61 | dataset = dataset.filter(lambda x: x[column] == value) 62 | return dataset 63 | 64 | 65 | def deduplicate_dataset( 66 | dataset: Union[Dataset, IterableDataset], conditions: Dict[str, Any] 67 | ) -> Union[Dataset, IterableDataset]: 68 | """Deduplicate dataset rows based on specified columns. 69 | 70 | Args: 71 | dataset: Input dataset to deduplicate 72 | conditions: Dictionary specifying columns to check for duplicates 73 | 74 | Returns: 75 | Dataset: Deduplicated dataset 76 | 77 | Example: 78 | ```python 79 | deduped_ds = deduplicate_dataset(ds, {"url": True}) 80 | ``` 81 | """ 82 | for column, value in conditions.items(): 83 | seen_values = set() 84 | 85 | def dedup_filter(example): 86 | value = example[column] 87 | if value in seen_values: 88 | return False 89 | seen_values.add(value) 90 | return True 91 | 92 | dataset = dataset.filter(dedup_filter) 93 | return dataset 94 | 95 | 96 | def make_template_fn(template: str, columns: List[str], dataset_name: str) -> Callable: 97 | """Create a function that formats examples using a template. 98 | 99 | Args: 100 | template: String template with {column} placeholders 101 | columns: List of column names to use in template 102 | dataset_name: Name of dataset (for error messages) 103 | 104 | Returns: 105 | Callable: Function that formats examples using the template 106 | 107 | Raises: 108 | KeyError: If required columns are missing 109 | ValueError: If template produces empty text 110 | 111 | Example: 112 | ```python 113 | template_fn = make_template_fn( 114 | "Question: {question}\nAnswer: {answer}", 115 | ["question", "answer"], 116 | "qa_dataset" 117 | ) 118 | ``` 119 | """ 120 | 121 | def apply_template(example): 122 | try: 123 | text = template.format(**{col: example[col] for col in columns}) 124 | if not text or text.isspace(): 125 | raise ValueError(f"Empty text generated for {dataset_name}") 126 | return {"text": text} 127 | except KeyError as e: 128 | raise KeyError(f"Missing required column {e} for dataset {dataset_name}") 129 | 130 | return apply_template 131 | 132 | 133 | def make_chat_template_fn(template: List[str], columns: List[str], dataset_name: str) -> Callable: 134 | """Create a function that formats examples into chat format. 135 | 136 | Args: 137 | template: List of role names ("user", "assistant", etc.) 138 | columns: List of column names containing messages 139 | dataset_name: Name of dataset (for error messages) 140 | 141 | Returns: 142 | Callable: Function that formats examples into chat format 143 | 144 | Raises: 145 | KeyError: If required columns are missing 146 | 147 | Example: 148 | ```python 149 | chat_fn = make_chat_template_fn( 150 | ["user", "assistant"], 151 | ["question", "answer"], 152 | "dialogue_dataset" 153 | ) 154 | ``` 155 | """ 156 | 157 | def apply_chat_template(example): 158 | try: 159 | entry = [] 160 | for role, column in zip(template, columns): 161 | entry.append({"role": role, "content": example[column]}) 162 | return {"text": entry} 163 | except KeyError as e: 164 | raise KeyError(f"Missing required column {e} for dataset {dataset_name}") 165 | 166 | return apply_chat_template 167 | 168 | 169 | def load_dataset_by_key( 170 | dataset_key: str, split: str = "train", streaming: bool = True 171 | ) -> IterableDataset: 172 | """Load a single dataset according to its configuration. 173 | 174 | This function loads a dataset and applies templates and filters as configured. 175 | 176 | Args: 177 | dataset_key: Name of the dataset configuration to use 178 | split: Dataset split to load ("train", "validation", etc.) 179 | streaming: Whether to use streaming mode 180 | 181 | Returns: 182 | IterableDataset: Processed dataset 183 | 184 | Raises: 185 | KeyError: If dataset_key is not found 186 | Exception: If dataset loading or processing fails 187 | 188 | Example: 189 | ```python 190 | dataset = load_dataset_by_key( 191 | dataset_key="web_instruct", 192 | split="train", 193 | streaming=True 194 | ) 195 | ``` 196 | """ 197 | config = load_dataset_config()[dataset_key] 198 | logger.info(f"\nLoading dataset: {dataset_key}") 199 | 200 | # Load dataset 201 | ds = load_dataset(config["name"], config.get("subset"), split=split, streaming=streaming) 202 | 203 | # Apply filters if specified 204 | if "filter" in config: 205 | ds = apply_filter_conditions(ds, config["filter"]) 206 | 207 | if "deduplicate" in config: 208 | ds = deduplicate_dataset(ds, config["deduplicate"]) 209 | 210 | # Handle different column structures 211 | if "template" in config: 212 | # Apply template if specified 213 | ds = ds.map(make_template_fn(config["template"], config["columns"], dataset_key)) 214 | elif "chat_template" in config: 215 | ds = ds.map(make_chat_template_fn(config["chat_template"], config["columns"], dataset_key)) 216 | else: 217 | # If no template, map the specified column to 'text' 218 | column = config["columns"][0] 219 | if column != "text": 220 | ds = ds.map(lambda x: {"text": x[column]}) 221 | 222 | # Select only the text column 223 | ds = ds.select_columns(["text"]) 224 | logger.info(f"Successfully loaded {dataset_key}") 225 | 226 | return ds 227 | -------------------------------------------------------------------------------- /pipeline/data/datasets.json: -------------------------------------------------------------------------------- 1 | { 2 | "edu": { 3 | "name": "HuggingFaceFW/fineweb-edu", 4 | "subset": "sample-100BT", 5 | "columns": [ 6 | "text" 7 | ] 8 | }, 9 | "web-instruct:CHAT_TEMPLATE": { 10 | "name": "chargoddard/WebInstructSub-prometheus", 11 | "chat_template": [ 12 | "user", 13 | "assistant" 14 | ], 15 | "columns": [ 16 | "instruction", 17 | "generation" 18 | ] 19 | }, 20 | "infinity-instruct:CHAT_TEMPLATE": { 21 | "name": "BAAI/Infinity-Instruct", 22 | "subset": "3M", 23 | "columns": [ 24 | "conversations" 25 | ], 26 | "filter": { 27 | "langdetect": "en" 28 | } 29 | }, 30 | "lmsys:CHAT_TEMPLATE": { 31 | "name": "OpenLeecher/lmsys_chat_1m_clean", 32 | "subset": "default", 33 | "columns": [ 34 | "conversations" 35 | ] 36 | }, 37 | "smoltalk:CHAT_TEMPLATE": { 38 | "name": "HuggingFaceTB/smoltalk", 39 | "subset": "all", 40 | "columns": [ 41 | "messages" 42 | ] 43 | }, 44 | "s1.1": { 45 | "name": "simplescaling/s1K-1.1_tokenized", 46 | "subset": "default", 47 | "columns": [ 48 | "text" 49 | ] 50 | }, 51 | "qwq_cot:CHAT_TEMPLATE": { 52 | "name": "amphora/QwQ-LongCoT-130K", 53 | "subset": "default", 54 | "columns": [ 55 | "problem", 56 | "qwq" 57 | ], 58 | "chat_template": [ 59 | "user", 60 | "assistant" 61 | ] 62 | } 63 | } -------------------------------------------------------------------------------- /pipeline/generate.py: -------------------------------------------------------------------------------- 1 | """Copyright (2025) Tilde Research Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | """ 15 | 16 | import torch 17 | from tqdm import tqdm 18 | from transformers import AutoModelForCausalLM 19 | from typing import Dict 20 | from pipeline.data.dataloader import DataLoader 21 | from pipeline.config import Config 22 | from pipeline.vault import HookUploader 23 | from s3.utils import create_s3_client 24 | import logging 25 | from uuid import uuid4 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | def generate_activations( 31 | model: AutoModelForCausalLM, 32 | loader: DataLoader, 33 | config: Config, 34 | uploaders: Dict[str, HookUploader] = None, 35 | hook_activations: Dict[str, Dict[str, torch.Tensor]] = None, 36 | ) -> None: 37 | """ 38 | Main activation generation loop. 39 | 40 | Args: 41 | model: The transformer model (already on correct device) 42 | loader: DataLoader instance 43 | config: Configuration object 44 | uploaders: Dictionary mapping hook names to their uploaders 45 | """ 46 | # Set d_model in config 47 | config.d_model = model.config.hidden_size 48 | 49 | # Create S3 client for config saving 50 | s3_client = create_s3_client() 51 | 52 | # Load existing config and stats if available 53 | existing_config = Config.load_from_s3(s3_client, config.data_config["bucket_name"]) 54 | if existing_config: 55 | logger.info(f"Resuming run {config.run_name} from {existing_config.total_tokens} tokens") 56 | config.total_tokens = existing_config.total_tokens 57 | config.n_total_files = existing_config.n_total_files 58 | config.batches_processed = existing_config.batches_processed 59 | 60 | # Skip tokens based on existing total tokens with an offset 61 | tokens_to_skip = existing_config.total_tokens + 3000 62 | loader.skip_tokens(tokens_to_skip) 63 | 64 | # Initialize statistics tracking 65 | hooks = config.upload_config["hooks"] 66 | means = {hook: torch.zeros(model.config.hidden_size, device=model.device) for hook in hooks} 67 | 68 | # M2 stores sum of squared differences from the mean (for Welford's algorithm) 69 | M2s = {hook: torch.zeros(model.config.hidden_size, device=model.device) for hook in hooks} 70 | counts = {hook: 0 for hook in hooks} # Track number of samples per dimension 71 | norm_sums = {hook: torch.zeros(1, device=model.device) for hook in hooks} # Track sum of norms 72 | norm_counts = {hook: 0 for hook in hooks} # Track count for norms 73 | 74 | # Load existing statistics if available 75 | for hook in hooks: 76 | stats = Config.load_hook_statistics( 77 | s3_client, config.run_name, hook, config.data_config["bucket_name"] 78 | ) 79 | if stats: 80 | logger.info(f"Loading existing statistics for {hook}") 81 | means[hook] = torch.tensor(stats["mean"], device=model.device) 82 | if "M2" in stats: 83 | M2s[hook] = torch.tensor(stats["M2"], device=model.device) 84 | else: 85 | # If M2 not available, approximate from std (for backward compatibility) 86 | std = torch.tensor(stats["std"], device=model.device) 87 | M2s[hook] = std * std * config.batches_processed 88 | 89 | counts[hook] = config.batches_processed 90 | norm_sums[hook] = stats.get("norm", 0.0) * config.batches_processed 91 | norm_counts[hook] = config.batches_processed 92 | 93 | # Prepare for activation collection 94 | layers = {hook: int(hook.split(".")[2]) for hook in hooks} 95 | 96 | # Initialize batches processed from config 97 | batches_processed = config.batches_processed 98 | 99 | # Calculate tokens to skip based on batches processed 100 | tokens_to_skip = ( 101 | config.batches_processed 102 | * config.data_config["batch_size"] 103 | * config.data_config["seq_length"] 104 | ) 105 | loader.skip_tokens(tokens_to_skip) 106 | 107 | # Main loop 108 | model.eval() 109 | with torch.no_grad(): 110 | total_batches = ( 111 | loader.batches_per_machine 112 | if loader.batches_per_machine is not None 113 | else config.data_config["n_batches"] 114 | ) 115 | pbar = tqdm(total=total_batches) 116 | pbar.update(batches_processed) 117 | 118 | # Generate a new UUID for each batch group 119 | current_group_uuid = str(uuid4()) 120 | group_batch_count = 0 121 | 122 | for batch_idx, batch in enumerate(loader): 123 | if batch_idx >= config.data_config["n_batches"]: 124 | break 125 | 126 | # Move batch to model's device 127 | batch = {k: v.to(device=model.device) for k, v in batch.items()} 128 | 129 | # Forward pass 130 | outputs = model(**batch, output_hidden_states=True) 131 | 132 | # Extract activations 133 | activations = {} 134 | for hook in hooks: 135 | if hook in hook_activations: 136 | activations[hook] = { 137 | "states": hook_activations[hook], 138 | "input_ids": batch["input_ids"], 139 | } 140 | else: 141 | # Handle pre and post hooks separately 142 | layer_idx = int(hook.split(".")[2]) 143 | if "pre" in hook: 144 | activations[hook] = { 145 | "states": outputs.hidden_states[layer_idx], 146 | "input_ids": batch["input_ids"], 147 | } 148 | elif "post" in hook: 149 | activations[hook] = { 150 | "states": outputs.hidden_states[layer_idx + 1], 151 | "input_ids": batch["input_ids"], 152 | } 153 | 154 | try: 155 | # Clean special tokens from activations (e.g. BOS) 156 | cleaned_activations = {} 157 | for hook in hooks: 158 | cleaned_input_ids, cleaned_states = loader.clean_batch( 159 | activations[hook]["input_ids"], activations[hook]["states"] 160 | ) 161 | cleaned_activations[hook] = { 162 | "states": cleaned_states, 163 | "input_ids": cleaned_input_ids, 164 | } 165 | except Exception as e: 166 | logger.warning(f"SKIPPING BATCH {batch_idx} due to error: {e}") 167 | continue 168 | 169 | # Update total tokens 170 | config.total_tokens += ( 171 | config.data_config["batch_size"] * config.data_config["seq_length"] 172 | ) 173 | 174 | # Compute statistics and move to CPU 175 | if uploaders: 176 | any_file_uploaded = False 177 | for hook in hooks: 178 | # Get current batch activations 179 | states = cleaned_activations[hook]["states"] 180 | N, T = states.shape[0], states.shape[1] 181 | total_tokens = N * T 182 | 183 | # Update statistics using Welford's online algorithm 184 | counts[hook] += total_tokens 185 | 186 | # Calculate deltas for Welford's algorithm (vectorized) 187 | delta = states.mean(dim=(0, 1)) - means[hook] 188 | means[hook] += delta * (total_tokens / counts[hook]) 189 | 190 | # Calculate contribution to M2 191 | delta2 = states.mean(dim=(0, 1)) - means[hook] 192 | M2s[hook] += total_tokens * delta * delta2 193 | 194 | # Update norm statistics 195 | norm_sums[hook] += torch.norm(states, dim=2).sum().item() 196 | norm_counts[hook] += total_tokens 197 | 198 | # Move to CPU for saving 199 | cpu_activations = { 200 | "states": states.to(device="cpu", non_blocking=True), 201 | "input_ids": cleaned_activations[hook]["input_ids"].to( 202 | device="cpu", non_blocking=True 203 | ), 204 | } 205 | 206 | # Append to uploader with the current group UUID 207 | file_id = uploaders[hook].append(cpu_activations, current_group_uuid) 208 | if file_id: 209 | any_file_uploaded = True 210 | 211 | if any_file_uploaded: 212 | config.n_total_files += 1 213 | # Generate new UUID for next group since we just uploaded 214 | current_group_uuid = str(uuid4()) 215 | group_batch_count = 0 216 | else: 217 | group_batch_count += 1 218 | 219 | # Update batches processed 220 | batches_processed += 1 221 | 222 | # Save config periodically only from machine index 0 223 | if ( 224 | config.machine_index == 0 225 | and batches_processed % config.upload_config["batches_per_upload"] == 0 226 | ): 227 | # Update config with the current state 228 | config.batches_processed = batches_processed 229 | 230 | # Save config in a non-blocking way 231 | config.save_to_s3(s3_client, blocking=False) 232 | 233 | # Save statistics 234 | if uploaders: 235 | for hook in hooks: 236 | # Extract final statistics from running calculations 237 | mean = means[hook].cpu() 238 | 239 | # Calculate standard deviation from M2 240 | variance = M2s[hook] / counts[hook] 241 | std = torch.sqrt(variance).cpu() 242 | 243 | # Calculate average norm 244 | norm = norm_sums[hook] / norm_counts[hook] if norm_counts[hook] > 0 else 0.0 245 | 246 | # Also save M2 for future resumption 247 | uploaders[hook].save_stats(mean, std, norm, M2=M2s[hook].cpu()) 248 | 249 | pbar.update(1) 250 | 251 | pbar.close() 252 | 253 | # Save final config and statistics only from machine index 0 254 | if config.machine_index == 0: 255 | config.batches_processed = batches_processed 256 | config.save_to_s3(s3_client, blocking=True) # Block on final save 257 | 258 | if uploaders: 259 | for hook in hooks: 260 | # Extract final statistics from running calculations 261 | mean = means[hook].cpu() 262 | 263 | # Calculate standard deviation from M2 264 | variance = M2s[hook] / counts[hook] 265 | std = torch.sqrt(variance).cpu() 266 | 267 | # Calculate average norm 268 | norm = norm_sums[hook] / norm_counts[hook] if norm_counts[hook] > 0 else 0.0 269 | 270 | # Also save M2 for future resumption 271 | uploaders[hook].save_stats(mean, std, norm, M2=M2s[hook].cpu()) 272 | uploaders[hook].finalize() 273 | -------------------------------------------------------------------------------- /pipeline/setup.py: -------------------------------------------------------------------------------- 1 | """Copyright (2025) Tilde Research Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | """ 15 | 16 | import torch 17 | from transformers import AutoModelForCausalLM, AutoTokenizer 18 | from typing import Tuple, List, Dict 19 | import os 20 | from pipeline.vault import HookUploader 21 | from accelerate import dispatch_model, infer_auto_device_map 22 | import logging 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | def setup_model_and_tokenizer( 28 | transformer_config: dict, 29 | hooks: List[str] = None, 30 | ) -> Tuple[AutoModelForCausalLM, AutoTokenizer]: 31 | """Set up and configure a transformer model and tokenizer for activation extraction. 32 | 33 | This function handles model loading, dtype configuration, model truncation based on hooks, 34 | and device mapping for efficient memory usage. 35 | 36 | Args: 37 | transformer_config: Configuration dictionary containing: 38 | - model_name: Name/path of the HuggingFace model 39 | - dtype: Model precision ("float16", "bfloat16", or "float32") 40 | hooks: List of activation hook names, used to determine model truncation 41 | 42 | Returns: 43 | tuple: (model, tokenizer) 44 | - model: Configured AutoModelForCausalLM instance 45 | - tokenizer: Associated AutoTokenizer instance 46 | 47 | Example: 48 | ```python 49 | config = { 50 | "model_name": "gpt2", 51 | "dtype": "float16" 52 | } 53 | model, tokenizer = setup_model_and_tokenizer(config, hooks=["models.layers.24.mlp.post"]) 54 | ``` 55 | """ 56 | tokenizer = AutoTokenizer.from_pretrained(transformer_config["model_name"]) 57 | 58 | # Determine dtype 59 | dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}.get( 60 | transformer_config.get("dtype", "float16"), torch.float16 61 | ) 62 | 63 | model = AutoModelForCausalLM.from_pretrained( 64 | transformer_config["model_name"], 65 | torch_dtype=dtype, 66 | device_map="cpu", 67 | cache_dir=transformer_config.get("cache_dir", None), 68 | attn_implementation="flash_attention_2", 69 | trust_remote_code=True, 70 | **transformer_config.get("kwargs", {}), 71 | ) 72 | 73 | if hooks: 74 | import gc 75 | 76 | # Find the highest layer number from hooks 77 | layer_numbers = [int(hook.split(".")[2]) for hook in hooks if hook.startswith("models.layers")] 78 | max_layer = max(layer_numbers) 79 | 80 | # Print model size before truncation 81 | total_params_before = sum(p.numel() for p in model.parameters()) 82 | logger.info(f"Model parameters before truncation: {total_params_before:,}") 83 | 84 | # Truncate the model 85 | num_layers = len(model.model.layers) 86 | removed_layers = model.model.layers[max_layer + 1 :] 87 | model.model.layers = model.model.layers[: max_layer + 1] 88 | 89 | del removed_layers 90 | del model.lm_head 91 | torch.cuda.empty_cache() 92 | gc.collect() 93 | 94 | model.lm_head = torch.nn.Identity() 95 | 96 | # Print model size after truncation 97 | total_params_after = sum(p.numel() for p in model.parameters()) 98 | logger.info(f"Model parameters after truncation: {total_params_after:,}") 99 | logger.info( 100 | f"Removed {num_layers - (max_layer + 1)} layers, keeping first {max_layer + 1} layers" 101 | ) 102 | logger.info( 103 | f"Memory saved: {(total_params_before - total_params_after) * dtype.itemsize / (1024**3):.2f} GB" 104 | ) 105 | else: 106 | logger.info("No hooks provided, using all layers") 107 | hooks = [f"models.layers.{i}.mlp.post" for i in range(model.config.num_hidden_layers)] 108 | 109 | decoder_cls = model.model.layers[0].__class__.__name__ 110 | 111 | num_gpus = torch.cuda.device_count() 112 | 113 | device_map = infer_auto_device_map( 114 | model, 115 | max_memory={ 116 | i: transformer_config.get("max_per_device_memory", "55GB") for i in range(num_gpus) 117 | }, 118 | no_split_module_classes=[decoder_cls], 119 | ) 120 | 121 | model = dispatch_model(model, device_map=device_map) 122 | 123 | torch.cuda.empty_cache() 124 | gc.collect() 125 | 126 | return model, tokenizer, hooks 127 | 128 | 129 | def maybe_add_mlp_attn_hooks(model: AutoModelForCausalLM, hooks: List[str] = None) -> List[str]: 130 | """Add MLP and attention hooks to the model. 131 | 132 | Args: 133 | model: The transformer model 134 | hooks: List of hook names to check for mlp/attn hooks 135 | 136 | Returns: 137 | List[str]: Potentially expanded list of hooks including new mlp/attn hooks 138 | """ 139 | if not hooks: 140 | return hooks 141 | 142 | pytorch_hooks = [] 143 | hook_activations = {} 144 | 145 | for hook in list(hooks): 146 | layer_idx = int(hook.split(".")[2]) 147 | 148 | if "mlp.pre" in hook: 149 | def get_mlp_pre_hook(layer_idx): 150 | def pre_hook(module, input): 151 | if input and isinstance(input, tuple) and len(input) > 0: 152 | hook_activations[f"models.layers.{layer_idx}.mlp.pre"] = input[0] 153 | return pre_hook 154 | 155 | pytorch_hook = get_mlp_pre_hook(layer_idx) 156 | model.model.layers[layer_idx].mlp.register_forward_pre_hook(pytorch_hook) 157 | pytorch_hooks.append(pytorch_hook) 158 | logger.info(f"Added MLP pre hook for layer {layer_idx}") 159 | 160 | if "mlp.post" in hook: 161 | def get_mlp_post_hook(layer_idx): 162 | def post_hook(module, input, output): 163 | # MLP output is a direct tensor 164 | hook_activations[f"models.layers.{layer_idx}.mlp.post"] = output 165 | return post_hook 166 | 167 | pytorch_hook = get_mlp_post_hook(layer_idx) 168 | model.model.layers[layer_idx].mlp.register_forward_hook(pytorch_hook) 169 | pytorch_hooks.append(pytorch_hook) 170 | logger.info(f"Added MLP post hook for layer {layer_idx}") 171 | 172 | if "self_attn.pre" in hook: 173 | def get_attn_pre_hook(layer_idx): 174 | def pre_hook(module, input): 175 | if input and isinstance(input, tuple) and len(input) > 0: 176 | hook_activations[f"models.layers.{layer_idx}.self_attn.pre"] = input[0] 177 | return pre_hook 178 | 179 | pytorch_hook = get_attn_pre_hook(layer_idx) 180 | model.model.layers[layer_idx].self_attn.register_forward_pre_hook(pytorch_hook) 181 | pytorch_hooks.append(pytorch_hook) 182 | logger.info(f"Added attention pre hook for layer {layer_idx}") 183 | 184 | if "self_attn.post" in hook: 185 | def get_attn_post_hook(layer_idx): 186 | def post_hook(module, input, output): 187 | # Attention output is a tuple, we want the first element 188 | if isinstance(output, tuple) and len(output) > 0: 189 | hook_activations[f"models.layers.{layer_idx}.self_attn.post"] = output[0] 190 | return post_hook 191 | 192 | pytorch_hook = get_attn_post_hook(layer_idx) 193 | model.model.layers[layer_idx].self_attn.register_forward_hook(pytorch_hook) 194 | pytorch_hooks.append(pytorch_hook) 195 | logger.info(f"Added attention post hook for layer {layer_idx}") 196 | 197 | return pytorch_hooks, hook_activations 198 | 199 | 200 | def setup_uploaders( 201 | run_name: str, hooks: List[str], batches_per_upload: int, bucket_name: str 202 | ) -> Dict[str, HookUploader]: 203 | """Create S3 uploaders for storing activation data from each hook. 204 | 205 | Args: 206 | run_name: Unique identifier for this run 207 | hooks: List of hook names requiring uploaders 208 | batches_per_upload: Number of batches to accumulate before upload 209 | bucket_name: S3 bucket name for storage 210 | 211 | Returns: 212 | dict: Mapping of hook names to their respective uploaders 213 | 214 | Example: 215 | ```python 216 | uploaders = setup_uploaders( 217 | "experiment_1", 218 | hooks=["models.layers.0.mlp.post"], 219 | batches_per_upload=10 220 | ) 221 | ``` 222 | """ 223 | uploaders = {} 224 | for hook in hooks: 225 | uploader = HookUploader.from_credentials( 226 | access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"), 227 | secret=os.environ.get("AWS_SECRET_ACCESS_KEY"), 228 | prefix_path=f"{run_name}/{hook}", 229 | batches_per_upload=batches_per_upload, 230 | bucket_name=bucket_name, 231 | ) 232 | uploaders[hook] = uploader 233 | 234 | return uploaders 235 | 236 | 237 | def calculate_machine_params( 238 | machine_index: int, total_batches: int, num_runs: int 239 | ) -> Tuple[int, int]: 240 | """Calculate batch distribution for distributed processing. 241 | 242 | Determines how many batches each machine should process and their starting points. 243 | 244 | Args: 245 | machine_index: Index of the current machine (0 to num_runs-1) 246 | total_batches: Total number of batches to process 247 | num_runs: Total number of parallel runs 248 | 249 | Returns: 250 | tuple: (batches_per_machine, start_batch) 251 | - batches_per_machine: Number of batches for this machine 252 | - start_batch: Starting batch index 253 | 254 | Note: 255 | The returned start_batch should be added to any global start_batch parameter. 256 | 257 | Example: 258 | ```python 259 | batches, start = calculate_machine_params(0, 1000, 4) 260 | # Returns (250, 0) for first machine 261 | ``` 262 | """ 263 | batches_per_machine = total_batches // num_runs 264 | start_batch = machine_index * batches_per_machine 265 | end_batch = start_batch + batches_per_machine - 1 266 | 267 | if num_runs > 1: 268 | logger.info( 269 | f"Run {machine_index + 1}/{num_runs} processing batches {start_batch} to {end_batch} " 270 | f"({batches_per_machine} out of {total_batches} total batches)" 271 | ) 272 | else: 273 | logger.info( 274 | f"Single run processing all batches from {start_batch} to {end_batch} " 275 | f"({total_batches} total batches)" 276 | ) 277 | 278 | return batches_per_machine, start_batch 279 | 280 | 281 | def display_job_stats(model, hooks, config, batches_per_machine): 282 | """Display job statistics including space usage and token count. 283 | 284 | Calculates and presents a box containing key statistics about 285 | the current job, including model dimensions, token counts, and estimated storage 286 | requirements. 287 | 288 | Args: 289 | model: The transformer model 290 | hooks: List of activation hooks 291 | config: Configuration object 292 | batches_per_machine: Number of batches to be processed by this machine 293 | """ 294 | # Calculate space usage statistics 295 | d_model = model.config.hidden_size 296 | n_hooks = len(hooks) 297 | dtype_size = {"float16": 2, "bfloat16": 2, "float32": 4}.get( 298 | config.transformer_config.get("dtype", "float16"), 2 299 | ) # Size in bytes 300 | n_batches = batches_per_machine 301 | batch_size = config.data_config["batch_size"] 302 | seq_length = config.data_config["seq_length"] 303 | 304 | total_tokens = n_batches * batch_size * seq_length 305 | total_space_bytes = d_model * n_hooks * dtype_size * total_tokens 306 | 307 | # Convert to human-readable format 308 | space_gb = total_space_bytes / (1024**3) 309 | 310 | # Create beautiful stats display 311 | logger.info("┌─" + "─" * 60 + "─┐") 312 | logger.info("│ " + "Job Statistics".center(60) + " │") 313 | logger.info("├─" + "─" * 60 + "─┤") 314 | logger.info(f"│ Model Hidden Dimension: {d_model:,}".ljust(62) + "│") 315 | logger.info(f"│ Number of Hooks: {n_hooks}".ljust(62) + "│") 316 | logger.info( 317 | f"│ Data Type: {config.transformer_config.get('dtype', 'float16')} ({dtype_size} bytes)".ljust( 318 | 62 319 | ) 320 | + "│" 321 | ) 322 | logger.info(f"│ Batch Size: {batch_size}".ljust(62) + "│") 323 | logger.info(f"│ Sequence Length: {seq_length}".ljust(62) + "│") 324 | logger.info(f"│ Number of Batches: {n_batches:,}".ljust(62) + "│") 325 | logger.info(f"│ Total Tokens: {total_tokens:,}".ljust(62) + "│") 326 | logger.info(f"│ Estimated Storage Required: {space_gb:.2f} GB".ljust(62) + "│") 327 | logger.info("└─" + "─" * 60 + "─┘") 328 | -------------------------------------------------------------------------------- /pipeline/utils/cli.py: -------------------------------------------------------------------------------- 1 | """Copyright (2025) Tilde Research Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | This module provides the main entry point for the activault CLI and handles subcommands. 10 | """ 11 | 12 | import sys 13 | import os 14 | import argparse 15 | import logging 16 | import random 17 | from pathlib import Path 18 | 19 | # Add the project root to sys.path so we can import stash.py 20 | project_root = str(Path(__file__).resolve().parent.parent.parent) 21 | if project_root not in sys.path: 22 | sys.path.insert(0, project_root) 23 | 24 | logging.basicConfig(level=logging.INFO, format="[%(asctime)s] %(levelname)s: %(message)s") 25 | logger = logging.getLogger(__name__) 26 | 27 | TILDE_LOGO = """ 28 | ___ 29 | / _ \_/\ 30 | \/ \___/ 31 | """ 32 | 33 | ACTIVAULT_BANNER = """ 34 | ╔════════════════════════════════════════════════════════╗ 35 | ║ ACTIVAULT v0.1.0 ║ 36 | ║ ║ 37 | ║ A pipeline for collecting LLM activations and ║ 38 | ║ storing them for efficient retrieval ║ 39 | ╚════════════════════════════════════════════════════════╝ 40 | """ 41 | 42 | 43 | def show_welcome(): 44 | """Display the welcome message with ASCII art.""" 45 | os.system("cls" if os.name == "nt" else "clear") 46 | print("\033[94m" + TILDE_LOGO + "\033[0m") 47 | print("\033[92m" + ACTIVAULT_BANNER + "\033[0m") 48 | print("\033[1mAvailable Commands:\033[0m") 49 | print(" \033[96mactivault collect\033[0m - Run activation collection") 50 | print(" \033[96mactivault s3\033[0m - Launch the S3 shell") 51 | print("\nFor more information on a command, run: activault --help") 52 | print("\nVisit: \033[4mhttps://github.com/tilde-research/activault\033[0m for documentation") 53 | print() 54 | 55 | 56 | def s3_command(): 57 | """Launch the S3 shell.""" 58 | from s3.shell.shell import S3Shell 59 | 60 | shell = S3Shell() 61 | shell.run() 62 | 63 | 64 | def collect_command(args): 65 | """Run the main activation collection pipeline.""" 66 | try: 67 | import stash 68 | 69 | # Pass command line arguments through to stash.py 70 | sys.argv = ["stash.py"] 71 | if args.config: 72 | sys.argv.extend(["--config", args.config]) 73 | if args.machine is not None: 74 | sys.argv.extend(["--machine", str(args.machine)]) 75 | stash.main() 76 | except ImportError as e: 77 | logger.error(f"Failed to import stash module: {e}") 78 | logger.error("Make sure you're running from the project root directory.") 79 | sys.exit(1) 80 | 81 | 82 | def main(): 83 | """Main entry point for the activault CLI.""" 84 | parser = argparse.ArgumentParser( 85 | description="Activault - A tool for collecting and processing neural network activations." 86 | ) 87 | subparsers = parser.add_subparsers(dest="command", help="Command to run") 88 | 89 | # Add the 's3' subcommand 90 | s3_parser = subparsers.add_parser("s3", help="Launch the S3 shell") 91 | 92 | # Add the 'collect' subcommand 93 | collect_parser = subparsers.add_parser("collect", help="Run activation collection") 94 | collect_parser.add_argument( 95 | "--config", type=str, help="Path to configuration file", required=False 96 | ) 97 | collect_parser.add_argument( 98 | "--machine", type=int, help="Machine index for distributed processing", required=False 99 | ) 100 | 101 | # Parse arguments 102 | args = parser.parse_args() 103 | 104 | # Handle subcommands 105 | if args.command == "s3": 106 | s3_command() 107 | elif args.command == "collect": 108 | collect_command(args) 109 | elif args.command is None: 110 | show_welcome() 111 | else: 112 | parser.print_help() 113 | sys.exit(1) 114 | 115 | 116 | if __name__ == "__main__": 117 | main() 118 | -------------------------------------------------------------------------------- /pipeline/utils/ray_utils.py: -------------------------------------------------------------------------------- 1 | """Copyright (2025) Tilde Research Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | """ 15 | 16 | import os 17 | import ray 18 | import logging 19 | import time 20 | import yaml 21 | import tempfile 22 | from typing import Dict, Any, Optional, List, Tuple 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | def init_ray(address: Optional[str] = None) -> None: 28 | """Initialize Ray with the given address or environment configuration. 29 | 30 | Args: 31 | address: Ray cluster address, or None to use environment variables 32 | """ 33 | if not ray.is_initialized(): 34 | try: 35 | # Use address if provided, otherwise Ray will use RAY_ADDRESS env var 36 | ray.init(address=address, ignore_reinit_error=True) 37 | logger.info(f"Ray initialized: {ray.cluster_resources()}") 38 | except Exception as e: 39 | logger.error(f"Failed to initialize Ray: {e}") 40 | raise 41 | 42 | 43 | @ray.remote 44 | class RayActivationWorker: 45 | """Ray actor for distributed activation collection. 46 | 47 | This worker encapsulates the main processing logic in a Ray-friendly way 48 | without modifying the core processing functions. 49 | """ 50 | 51 | def __init__(self, machine_index: int, config_path: str): 52 | """Initialize the worker with its machine index and config path. 53 | 54 | Args: 55 | machine_index: Index of this worker (equivalent to Slurm machine_index) 56 | config_path: Path to the configuration file 57 | """ 58 | self.machine_index = machine_index 59 | self.config_path = config_path 60 | self.results = {"status": "initialized"} 61 | 62 | def run(self) -> Dict[str, Any]: 63 | """Run the activation collection process. 64 | 65 | This method imports and calls the main function to run the collection process, 66 | capturing logs and results. 67 | 68 | Returns: 69 | Dict containing status and results 70 | """ 71 | import sys 72 | import os 73 | import json 74 | from importlib import import_module 75 | 76 | # Set machine index environment variable for compatibility 77 | os.environ["MACHINE_INDEX"] = str(self.machine_index) 78 | 79 | try: 80 | # Handle config processing - ray_config section isn't supported by Config class 81 | if self.config_path.endswith(".yaml"): 82 | # Create a clean config without ray_config 83 | with open(self.config_path, "r") as f: 84 | config_dict = yaml.safe_load(f) 85 | 86 | # Remove ray_config section if present 87 | if "ray_config" in config_dict: 88 | del config_dict["ray_config"] 89 | 90 | # Create temporary config file 91 | with tempfile.NamedTemporaryFile( 92 | suffix=".yaml", mode="w", delete=False 93 | ) as temp_file: 94 | temp_config_path = temp_file.name 95 | yaml.dump(config_dict, temp_file) 96 | else: 97 | temp_config_path = self.config_path 98 | 99 | # Use the existing main function from stash.py with temporary config 100 | sys.argv = [ 101 | "stash.py", 102 | "--config", 103 | temp_config_path, 104 | "--machine", 105 | str(self.machine_index), 106 | ] 107 | 108 | # Import and run the main function 109 | stash = import_module("stash") 110 | stash.main() 111 | 112 | # Clean up temp file 113 | if temp_config_path != self.config_path: 114 | os.remove(temp_config_path) 115 | 116 | self.results = { 117 | "status": "completed", 118 | "machine_index": self.machine_index, 119 | "error": None, 120 | } 121 | 122 | except Exception as e: 123 | import traceback 124 | 125 | error_trace = traceback.format_exc() 126 | logger.error(f"Worker {self.machine_index} failed: {e}\n{error_trace}") 127 | self.results = { 128 | "status": "failed", 129 | "machine_index": self.machine_index, 130 | "error": str(e), 131 | "traceback": error_trace, 132 | } 133 | 134 | return self.results 135 | 136 | 137 | def calculate_resources_per_worker(config: Dict[str, Any]) -> Dict[str, float]: 138 | """Calculate Ray resources required per worker. 139 | 140 | Args: 141 | config: Configuration dictionary with resource requirements 142 | 143 | Returns: 144 | Dict of Ray resources (e.g., {"GPU": 1, "CPU": 4}) 145 | """ 146 | resources = {} 147 | 148 | # Get resources from config or use defaults 149 | if "resources" in config: 150 | resources = config["resources"] 151 | else: 152 | # Default resource allocation 153 | resources = { 154 | "CPU": 4, 155 | "GPU": 1, 156 | } 157 | 158 | return resources 159 | 160 | 161 | def launch_ray_jobs( 162 | config_path: str, 163 | num_total_runs: int, 164 | start_idx: int = 0, 165 | end_idx: Optional[int] = None, 166 | ray_address: Optional[str] = None, 167 | resources_per_worker: Optional[Dict[str, float]] = None, 168 | ) -> List[ray.ObjectRef]: 169 | """Launch Ray workers for activation collection. 170 | 171 | Args: 172 | config_path: Path to configuration file 173 | num_total_runs: Total number of runs to execute 174 | start_idx: Starting machine index 175 | end_idx: Ending machine index (inclusive) 176 | ray_address: Ray cluster address 177 | resources_per_worker: Dict of resources per worker 178 | 179 | Returns: 180 | List of Ray ObjectRefs for the running jobs 181 | """ 182 | # Initialize Ray 183 | init_ray(ray_address) 184 | 185 | # Set end_idx to (num_total_runs - 1) if not specified 186 | if end_idx is None: 187 | end_idx = num_total_runs - 1 188 | 189 | # Validate indices 190 | if start_idx >= num_total_runs or end_idx >= num_total_runs or start_idx > end_idx: 191 | raise ValueError( 192 | f"Invalid index range: {start_idx} to {end_idx} (total runs: {num_total_runs})" 193 | ) 194 | 195 | # Create actor options with resource requirements 196 | actor_options = {} 197 | if resources_per_worker: 198 | actor_options["num_cpus"] = resources_per_worker.get("CPU", 4) 199 | actor_options["num_gpus"] = resources_per_worker.get("GPU", 1) 200 | 201 | # Add custom resources if any 202 | for k, v in resources_per_worker.items(): 203 | if k not in ["CPU", "GPU"]: 204 | actor_options[f"resources"] = {k: v} 205 | 206 | # Launch jobs for the specified range 207 | job_refs = [] 208 | for i in range(start_idx, end_idx + 1): 209 | logger.info(f"Launching Ray worker {i+1}/{num_total_runs}") 210 | worker = RayActivationWorker.options(**actor_options).remote(i, config_path) 211 | job_ref = worker.run.remote() 212 | job_refs.append(job_ref) 213 | 214 | return job_refs 215 | 216 | 217 | def monitor_ray_jobs( 218 | job_refs: List[ray.ObjectRef], poll_interval: int = 10 219 | ) -> List[Dict[str, Any]]: 220 | """Monitor Ray jobs until completion. 221 | 222 | Args: 223 | job_refs: List of Ray ObjectRefs to monitor 224 | poll_interval: How often to check job status in seconds 225 | 226 | Returns: 227 | List of results from completed jobs 228 | """ 229 | all_results = [] 230 | remaining_refs = list(job_refs) 231 | 232 | while remaining_refs: 233 | # Check for any completed jobs 234 | done_refs, remaining_refs = ray.wait(remaining_refs, timeout=poll_interval) 235 | 236 | # Process completed jobs 237 | for job_ref in done_refs: 238 | try: 239 | result = ray.get(job_ref) 240 | all_results.append(result) 241 | logger.info( 242 | f"Job completed: machine_index={result.get('machine_index')}, status={result.get('status')}" 243 | ) 244 | except Exception as e: 245 | logger.error(f"Error getting job result: {e}") 246 | all_results.append({"status": "error", "error": str(e)}) 247 | 248 | # Log progress 249 | logger.info(f"Progress: {len(all_results)}/{len(job_refs)} jobs completed") 250 | 251 | return all_results 252 | -------------------------------------------------------------------------------- /pipeline/vault/__init__.py: -------------------------------------------------------------------------------- 1 | from .hook_uploader import HookUploader 2 | from .rcache import S3RCache 3 | 4 | __all__ = ["HookUploader", "S3RCache"] 5 | -------------------------------------------------------------------------------- /pipeline/vault/hook_uploader.py: -------------------------------------------------------------------------------- 1 | """Copyright (2025) Tilde Research Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | """ 15 | 16 | import json 17 | import threading 18 | import queue 19 | import os 20 | import torch 21 | from uuid import uuid4 22 | import time 23 | import multiprocessing as mp 24 | import os 25 | import io 26 | import atexit 27 | from s3.utils import create_s3_client 28 | import math 29 | from typing import Optional 30 | import logging 31 | 32 | logger = logging.getLogger(__name__) 33 | 34 | 35 | def _metadata_path(prefix_path: str) -> str: 36 | """Generate the metadata file path for a given prefix path. 37 | 38 | Args: 39 | prefix_path: Base path for the hook's data 40 | 41 | Returns: 42 | str: Path in format "{prefix_path}/metadata.json" 43 | """ 44 | return f"{prefix_path}/metadata.json" 45 | 46 | 47 | def _statistics_path(prefix_path: str) -> str: 48 | """Generate the statistics file path for a given prefix path. 49 | 50 | Args: 51 | prefix_path: Base path for the hook's data 52 | 53 | Returns: 54 | str: Path in format "{prefix_path}/statistics.json" 55 | """ 56 | return f"{prefix_path}/statistics.json" 57 | 58 | 59 | def _strip_run_name(prefix_path): 60 | """Strip the run name from the prefix path.""" 61 | return prefix_path.split("/", 1)[1] 62 | 63 | 64 | class HookUploader: 65 | """Asynchronous uploader for neural network activation data to S3-compatible storage. 66 | 67 | This class manages the collection, batching, and uploading of activation data from 68 | neural network hooks. It uses a multi-process architecture with threaded workers 69 | to handle concurrent uploads efficiently. 70 | 71 | Architecture: 72 | - Main Process: Collects activations and queues them for upload 73 | - Upload Process: Manages upload threads and data transfer 74 | - Upload Threads: Handle actual S3 upload operations 75 | 76 | Attributes: 77 | batches_per_upload (int): Number of batches to accumulate before upload 78 | prefix_path (str): Base path for storing data in S3 79 | bucket_name (str): S3 bucket name 80 | num_upload_threads (int): Number of concurrent upload threads 81 | pending_uploads (mp.Value): Counter for pending upload operations 82 | upload_queue_size (mp.Value): Size of the upload queue 83 | running_mean (Optional[torch.Tensor]): Running mean of activations 84 | running_std (Optional[torch.Tensor]): Running standard deviation 85 | running_norm (float): Running norm of activations 86 | n_batches_processed (int): Total number of processed batches 87 | 88 | Example: 89 | ```python 90 | uploader = HookUploader.from_credentials( 91 | access_key_id="your_key", 92 | secret="your_secret", 93 | prefix_path="run_name/hook_name", 94 | batches_per_upload=32 95 | ) 96 | 97 | # Add activations 98 | uploader.append({ 99 | 'states': activation_tensor, 100 | 'input_ids': input_ids_tensor 101 | }) 102 | 103 | # Finalize and cleanup 104 | uploader.finalize() 105 | ``` 106 | """ 107 | 108 | @classmethod 109 | def from_credentials(cls, access_key_id: str, secret: str, *args, **kwargs) -> "HookUploader": 110 | """Create an HookUploader instance using storage credentials. 111 | 112 | Args: 113 | access_key_id: Storage access key ID 114 | secret: Storage secret access key 115 | *args: Additional positional arguments for HookUploader 116 | **kwargs: Additional keyword arguments for HookUploader 117 | 118 | Returns: 119 | HookUploader: Configured uploader instance 120 | 121 | Raises: 122 | Exception: If bucket verification fails 123 | """ 124 | s3_client = create_s3_client(access_key_id, secret) 125 | logger.debug("Created S3 client with endpoint: %s", s3_client._endpoint) 126 | return cls(s3_client, *args, **kwargs) 127 | 128 | def __init__( 129 | self, 130 | s3_client, 131 | prefix_path: str, 132 | batches_per_upload: int = 32, 133 | bucket_name: str = "renes-bucket", 134 | num_upload_threads: int = 1, 135 | ): 136 | """Initialize the HookUploader.""" 137 | self.batches_per_upload = batches_per_upload 138 | self.prefix_path = prefix_path 139 | self._in_mem = [] # Back to being a simple list 140 | self.current_group_uuid = None # Track current UUID 141 | self.s3_client = s3_client 142 | self.metadata = None 143 | self.bucket_name = bucket_name 144 | self.num_upload_threads = num_upload_threads 145 | self.upload_attempt_count = 0 146 | self.pending_uploads = 0 147 | 148 | logger.debug("Initializing HookUploader with bucket: %s", self.bucket_name) 149 | 150 | try: 151 | self.s3_client.head_bucket(Bucket=self.bucket_name) 152 | logger.debug("Successfully verified bucket %s exists", self.bucket_name) 153 | except Exception as e: 154 | logger.error("Error verifying bucket %s: %s", self.bucket_name, str(e), exc_info=True) 155 | raise 156 | 157 | self.running_mean = None 158 | self.running_std = None 159 | self.running_norm = 0.0 160 | self.n_batches_processed = 0 161 | 162 | self.mp_upload_queue = mp.Queue(2) 163 | self.stop_event = mp.Event() 164 | 165 | self.upload_process = mp.Process(target=self._upload_worker) 166 | self.upload_process.start() 167 | 168 | atexit.register(self.cleanup) 169 | 170 | def cleanup(self) -> None: 171 | """Clean up resources and ensure all uploads are complete. 172 | 173 | Waits for pending uploads to complete (with timeout) and terminates 174 | the upload process. Automatically called on program exit. 175 | """ 176 | logger.info("Cleaning up HookUploader for %s...", _strip_run_name(self.prefix_path)) 177 | self.stop_event.set() 178 | if self.upload_process.is_alive(): 179 | minutes_to_wait = 10 180 | self.upload_process.join(timeout=minutes_to_wait * 60) 181 | if self.upload_process.is_alive(): 182 | logger.warning( 183 | "HookUploader for %s upload process is still alive. TERMINATING...", 184 | _strip_run_name(self.prefix_path), 185 | ) 186 | self.upload_process.terminate() 187 | logger.debug("HookUploader for %s cleanup complete.", _strip_run_name(self.prefix_path)) 188 | 189 | def _upload_worker(self) -> None: 190 | """Worker process for handling S3 uploads. 191 | 192 | Manages a pool of upload threads and coordinates data transfer from 193 | the multiprocessing queue to the thread queue. 194 | 195 | Note: 196 | Runs in a separate process and manages its own thread pool. 197 | """ 198 | upload_queue = queue.Queue(2) 199 | thread_stop_event = threading.Event() 200 | logger.debug("%s PID: %s", _strip_run_name(self.prefix_path), os.getpid()) 201 | 202 | threads = [] 203 | for _ in range(self.num_upload_threads): 204 | t = threading.Thread(target=self._upload_thread, args=(upload_queue, thread_stop_event)) 205 | t.start() 206 | threads.append(t) 207 | 208 | last_log_time = time.time() 209 | log_interval = 10 # Log every 10 seconds 210 | 211 | while True: 212 | if self.stop_event.is_set() and self.mp_upload_queue.empty(): 213 | logger.debug( 214 | "stop_event reached, mp_upload_queue size at moment: %d", 215 | self.mp_upload_queue.qsize(), 216 | ) 217 | break 218 | 219 | try: 220 | item = self.mp_upload_queue.get(timeout=10) 221 | self.pending_uploads += 1 222 | upload_queue.put(item) 223 | 224 | # Log queue sizes periodically 225 | current_time = time.time() 226 | if current_time - last_log_time > log_interval: 227 | logger.debug( 228 | "[%s QUEUE STATUS] MP Queue: %d, Thread Queue: %d, Pending Uploads: %d", 229 | _strip_run_name(self.prefix_path), 230 | self.mp_upload_queue.qsize(), 231 | upload_queue.qsize(), 232 | self.pending_uploads, 233 | ) 234 | last_log_time = current_time 235 | 236 | except mp.queues.Empty: 237 | continue 238 | except Exception as e: 239 | logger.error("Exception in _upload_worker: %s", str(e), exc_info=True) 240 | raise e 241 | 242 | thread_stop_event.set() 243 | 244 | for t in threads: 245 | t.join() 246 | 247 | def _upload_thread(self, upload_queue: queue.Queue, stop_event: threading.Event) -> None: 248 | """Thread for uploading data to S3. 249 | 250 | Args: 251 | upload_queue: Queue containing data to upload 252 | stop_event: Event signaling thread should stop 253 | 254 | Note: 255 | Runs in upload worker process and handles actual S3 uploads. 256 | """ 257 | while True: 258 | if stop_event.is_set() and upload_queue.empty(): 259 | logger.debug( 260 | "stop_event reached, upload_queue size at moment: %d", upload_queue.qsize() 261 | ) 262 | break 263 | 264 | try: 265 | activations, group_uuid = upload_queue.get(timeout=1) 266 | self._save(activations, group_uuid) 267 | self.pending_uploads -= 1 # Decrement pending uploads after successful save 268 | upload_queue.task_done() 269 | except queue.Empty: 270 | if stop_event.is_set(): 271 | continue 272 | time.sleep(0.25) 273 | continue 274 | except Exception as e: 275 | logger.error("Exception in _upload_thread: %s", str(e), exc_info=True) 276 | raise e 277 | 278 | def append(self, activations: dict, group_uuid: str) -> Optional[str]: 279 | """Append activations to the cache, queueing for S3 upload when batch is full. 280 | 281 | Args: 282 | activations: Dictionary containing activation data 283 | group_uuid: UUID for this batch group (shared across layers) 284 | 285 | Returns: 286 | Optional[str]: Upload ID if batch was queued, None otherwise 287 | """ 288 | # Update current group UUID 289 | self.current_group_uuid = group_uuid 290 | 291 | if self.metadata is None: 292 | self.metadata = self._get_metadata(activations, self.batches_per_upload) 293 | self._save_metadata() 294 | else: 295 | if not self._validate_activations(activations): 296 | return None 297 | 298 | self._in_mem.append(activations) 299 | 300 | if len(self._in_mem) == self.batches_per_upload: 301 | return self._queue_save_in_mem() 302 | 303 | return None 304 | 305 | def _queue_save_in_mem(self) -> str: 306 | """Queue the in-memory activations for S3 upload. 307 | 308 | Returns: 309 | str: The group UUID used for this upload 310 | """ 311 | # Combine all states and input_ids from the group 312 | combined_states = torch.cat([item["states"] for item in self._in_mem]) 313 | combined_input_ids = torch.cat([item["input_ids"] for item in self._in_mem]) 314 | 315 | states_bytes = combined_states.numel() * combined_states.element_size() 316 | input_ids_bytes = combined_input_ids.numel() * combined_input_ids.element_size() 317 | 318 | combined_dict = { 319 | "states": combined_states, 320 | "input_ids": combined_input_ids, 321 | "tensor_bytes": states_bytes + input_ids_bytes, 322 | } 323 | 324 | self.mp_upload_queue.put((combined_dict, self.current_group_uuid)) 325 | self._in_mem = [] 326 | return self.current_group_uuid 327 | 328 | def _save(self, activations_dict: dict, group_uuid: str) -> None: 329 | """Save the activations dictionary to S3 using multipart upload. 330 | 331 | Args: 332 | activations_dict: Dictionary containing activation data 333 | group_uuid: UUID for this batch group 334 | 335 | Raises: 336 | Exception: If upload fails 337 | 338 | Note: 339 | Uses multipart upload for large files and handles failures. 340 | """ 341 | 342 | filename = self._filename(group_uuid) 343 | logger.debug("Starting upload of %s", filename) 344 | 345 | serialization_start = time.time() 346 | buffer = io.BytesIO() 347 | torch.save(activations_dict, buffer) 348 | buffer.seek(0) 349 | tensor_bytes = buffer.getvalue() 350 | serialization_end = time.time() 351 | serialization_time = serialization_end - serialization_start 352 | 353 | # Start multipart upload 354 | upload_start = time.time() 355 | multipart_upload = self.s3_client.create_multipart_upload( 356 | Bucket=self.bucket_name, Key=filename, ContentType="application/octet-stream" 357 | ) 358 | upload_id = multipart_upload["UploadId"] 359 | 360 | try: 361 | # Upload parts 362 | part_size = 100 * 1024 * 1024 # 100 MB 363 | num_parts = math.ceil(len(tensor_bytes) / part_size) 364 | parts = [] 365 | 366 | for part_number in range(1, num_parts + 1): 367 | start_byte = (part_number - 1) * part_size 368 | end_byte = min(part_number * part_size, len(tensor_bytes)) 369 | 370 | part_response = self.s3_client.upload_part( 371 | Bucket=self.bucket_name, 372 | Key=filename, 373 | PartNumber=part_number, 374 | UploadId=upload_id, 375 | Body=tensor_bytes[start_byte:end_byte], 376 | ) 377 | parts.append({"PartNumber": part_number, "ETag": part_response["ETag"]}) 378 | 379 | # Complete multipart upload 380 | self.s3_client.complete_multipart_upload( 381 | Bucket=self.bucket_name, 382 | Key=filename, 383 | MultipartUpload={"Parts": parts}, 384 | UploadId=upload_id, 385 | ) 386 | 387 | except Exception as e: 388 | logger.error("Upload failed: %s", str(e)) 389 | # Abort the multipart upload on any error 390 | self.s3_client.abort_multipart_upload( 391 | Bucket=self.bucket_name, Key=filename, UploadId=upload_id 392 | ) 393 | raise e 394 | 395 | upload_end = time.time() 396 | upload_time = upload_end - upload_start 397 | total_time = upload_end - serialization_start 398 | 399 | # Save metadata 400 | self.metadata = self._get_metadata( 401 | activations_dict, self.batches_per_upload, len(tensor_bytes) 402 | ) 403 | self.metadata["bytes_per_file"] = len(tensor_bytes) 404 | self._save_metadata() 405 | 406 | logger.debug( 407 | "Successfully saved %s to S3! PID: %d, Serialization time: %.2fs, S3 upload time: %.2fs, Total time: %.2fs", 408 | _strip_run_name(filename), 409 | os.getpid(), 410 | serialization_time, 411 | upload_time, 412 | total_time, 413 | ) 414 | 415 | def finalize(self) -> None: 416 | """Finalize the cache, ensuring all data is saved and processes are stopped.""" 417 | if self.metadata is None: 418 | raise ValueError("Cannot finalize cache without any data") 419 | 420 | # Handle any remaining batches 421 | if len(self._in_mem) > 0: 422 | if len(self._in_mem) == self.batches_per_upload: 423 | logger.debug("Queueing save for final full batch group") 424 | self._queue_save_in_mem() 425 | else: 426 | logger.warning("Discarding %d incomplete batches", len(self._in_mem)) 427 | self._in_mem = [] 428 | 429 | # Wait for all pending uploads to complete 430 | logger.info( 431 | "Waiting for %d pending uploads to complete for %s...", 432 | self.pending_uploads, 433 | self.prefix_path, 434 | ) 435 | wait_start = time.time() 436 | while self.pending_uploads > 0: 437 | if time.time() - wait_start > 3600: # 1 hour timeout 438 | logger.warning( 439 | "Timeout waiting for uploads to complete. %d uploads still pending.", 440 | self.pending_uploads, 441 | ) 442 | break 443 | time.sleep(5) 444 | 445 | self.cleanup() 446 | 447 | def _validate_activations(self, activations: dict) -> bool: 448 | """Validate the shape and dtype of the activations against the metadata. 449 | 450 | Args: 451 | activations: Dictionary containing activation data 452 | 453 | Returns: 454 | bool: True if validation passes, False otherwise 455 | """ 456 | expected_shape = ( 457 | self.metadata["batch_size"], 458 | self.metadata["sequence_length"], 459 | self.metadata["d_in"], 460 | ) 461 | if activations["states"].shape != expected_shape: 462 | logger.warning( 463 | "NOT SAVING: shape mismatch. Expected %s, got %s", 464 | expected_shape, 465 | activations["states"].shape, 466 | ) 467 | return False 468 | if str(activations["states"].dtype) != self.metadata["dtype"]: 469 | logger.warning( 470 | "NOT SAVING: dtype mismatch. Expected %s, got %s", 471 | self.metadata["dtype"], 472 | activations["states"].dtype, 473 | ) 474 | return False 475 | return True 476 | 477 | def _save_metadata(self) -> None: 478 | """Save the metadata to S3. 479 | 480 | Stores metadata about activation shapes, types, and batch sizes. 481 | """ 482 | logger.debug("Saving metadata for %s", _strip_run_name(self.prefix_path)) 483 | self.s3_client.put_object( 484 | Body=json.dumps(self.metadata), 485 | Bucket=self.bucket_name, 486 | Key=_metadata_path(self.prefix_path), 487 | ) 488 | 489 | def _filename(self, group_uuid: str) -> str: 490 | """Generate the filename for a given group UUID. 491 | 492 | Args: 493 | group_uuid: UUID for this batch group 494 | 495 | Returns: 496 | str: Full path including attempt number 497 | """ 498 | slurm_job = os.getenv("SLURM_JOB_ID", 0) 499 | self.upload_attempt_count += 1 500 | return f"{self.prefix_path}/{group_uuid}--{self.upload_attempt_count}_{slurm_job}.saved.pt" 501 | 502 | def _get_metadata( 503 | self, activations: dict, batches_per_upload: int, bytes_per_file: Optional[int] = None 504 | ) -> dict: 505 | """Create metadata dictionary from activation data. 506 | 507 | Args: 508 | activations: Dictionary containing activation data 509 | batches_per_upload: Number of batches per upload file 510 | bytes_per_file: Optional size of file in bytes 511 | 512 | Returns: 513 | dict: Metadata including shapes, types, and sizes 514 | """ 515 | metadata = { 516 | "batch_size": activations["states"].shape[0], 517 | "sequence_length": activations["states"].shape[1], 518 | "dtype": str(activations["states"].dtype), 519 | "d_in": activations["states"].shape[2], 520 | "batches_per_file": batches_per_upload, 521 | "shape": list(activations["states"].shape), 522 | "input_ids_shape": list(activations["input_ids"].shape), 523 | } 524 | if bytes_per_file is not None: 525 | metadata["bytes_per_file"] = bytes_per_file 526 | return metadata 527 | 528 | def save_stats(self, mean: torch.Tensor, std: torch.Tensor, norm: float, M2: torch.Tensor): 529 | """ 530 | Save statistics for the hook. 531 | 532 | Args: 533 | mean: Mean tensor 534 | std: Standard deviation tensor 535 | norm: Optional average L2 norm 536 | M2: Optional M2 value from Welford's algorithm for resumable stats 537 | """ 538 | stats = { 539 | "mean": mean.tolist(), 540 | "std": std.tolist(), 541 | } 542 | 543 | if norm is not None: 544 | stats["norm"] = float(norm) # Ensure norm is a Python float 545 | 546 | if M2 is not None: 547 | stats["M2"] = M2.tolist() # Convert M2 tensor to list 548 | 549 | # Save stats to S3 directly using the client 550 | self.s3_client.put_object( 551 | Body=json.dumps(stats), 552 | Bucket=self.bucket_name, 553 | Key=_statistics_path(self.prefix_path), 554 | ) 555 | -------------------------------------------------------------------------------- /pipeline/vault/rcache.py: -------------------------------------------------------------------------------- 1 | """Copyright (2025) Tilde Research Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | """ 15 | 16 | import asyncio 17 | import io 18 | import json 19 | import os 20 | import random 21 | import signal 22 | import sys 23 | import time 24 | import warnings 25 | from multiprocessing import Process, Queue, Value 26 | 27 | import aiohttp 28 | import boto3 29 | import torch 30 | import torch.nn as nn 31 | import multiprocessing as mp 32 | import warnings 33 | import logging 34 | 35 | logger = logging.getLogger(__name__) 36 | 37 | # Constants for file sizes 38 | KB = 1024 39 | MB = KB * KB 40 | 41 | # Cache directory constants 42 | OUTER_CACHE_DIR = "cache" 43 | INNER_CACHE_DIR = "cache" 44 | BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", "main") 45 | 46 | 47 | def _metadata_path(run_name): 48 | """Generate the metadata file path for a given run name.""" 49 | return f"{run_name}/metadata.json" 50 | 51 | 52 | def _statistics_path(run_name): 53 | """Generate the statistics file path for a given run name.""" 54 | return f"{run_name}/statistics.json" 55 | 56 | 57 | async def download_chunks(session, url, total_size, chunk_size): 58 | """Download file chunks asynchronously with retries.""" 59 | tries_left = 5 60 | while tries_left > 0: 61 | chunks = [ 62 | (i, min(i + chunk_size - 1, total_size - 1)) for i in range(0, total_size, chunk_size) 63 | ] 64 | tasks = [ 65 | asyncio.create_task(request_chunk(session, url, start, end)) for start, end in chunks 66 | ] 67 | responses = await asyncio.gather(*tasks, return_exceptions=True) 68 | 69 | results = [] 70 | retry = False 71 | for response in responses: 72 | if isinstance(response, Exception): 73 | logger.error(f"Error occurred: {response}") 74 | logger.error(f"Session: {session}, URL: {url}, Tries left: {tries_left}") 75 | tries_left -= 1 76 | retry = True 77 | break 78 | else: 79 | results.append(response) 80 | 81 | if not retry: 82 | return results 83 | 84 | return None 85 | 86 | 87 | async def request_chunk(session, url, start, end): 88 | """Request a specific chunk of a file.""" 89 | headers = {"Range": f"bytes={start}-{end}"} 90 | try: 91 | async with session.get(url, headers=headers) as response: 92 | response.raise_for_status() 93 | return start, await response.read() 94 | except Exception as e: 95 | return e 96 | 97 | 98 | def download_loop(*args): 99 | """Run the asynchronous download loop.""" 100 | asyncio.run(_async_download(*args)) 101 | 102 | 103 | def compile(byte_buffers, shuffle=True, seed=None, return_ids=False): 104 | """Compile downloaded chunks into a tensor.""" 105 | combined_bytes = b"".join(chunk for _, chunk in sorted(byte_buffers, key=lambda x: x[0])) 106 | 107 | with warnings.catch_warnings(): 108 | warnings.simplefilter("ignore") 109 | # n = np.frombuffer(combined_bytes, dtype=np.float16) 110 | # t = torch.from_numpy(n) 111 | # t = torch.frombuffer(combined_bytes, dtype=dtype) # torch.float32 112 | buffer = io.BytesIO(combined_bytes) 113 | t = torch.load(buffer) 114 | if isinstance(t, dict) and "states" in t and not return_ids: # backward compatibility 115 | t = t["states"] # ignore input_ids 116 | buffer.close() 117 | 118 | if shuffle and not return_ids: 119 | t = shuffle_megabatch_tokens(t, seed) 120 | 121 | return t 122 | 123 | 124 | def shuffle_megabatch_tokens(t, seed=None): 125 | """ 126 | Shuffle within a megabatch (across batches and sequences), using each token as the unit of shuffling. 127 | 128 | Args: 129 | t (torch.Tensor): Input tensor of shape (batch_size * batches_per_file, sequence_length, d_in + 1) 130 | seed (int): Seed for the random number generator 131 | 132 | Returns: 133 | torch.Tensor: Shuffled tensor of the same shape as input 134 | """ 135 | original_shape = t.shape # (batch_size * batches_per_file, sequence_length, d_in + 1) 136 | 137 | total_tokens = original_shape[0] * original_shape[1] # reshape to (total_tokens, d_in + 1) 138 | t_reshaped = t.reshape(total_tokens, -1) 139 | 140 | rng = torch.Generator() 141 | if seed is not None: 142 | rng.manual_seed(seed) 143 | 144 | shuffled_indices = torch.randperm(total_tokens, generator=rng) 145 | t_shuffled = t_reshaped[shuffled_indices] 146 | 147 | t = t_shuffled.reshape(original_shape) # revert 148 | 149 | return t 150 | 151 | 152 | def write_tensor(t, buffer, writeable_tensors, readable_tensors, ongoing_downloads): 153 | """Write a tensor to the shared buffer.""" 154 | idx = writeable_tensors.get(block=True) 155 | if isinstance(buffer[0], SharedBuffer): 156 | buffer[idx].states.copy_(t["states"]) 157 | buffer[idx].input_ids.copy_(t["input_ids"]) 158 | else: 159 | buffer[idx] = t 160 | 161 | readable_tensors.put(idx, block=True) 162 | with ongoing_downloads.get_lock(): 163 | ongoing_downloads.value -= 1 164 | 165 | 166 | async def _async_download( 167 | buffer, 168 | file_index, 169 | s3_paths, 170 | stop, 171 | readable_tensors, 172 | writeable_tensors, 173 | ongoing_downloads, 174 | concurrency, 175 | bytes_per_file, 176 | chunk_size, 177 | shuffle, 178 | seed, 179 | return_ids, 180 | ): 181 | """Asynchronously download and process files from S3.""" 182 | connector = aiohttp.TCPConnector(limit=concurrency) 183 | async with aiohttp.ClientSession(connector=connector) as session: 184 | while file_index.value < len(s3_paths) and not stop.value: 185 | with ongoing_downloads.get_lock(): 186 | ongoing_downloads.value += 1 187 | with file_index.get_lock(): 188 | url = s3_paths[file_index.value] 189 | file_index.value += 1 190 | bytes_results = await download_chunks(session, url, bytes_per_file, chunk_size) 191 | if bytes_results is not None: 192 | try: 193 | t = compile(bytes_results, shuffle, seed, return_ids) 194 | write_tensor( 195 | t, 196 | buffer, 197 | writeable_tensors, 198 | readable_tensors, 199 | ongoing_downloads, 200 | ) 201 | except Exception as e: 202 | logger.error(f"Exception while downloading: {e}") 203 | logger.error(f"Failed URL: {url}") 204 | stop.value = True # Set stop flag 205 | break # Exit the loop 206 | else: 207 | logger.error(f"Failed to download URL: {url}") 208 | with ongoing_downloads.get_lock(): 209 | ongoing_downloads.value -= 1 210 | 211 | 212 | class S3RCache: 213 | """A cache that reads data from Amazon S3.""" 214 | 215 | @classmethod 216 | def from_credentials(self, aws_access_key_id, aws_secret_access_key, *args, **kwargs): 217 | s3_client = boto3.client( 218 | "s3", 219 | aws_access_key_id=aws_access_key_id, 220 | aws_secret_access_key=aws_secret_access_key, 221 | ) 222 | return S3RCache(s3_client, *args, **kwargs) 223 | 224 | def __init__( 225 | self, 226 | s3_client, 227 | s3_prefix, 228 | bucket_name=BUCKET_NAME, 229 | device="cpu", 230 | concurrency=100, 231 | chunk_size=MB * 16, 232 | buffer_size=2, 233 | shuffle=True, 234 | preserve_file_order=False, 235 | seed=42, 236 | paths=None, 237 | n_workers=1, 238 | return_ids=False, 239 | ) -> None: 240 | """Initialize S3 cache.""" 241 | ensure_spawn_context() 242 | 243 | # Configure S3 client with correct signature version 244 | self.s3_client = ( 245 | boto3.client( 246 | "s3", 247 | region_name="eu-north1", # Make sure this matches your bucket region 248 | config=boto3.session.Config(signature_version="s3v4"), 249 | ) 250 | if s3_client is None 251 | else s3_client 252 | ) 253 | 254 | self.s3_prefix = s3_prefix 255 | self.bucket_name = bucket_name 256 | self.device = device 257 | self.concurrency = concurrency 258 | self.chunk_size = chunk_size 259 | self.buffer_size = buffer_size 260 | self.shuffle = shuffle 261 | self.preserve_file_order = preserve_file_order 262 | self.seed = seed 263 | self.return_ids = return_ids 264 | 265 | random.seed(self.seed) 266 | torch.manual_seed(self.seed) # unclear if this has effect 267 | # but we drill down the seed to download loop anyway 268 | 269 | self.paths = paths 270 | self._s3_paths = self._list_s3_files() 271 | if isinstance(self.s3_prefix, list): 272 | target_prefix = self.s3_prefix[0] 273 | else: 274 | target_prefix = self.s3_prefix 275 | response = self.s3_client.get_object(Bucket=bucket_name, Key=_metadata_path(target_prefix)) 276 | content = response["Body"].read() 277 | self.metadata = json.loads(content) 278 | # self.metadata["bytes_per_file"] = 1612711320 279 | self._activation_dtype = eval(self.metadata["dtype"]) 280 | 281 | self._running_processes = [] 282 | self.n_workers = n_workers 283 | 284 | self.readable_tensors = Queue(maxsize=self.buffer_size) 285 | self.writeable_tensors = Queue(maxsize=self.buffer_size) 286 | 287 | for i in range(self.buffer_size): 288 | self.writeable_tensors.put(i) 289 | 290 | if self.return_ids: 291 | self.buffer = [ 292 | SharedBuffer( 293 | self.metadata["shape"], 294 | self.metadata["input_ids_shape"], 295 | self._activation_dtype, 296 | ) 297 | for _ in range(self.buffer_size) 298 | ] 299 | for shared_buffer in self.buffer: 300 | shared_buffer.share_memory() 301 | else: 302 | self.buffer = torch.empty( 303 | (self.buffer_size, *self.metadata["shape"]), 304 | dtype=self._activation_dtype, 305 | ).share_memory_() 306 | 307 | self._stop = Value("b", False) 308 | self._file_index = Value("i", 0) 309 | self._ongoing_downloads = Value("i", 0) 310 | 311 | signal.signal(signal.SIGTERM, self._catch_stop) 312 | signal.signal(signal.SIGINT, self._catch_stop) 313 | 314 | self._initial_file_index = 0 315 | 316 | @property 317 | def current_file_index(self): 318 | return self._file_index.value 319 | 320 | def set_file_index(self, index): 321 | self._initial_file_index = index 322 | 323 | def _catch_stop(self, *args, **kwargs): 324 | logger.info("cleaning up before process is killed") 325 | self._stop_downloading() 326 | sys.exit(0) 327 | 328 | def sync(self): 329 | self._s3_paths = self._list_s3_files() 330 | 331 | def _reset(self): 332 | self._file_index.value = self._initial_file_index 333 | self._ongoing_downloads.value = 0 334 | self._stop.value = False 335 | 336 | while not self.readable_tensors.empty(): 337 | self.readable_tensors.get() 338 | 339 | while not self.writeable_tensors.empty(): 340 | self.writeable_tensors.get() 341 | for i in range(self.buffer_size): 342 | self.writeable_tensors.put(i) 343 | 344 | def _list_s3_files(self): 345 | """List and prepare all data files from one or more S3 prefixes.""" 346 | paths = [] 347 | combined_metadata = None 348 | combined_config = None 349 | 350 | # Handle single prefix case for backward compatibility 351 | prefixes = [self.s3_prefix] if isinstance(self.s3_prefix, str) else self.s3_prefix 352 | 353 | # Process each prefix 354 | for prefix in prefixes: 355 | # Get metadata for this prefix 356 | response = self.s3_client.get_object( 357 | Bucket=self.bucket_name, Key=_metadata_path(prefix) 358 | ) 359 | metadata = json.loads(response["Body"].read()) 360 | 361 | # Get config for this prefix 362 | try: 363 | config_response = self.s3_client.get_object( 364 | Bucket=self.bucket_name, Key=f"{'/'.join(prefix.split('/')[:-1])}/cfg.json" 365 | ) 366 | config = json.loads(config_response["Body"].read()) 367 | except Exception as e: 368 | logger.warning(f"Warning: Could not load config for prefix {prefix}: {e}") 369 | config = {} 370 | 371 | # Initialize combined metadata and config from first prefix 372 | if combined_metadata is None: 373 | combined_metadata = metadata.copy() 374 | combined_config = config.copy() 375 | # Initialize accumulation fields 376 | combined_config["total_tokens"] = 0 377 | combined_config["n_total_files"] = 0 378 | combined_config["batches_processed"] = 0 379 | else: 380 | # Verify metadata compatibility 381 | if metadata["shape"][1:] != combined_metadata["shape"][1:]: 382 | raise ValueError( 383 | f"Incompatible shapes between datasets: {metadata['shape']} vs {combined_metadata['shape']}" 384 | ) 385 | if metadata["dtype"] != combined_metadata["dtype"]: 386 | raise ValueError(f"Incompatible dtypes between datasets") 387 | 388 | # Accumulate config fields 389 | combined_config["total_tokens"] += config.get("total_tokens", 0) 390 | combined_config["n_total_files"] += config.get("n_total_files", 0) 391 | combined_config["batches_processed"] += config.get("batches_processed", 0) 392 | 393 | # List files for this prefix 394 | paginator = self.s3_client.get_paginator("list_objects_v2") 395 | page_iterator = paginator.paginate(Bucket=self.bucket_name, Prefix=prefix) 396 | 397 | prefix_paths = [] 398 | for page in page_iterator: 399 | if "Contents" not in page: 400 | continue 401 | 402 | for obj in page["Contents"]: 403 | if ( 404 | obj["Key"] != _metadata_path(prefix) 405 | and obj["Key"] != _statistics_path(prefix) 406 | and not obj["Key"].endswith("cfg.json") 407 | ): 408 | url = self.s3_client.generate_presigned_url( 409 | "get_object", 410 | Params={"Bucket": self.bucket_name, "Key": obj["Key"]}, 411 | ExpiresIn=604700, 412 | ) 413 | prefix_paths.append(url) 414 | 415 | paths.extend(prefix_paths) 416 | 417 | # Store the combined metadata and config 418 | self.metadata = combined_metadata 419 | self.config = combined_config # Store combined config for potential later use 420 | 421 | if self.preserve_file_order: 422 | # chronological upload order 423 | return sorted(paths) 424 | else: 425 | # shuffle the file order 426 | random.shuffle(paths) 427 | return paths 428 | 429 | def __iter__(self): 430 | self._reset() 431 | 432 | if self._running_processes: 433 | raise ValueError("Cannot iterate over cache a second time while it is downloading") 434 | 435 | if len(self._s3_paths) > self._initial_file_index: 436 | while len(self._running_processes) < self.n_workers: 437 | p = Process( 438 | target=download_loop, 439 | args=( 440 | self.buffer, 441 | self._file_index, 442 | self._s3_paths[self._initial_file_index :], # Start from the initial index 443 | self._stop, 444 | self.readable_tensors, 445 | self.writeable_tensors, 446 | self._ongoing_downloads, 447 | self.concurrency, 448 | self.metadata["bytes_per_file"], 449 | self.chunk_size, 450 | self.shuffle, 451 | self.seed, 452 | self.return_ids, 453 | ), 454 | ) 455 | p.start() 456 | self._running_processes.append(p) 457 | time.sleep(0.75) 458 | 459 | return self 460 | 461 | def _next_tensor(self): 462 | try: 463 | idx = self.readable_tensors.get(block=True) 464 | if self.return_ids: 465 | t = { 466 | "states": self.buffer[idx].states.clone().detach(), 467 | "input_ids": self.buffer[idx].input_ids.clone().detach(), 468 | } 469 | else: 470 | t = self.buffer[idx].clone().detach() 471 | 472 | self.writeable_tensors.put(idx, block=True) 473 | return t 474 | except Exception as e: 475 | logger.error(f"exception while iterating: {e}") 476 | self._stop_downloading() 477 | raise StopIteration 478 | 479 | def __next__(self): 480 | while ( 481 | self._file_index.value < len(self._s3_paths) 482 | or not self.readable_tensors.empty() 483 | or self._ongoing_downloads.value > 0 484 | ): 485 | return self._next_tensor() 486 | 487 | if self._running_processes: 488 | self._stop_downloading() 489 | raise StopIteration 490 | 491 | def finalize(self): 492 | self._stop_downloading() 493 | 494 | def _stop_downloading(self): 495 | logger.info("stopping workers...") 496 | self._file_index.value = len(self._s3_paths) 497 | self._stop.value = True 498 | 499 | while not all([not p.is_alive() for p in self._running_processes]): 500 | if not self.readable_tensors.empty(): 501 | self.readable_tensors.get() 502 | 503 | if not self.writeable_tensors.full(): 504 | self.writeable_tensors.put(0) 505 | 506 | time.sleep(0.25) 507 | 508 | for p in self._running_processes: 509 | p.join() # still join to make sure all resources are cleaned up 510 | 511 | self._ongoing_downloads.value = 0 512 | self._running_processes = [] 513 | 514 | 515 | """ 516 | tl;dr of why we need this: 517 | shared memory is handled differently for nested structures -- see buffer intiialization 518 | we can initialize a dict with two tensors with shared memory, and these tensors themselves are shared but NOT the dict 519 | hence writing to buffer[idx] in write_tensor will not actually write to self.buffer[idx], which _next_tensor uses 520 | (possibly a better fix, but for now this works) 521 | """ 522 | 523 | 524 | class SharedBuffer(nn.Module): 525 | def __init__(self, shape, input_ids_shape, dtype): 526 | super().__init__() 527 | self.states = nn.Parameter(torch.ones(shape, dtype=dtype), requires_grad=False) 528 | self.input_ids = nn.Parameter( 529 | torch.ones(input_ids_shape, dtype=torch.int64), requires_grad=False 530 | ) 531 | 532 | def forward(self): 533 | return {"states": self.states, "input_ids": self.input_ids} 534 | 535 | 536 | ### mini-helper for multiprocessing 537 | def ensure_spawn_context(): 538 | """ 539 | Ensures multiprocessing uses 'spawn' context if not already set. 540 | Returns silently if already set to 'spawn'. 541 | Issues warning if unable to set to 'spawn'. 542 | """ 543 | if mp.get_start_method(allow_none=True) != "spawn": 544 | try: 545 | mp.set_start_method("spawn", force=True) 546 | except RuntimeError: 547 | warnings.warn("Multiprocessing start method is not 'spawn'. This may cause issues.") 548 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "activault" 7 | version = "0.1.0" 8 | description = "A pipeline for generating and processing model activations" 9 | readme = "README.md" 10 | requires-python = ">=3.9, <3.12" 11 | license = {text = "Apache 2.0"} 12 | authors = [ 13 | {name = "Mason Wang", email = "mason@tilderesearch.com"}, 14 | {name = "Dhruv Pai", email = "dhruv@tilderesearch.com"}, 15 | {name = "Ben Keigwin", email = "ben@tilderesearch.com"}, 16 | ] 17 | classifiers = [ 18 | "Programming Language :: Python :: 3", 19 | "License :: OSI Approved :: Apache Software License", 20 | "Operating System :: OS Independent", 21 | ] 22 | dependencies = [ 23 | "accelerate>=1.0.1", 24 | "aiohttp>=3.10.11", 25 | "black>=25.1.0", 26 | "boto3>=1.37.13", 27 | "datasets>=3.1.0", 28 | "flash-attn>=2.7.0.post2", 29 | "huggingface-hub>=0.29.3", 30 | "matplotlib>=3.9.4", 31 | "ray>=2.43.0", 32 | "requests>=2.32.3", 33 | "seaborn>=0.13.2", 34 | "setuptools>=75.3.2", 35 | "torch>=2.4.1", 36 | "transformers==4.49.0", 37 | "wheel>=0.45.1", 38 | "xformers>=0.0.28.post1", 39 | ] 40 | 41 | [project.optional-dependencies] 42 | dev = [ 43 | "black", 44 | "isort", 45 | "mypy", 46 | "pytest", 47 | "pytest-cov", 48 | ] 49 | 50 | ray = [ 51 | "ray>=2.9.0", 52 | ] 53 | 54 | [tool.setuptools] 55 | packages = ["pipeline", "configs", "s3",] 56 | 57 | [tool.black] 58 | line-length = 100 59 | target-version = ["py38"] 60 | include = '\.pyi?$' 61 | 62 | [tool.isort] 63 | profile = "black" 64 | line_length = 100 65 | 66 | [tool.mypy] 67 | python_version = "3.8" 68 | warn_return_any = true 69 | warn_unused_configs = true 70 | disallow_untyped_defs = true 71 | disallow_incomplete_defs = true 72 | 73 | [tool.pytest.ini_options] 74 | testpaths = ["tests"] 75 | python_files = "test_*.py" 76 | 77 | [tool.uv.sources] 78 | transformers = { git = "https://github.com/huggingface/transformers", rev = "v4.49.0-Gemma-3" } 79 | 80 | [project.scripts] 81 | activault = "pipeline.utils.cli:main" 82 | -------------------------------------------------------------------------------- /retrieve.py: -------------------------------------------------------------------------------- 1 | """ 2 | RCache efficiently streams transformer activation data from S3 for training interpreter models. 3 | It maintains a small buffer of megabatch files (each containing multiple batches concatenated together during uploads) 4 | and asynchronously downloads the next files while the current ones are being processed. 5 | 6 | After a brief initial load (<30s), training should never bottlenecked by downloads since they happen asynchronously in the background. 7 | 8 | Copyright (2025) Tilde Research Inc. 9 | 10 | Licensed under the Apache License, Version 2.0 (the "License"); 11 | you may not use this file except in compliance with the License. 12 | You may obtain a copy of the License at 13 | 14 | http://www.apache.org/licenses/LICENSE-2.0 15 | 16 | Unless required by applicable law or agreed to in writing, software 17 | distributed under the License is distributed on an "AS IS" BASIS, 18 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 19 | See the License for the specific language governing permissions and 20 | limitations under the License. 21 | """ 22 | 23 | from transformers import AutoTokenizer 24 | from pipeline.vault import S3RCache 25 | from s3.utils import create_s3_client 26 | import os 27 | import json 28 | import logging 29 | 30 | # Constants 31 | RUN_NAME = "llama3.3_70b" # Base run name without hook 32 | BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", "main") 33 | 34 | logging.basicConfig(level=logging.INFO) 35 | logger = logging.getLogger(__name__) 36 | 37 | 38 | def get_first_hook_prefix(run_name, bucket_name): 39 | """Get the first available hook prefix for the run.""" 40 | s3_client = create_s3_client() 41 | response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=f"{run_name}/", Delimiter="/") 42 | if "CommonPrefixes" in response: 43 | # Get first hook directory 44 | first_hook = response["CommonPrefixes"][0]["Prefix"].rstrip("/") 45 | return first_hook 46 | return None 47 | 48 | 49 | def get_model_name_from_config(run_name, bucket_name): 50 | """Get model name from the run's config file.""" 51 | s3_client = create_s3_client() 52 | cfg_path = f"/tmp/{run_name}_cfg.json" 53 | s3_client.download_file(bucket_name, f"{run_name}/cfg.json", cfg_path) 54 | with open(cfg_path, "r") as f: 55 | model_name = json.load(f)["transformer_config"]["model_name"] 56 | os.remove(cfg_path) 57 | return model_name 58 | 59 | 60 | def inspect_batch(states, input_ids, tokenizer): 61 | """Helper function to inspect a batch of activations and tokens.""" 62 | logger.info(f"States shape: {states.shape}") 63 | logger.info(f"Input IDs shape: {input_ids.shape}") 64 | logger.info(f"\nStats: mean={states.mean().item():.4f}, std={states.std().item():.4f}") 65 | logger.info(f"Sample text: {tokenizer.decode(input_ids[0])[:100]}...") 66 | 67 | 68 | def main(): 69 | logger.info("Demo: Reading transformer activations from S3 cache") 70 | 71 | # Get first available hook prefix 72 | prefix = get_first_hook_prefix(RUN_NAME, BUCKET_NAME) 73 | if not prefix: 74 | logger.error(f"No hooks found for run {RUN_NAME}") 75 | return 76 | logger.info(f"Using hook prefix: {prefix}") 77 | 78 | # Initialize tokenizer 79 | model_name = get_model_name_from_config(RUN_NAME, BUCKET_NAME) 80 | tokenizer = AutoTokenizer.from_pretrained(model_name) 81 | 82 | # Initialize cache reader 83 | cache = S3RCache.from_credentials( 84 | aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"), 85 | aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"), 86 | s3_prefix=prefix, 87 | bucket_name=BUCKET_NAME, 88 | device="cpu", 89 | buffer_size=2, 90 | return_ids=True, 91 | ) 92 | 93 | logger.info("\nReading first two megabatch files from S3...") 94 | logger.info("Each file contains n_batches_per_file batches concatenated together") 95 | logger.info("Format: [n_batches_per_file, sequence_length, hidden_dim]\n") 96 | 97 | # Inspect a few batches 98 | for batch_idx, batch in enumerate(cache): 99 | if batch_idx >= 2: 100 | break 101 | inspect_batch(batch["states"], batch["input_ids"], tokenizer) 102 | 103 | cache.finalize() 104 | 105 | 106 | if __name__ == "__main__": 107 | main() 108 | -------------------------------------------------------------------------------- /s3/shell/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/activault/db6d1e4e36c2d3eb4fdce79e72be94f387eccee1/s3/shell/__init__.py -------------------------------------------------------------------------------- /s3/shell/history.py: -------------------------------------------------------------------------------- 1 | import readline 2 | import os 3 | from typing import Optional 4 | 5 | 6 | class CommandHistory: 7 | def __init__(self, histfile: Optional[str] = None): 8 | """Initialize command history with optional history file.""" 9 | self.histfile = histfile 10 | 11 | # Configure readline 12 | if histfile: 13 | # Create history file directory if it doesn't exist 14 | os.makedirs(os.path.dirname(histfile), exist_ok=True) 15 | 16 | # Set history file 17 | readline.set_history_length(1000) 18 | try: 19 | # Only try to read history file if it exists 20 | if os.path.exists(histfile): 21 | readline.read_history_file(histfile) 22 | except OSError: 23 | # Handle any OS errors when reading history 24 | print(f"Warning: Could not read history file {histfile}") 25 | pass 26 | 27 | # Enable tab completion and better key bindings 28 | readline.parse_and_bind("tab: complete") 29 | 30 | # Set better key bindings for history navigation 31 | readline.parse_and_bind('"\e[A": previous-history') # Up arrow 32 | readline.parse_and_bind('"\e[B": next-history') # Down arrow 33 | 34 | def add_command(self, cmd: str): 35 | """Add a command to history.""" 36 | if cmd.strip(): # Only add non-empty commands 37 | readline.add_history(cmd) 38 | 39 | def save(self): 40 | """Save history to file if histfile was specified.""" 41 | if self.histfile: 42 | try: 43 | readline.write_history_file(self.histfile) 44 | except Exception as e: 45 | print(f"Error saving history: {str(e)}") 46 | -------------------------------------------------------------------------------- /s3/shell/s3_operations.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple, Optional, Any 2 | from dataclasses import dataclass 3 | from concurrent.futures import ThreadPoolExecutor 4 | import threading 5 | import os 6 | from s3.utils import create_s3_client 7 | 8 | BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", "main") 9 | 10 | 11 | @dataclass 12 | class Progress: 13 | current: int = 0 14 | total: int = 0 15 | done: bool = False 16 | 17 | 18 | class S3Operations: 19 | def __init__(self, bucket: str = BUCKET_NAME): 20 | self.s3_client = create_s3_client() 21 | self.bucket = bucket 22 | 23 | def format_size(self, size_bytes: int) -> str: 24 | """Convert bytes to human readable format.""" 25 | for unit in ["B", "KB", "MB", "GB"]: 26 | if size_bytes < 1024: 27 | return f"{size_bytes:.2f} {unit}" 28 | size_bytes /= 1024 29 | return f"{size_bytes:.2f} TB" 30 | 31 | def list_objects(self, prefix: str) -> Tuple[List[Dict], List[str]]: 32 | """List objects and prefixes at a path.""" 33 | paginator = self.s3_client.get_paginator("list_objects_v2") 34 | files = [] 35 | folders = set() 36 | 37 | try: 38 | for page in paginator.paginate(Bucket=self.bucket, Prefix=prefix, Delimiter="/"): 39 | # Get folders 40 | if "CommonPrefixes" in page: 41 | for p in page["CommonPrefixes"]: 42 | folder = p["Prefix"].rstrip("/").split("/")[-1] 43 | folders.add(folder) 44 | 45 | # Get files 46 | if "Contents" in page: 47 | for obj in page["Contents"]: 48 | key = obj["Key"] 49 | if key == prefix: # Skip the prefix itself 50 | continue 51 | name = key[len(prefix) :].split("/")[0] 52 | if "/" not in name: # Only direct files 53 | files.append({"name": name, "size": obj["Size"], "key": key}) 54 | except Exception as e: 55 | print(f"\nError listing objects: {str(e)}") 56 | 57 | return files, sorted(list(folders)) 58 | 59 | def list_all_objects(self, prefix: str) -> List[Dict[str, Any]]: 60 | """List all objects under a prefix without delimiter (recursive).""" 61 | objects = [] 62 | paginator = self.s3_client.get_paginator("list_objects_v2") 63 | for page in paginator.paginate(Bucket=self.bucket, Prefix=prefix): 64 | if "Contents" in page: 65 | objects.extend(page["Contents"]) 66 | return objects 67 | 68 | def delete_batch(self, batch: List[Dict[str, str]], progress: Optional[Progress] = None) -> int: 69 | """Delete a batch of objects.""" 70 | if not batch: 71 | return 0 72 | 73 | try: 74 | self.s3_client.delete_objects( 75 | Bucket=self.bucket, Delete={"Objects": batch, "Quiet": True} 76 | ) 77 | if progress: 78 | progress.current += len(batch) 79 | return len(batch) 80 | except Exception as e: 81 | print(f"\nError deleting batch: {str(e)}") 82 | return 0 83 | 84 | def count_objects(self, prefix: str) -> int: 85 | """Count objects with a prefix.""" 86 | count = 0 87 | paginator = self.s3_client.get_paginator("list_objects_v2") 88 | for page in paginator.paginate(Bucket=self.bucket, Prefix=prefix): 89 | if "Contents" in page: 90 | count += len(page["Contents"]) 91 | return count 92 | 93 | def delete_objects(self, prefix: str, progress_callback=None) -> Tuple[int, int]: 94 | """Delete objects with a prefix recursively.""" 95 | # List all objects first 96 | all_objects = self.list_all_objects(prefix) 97 | total_objects = len(all_objects) 98 | 99 | if total_objects == 0: 100 | return 0, 0 101 | 102 | # Setup progress 103 | progress = Progress(total=total_objects) if progress_callback else None 104 | if progress_callback: 105 | progress_thread = threading.Thread(target=progress_callback, args=(progress,)) 106 | progress_thread.start() 107 | 108 | try: 109 | # Delete in batches 110 | batch_size = 1000 111 | deleted_count = 0 112 | current_batch = [] 113 | 114 | with ThreadPoolExecutor(max_workers=10) as executor: 115 | futures = [] 116 | 117 | # Create batches of objects to delete 118 | for obj in all_objects: 119 | current_batch.append({"Key": obj["Key"]}) 120 | 121 | if len(current_batch) >= batch_size: 122 | futures.append( 123 | executor.submit(self.delete_batch, current_batch.copy(), progress) 124 | ) 125 | current_batch = [] 126 | 127 | # Handle remaining objects 128 | if current_batch: 129 | futures.append(executor.submit(self.delete_batch, current_batch, progress)) 130 | 131 | # Wait for all deletions to complete 132 | for future in futures: 133 | deleted_count += future.result() 134 | 135 | return deleted_count, total_objects 136 | 137 | finally: 138 | if progress: 139 | progress.done = True 140 | progress_thread.join() 141 | 142 | def read_file(self, key: str) -> str: 143 | """Read a file's contents.""" 144 | response = self.s3_client.get_object(Bucket=self.bucket, Key=key) 145 | return response["Body"].read().decode("utf-8") 146 | 147 | def download_file(self, key: str, local_path: str): 148 | """Download a file from S3 to a local path.""" 149 | try: 150 | self.s3_client.download_file(self.bucket, key, local_path) 151 | except Exception as e: 152 | print(f"Error downloading file {key}: {str(e)}") 153 | -------------------------------------------------------------------------------- /s3/shell/sanity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | import os 4 | from typing import Dict, Tuple, Any, List 5 | from transformers import AutoTokenizer 6 | 7 | 8 | def load_pt_file(file_path: str) -> Dict[str, torch.Tensor]: 9 | """Load a .pt file and return its contents.""" 10 | return torch.load(file_path) 11 | 12 | 13 | def check_tensor_validity(tensor: torch.Tensor) -> Tuple[bool, bool, float, float]: 14 | """Check tensor for NaNs and Infs.""" 15 | has_nan = torch.isnan(tensor).any().item() 16 | has_inf = torch.isinf(tensor).any().item() 17 | min_val = float(tensor.min()) 18 | max_val = float(tensor.max()) 19 | return has_nan, has_inf, min_val, max_val 20 | 21 | 22 | def format_tensor_preview(tensor: torch.Tensor, batch_idx: int) -> str: 23 | """Format a preview of the states tensor for a specific batch.""" 24 | # Get the batch 25 | batch = tensor[batch_idx] 26 | # Convert to float to ensure consistent formatting 27 | batch = batch.float() 28 | # Get a small preview of actual values (first few elements) 29 | preview_size = 5 30 | preview_values = batch.flatten()[:preview_size].tolist() 31 | preview_str = ", ".join(f"{x:.3f}" for x in preview_values) 32 | 33 | return f"First {preview_size} values: [{preview_str}, ...]" 34 | 35 | 36 | def decode_input_ids(input_ids: torch.Tensor, model_name: str) -> List[str]: 37 | """Decode input_ids tensor into text using the correct tokenizer.""" 38 | try: 39 | tokenizer = AutoTokenizer.from_pretrained(model_name) 40 | # Decode each batch separately 41 | texts = [] 42 | for batch_idx in range(input_ids.shape[0]): 43 | batch_ids = input_ids[batch_idx] 44 | text = tokenizer.decode(batch_ids) 45 | texts.append(text) 46 | return texts 47 | except Exception as e: 48 | return [f"Error decoding text: {str(e)}"] * input_ids.shape[0] 49 | 50 | 51 | def inspect_tensor_file( 52 | bucket: str, file_path: str, s3_path: str 53 | ) -> Tuple[Dict[str, Any], str, List[str], Dict[str, Any]]: 54 | """ 55 | Inspect a .pt file and return shapes, model info, and decoded texts. 56 | Args: 57 | file_path: Local path to the .pt file 58 | s3_path: Full S3 path to the file (for finding cfg.json) 59 | Returns: 60 | (shapes_dict, tokenizer_name, decoded_texts, tensor_info) 61 | """ 62 | print(f"Inspecting file: {file_path}") 63 | # Load the tensor file 64 | data = load_pt_file(file_path) 65 | 66 | # Get shapes and check tensors 67 | shapes = {} 68 | tensor_info = {} 69 | if "states" in data: 70 | states = data["states"] 71 | shapes["states"] = list(states.shape) 72 | has_nan, has_inf, min_val, max_val = check_tensor_validity(states) 73 | tensor_info["states"] = { 74 | "has_nan": has_nan, 75 | "has_inf": has_inf, 76 | "min_val": min_val, 77 | "max_val": max_val, 78 | "tensor": states, # Store for later use in previews 79 | } 80 | if "input_ids" in data: 81 | shapes["input_ids"] = list(data["input_ids"].shape) 82 | 83 | # Get model info from cfg.json in activault root 84 | try: 85 | # Get the activault job name (first component of the path) 86 | activault_job = s3_path.split("/")[0] 87 | cfg_path = f"/tmp/{activault_job}_cfg.json" 88 | 89 | # Download and read cfg.json 90 | from s3.utils import create_s3_client 91 | 92 | s3_client = create_s3_client() 93 | s3_client.download_file(bucket, f"{activault_job}/cfg.json", cfg_path) 94 | 95 | with open(cfg_path, "r") as f: 96 | cfg = json.load(f) 97 | model_name = cfg["transformer_config"]["model_name"] 98 | 99 | # Clean up cfg file 100 | import os 101 | 102 | if os.path.exists(cfg_path): 103 | os.remove(cfg_path) 104 | except: 105 | model_name = "Unknown" 106 | 107 | # Decode input_ids if available 108 | decoded_texts = [] 109 | if "input_ids" in data and model_name != "Unknown": 110 | decoded_texts = decode_input_ids(data["input_ids"], model_name) 111 | 112 | return shapes, model_name, decoded_texts, tensor_info 113 | -------------------------------------------------------------------------------- /s3/shell/shell.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Any 2 | from dataclasses import dataclass 3 | import sys 4 | import os 5 | from pathlib import Path 6 | from .s3_operations import S3Operations, Progress 7 | from .sanity import inspect_tensor_file, format_tensor_preview 8 | from .history import CommandHistory 9 | 10 | 11 | @dataclass 12 | class IndexedItem: 13 | index: int 14 | name: str 15 | is_folder: bool 16 | stats: Dict[str, Any] 17 | 18 | 19 | class S3Shell: 20 | def __init__(self): 21 | self.s3 = S3Operations() 22 | self.current_path: List[str] = [] 23 | self.indexed_items: List[IndexedItem] = [] 24 | self.last_output: str = "" 25 | 26 | # Setup command history 27 | histfile = os.path.expanduser("~/.s3shell_history") 28 | self.history = CommandHistory(histfile) 29 | 30 | # ANSI color codes 31 | self.BLUE = "\033[34m" 32 | self.GREEN = "\033[32m" 33 | self.RED = "\033[31m" 34 | self.RESET = "\033[0m" 35 | 36 | def check_mark(self, condition: bool) -> str: 37 | """Return emoji check mark or X based on condition.""" 38 | if condition: 39 | return f"{self.GREEN}✅{self.RESET}" 40 | return f"{self.RED}❌{self.RESET}" 41 | 42 | def get_current_prefix(self) -> str: 43 | """Get the current S3 prefix based on path.""" 44 | return "/".join(self.current_path + [""] if self.current_path else []) 45 | 46 | def list_items(self) -> List[IndexedItem]: 47 | """List items in current path.""" 48 | prefix = self.get_current_prefix() 49 | files, folders = self.s3.list_objects(prefix) 50 | 51 | # Create indexed items 52 | items = [] 53 | self._index_map = {} # For O(1) index lookup 54 | self._folder_map = {} # For O(1) folder name lookup 55 | idx = 1 56 | 57 | # Add folders first 58 | for folder in folders: 59 | item = IndexedItem(index=idx, name=folder, is_folder=True, stats={}) 60 | items.append(item) 61 | self._index_map[idx] = item 62 | self._folder_map[folder] = item 63 | idx += 1 64 | 65 | # Add files 66 | for file in files: 67 | item = IndexedItem( 68 | index=idx, name=file["name"], is_folder=False, stats={"size": file["size"]} 69 | ) 70 | items.append(item) 71 | self._index_map[idx] = item 72 | idx += 1 73 | 74 | self.indexed_items = items 75 | return items 76 | 77 | def format_listing(self, items: List[IndexedItem]) -> str: 78 | """Format items listing.""" 79 | output = [] 80 | 81 | # Current path 82 | path = "/" + "/".join(self.current_path) 83 | output.append(f"\nCurrent path: {path}") 84 | output.append("=" * 80) 85 | 86 | # List items or show empty directory message 87 | if not items: 88 | output.append("(empty directory)") 89 | else: 90 | for item in items: 91 | prefix = "📁" if item.is_folder else "📄" 92 | if item.is_folder: 93 | output.append(f"{item.index:3d}. {prefix} {item.name}") 94 | else: 95 | output.append( 96 | f"{item.index:3d}. {prefix} {item.name} ({self.s3.format_size(item.stats['size'])})" 97 | ) 98 | 99 | return "\n".join(output) 100 | 101 | def cmd_filecount(self, args: List[str]) -> str: 102 | """Count files in folders at current level.""" 103 | items = self.indexed_items 104 | folders = [item for item in items if item.is_folder] 105 | 106 | if not folders: 107 | return "No folders in current directory" 108 | 109 | output = ["\nFile Count Check:", "-" * 40] 110 | 111 | # Count files in each folder 112 | folder_counts = {} 113 | for folder in folders: 114 | folder_prefix = self.get_current_prefix() + folder.name + "/" 115 | count = self.s3.count_objects(folder_prefix) 116 | folder_counts[folder.name] = count 117 | output.append(f"{folder.name}: {count} files") 118 | 119 | # Check if all counts are the same 120 | counts = list(folder_counts.values()) 121 | same_count = len(set(counts)) == 1 122 | output.append(f"\nAll folders have same file count? {self.check_mark(same_count)}") 123 | 124 | return "\n".join(output) 125 | 126 | def cmd_sizecheck(self, args: List[str]) -> str: 127 | """Check .pt file sizes in current directory and one level down.""" 128 | items = self.indexed_items 129 | output = ["\nPT File Size Check:", "-" * 40] 130 | 131 | # Check current directory first 132 | pt_files = [ 133 | (item.name, item.stats["size"]) 134 | for item in items 135 | if not item.is_folder and item.name.endswith(".pt") 136 | ] 137 | 138 | if not pt_files: 139 | output.append("No .pt files in current directory") 140 | output.append("Checking subdirectories...") 141 | output.append("") 142 | else: 143 | # Check current directory files 144 | sizes = {size for _, size in pt_files} 145 | same_size = len(sizes) == 1 146 | 147 | if same_size: 148 | size = next(iter(sizes)) 149 | output.append( 150 | f"Current directory - All .pt files same size: {self.check_mark(True)} ({self.s3.format_size(size)})" 151 | ) 152 | else: 153 | output.append( 154 | f"Current directory - All .pt files same size: {self.check_mark(False)}" 155 | ) 156 | output.append("Size groups:") 157 | size_groups = {} 158 | for name, size in pt_files: 159 | if size not in size_groups: 160 | size_groups[size] = [] 161 | size_groups[size].append(name) 162 | for size, files in sorted(size_groups.items()): 163 | output.append(f" {self.s3.format_size(size)}: {', '.join(files)}") 164 | output.append("") 165 | 166 | # Check subdirectories 167 | folders = [item for item in items if item.is_folder] 168 | if folders: 169 | for folder in folders: 170 | folder_prefix = self.get_current_prefix() + folder.name + "/" 171 | files, _ = self.s3.list_objects(folder_prefix) 172 | pt_files = [(f["name"], f["size"]) for f in files if f["name"].endswith(".pt")] 173 | 174 | if pt_files: 175 | sizes = {size for _, size in pt_files} 176 | same_size = len(sizes) == 1 177 | 178 | if same_size: 179 | size = next(iter(sizes)) 180 | output.append( 181 | f"{folder.name} - All .pt files same size: {self.check_mark(True)} ({self.s3.format_size(size)})" 182 | ) 183 | else: 184 | output.append( 185 | f"{folder.name} - All .pt files same size: {self.check_mark(False)}" 186 | ) 187 | output.append("Size groups:") 188 | size_groups = {} 189 | for name, size in pt_files: 190 | if size not in size_groups: 191 | size_groups[size] = [] 192 | size_groups[size].append(name) 193 | for size, files in sorted(size_groups.items()): 194 | output.append(f" {self.s3.format_size(size)}: {', '.join(files)}") 195 | else: 196 | output.append(f"{folder.name} - No .pt files") 197 | else: 198 | output.append("No subdirectories to check") 199 | 200 | return "\n".join(output) 201 | 202 | def cmd_help(self, args: List[str]) -> str: 203 | """Show help message.""" 204 | return """Available commands: 205 | ls List current directory contents 206 | cd Change directory (cd .. for parent, cd for root) 207 | cat View file contents (> file.txt to save) 208 | rm Remove files/folders 209 | filecount Compare file counts across folders 210 | sizecheck Check .pt file sizes in current and child dirs 211 | inspect View tensor shapes and decode text from .pt file 212 | help Show this help 213 | exit Exit shell""" 214 | 215 | def handle_redirect(self, output: str, args: List[str]) -> str: 216 | """Handle output redirection.""" 217 | if len(args) >= 2 and args[-2] == ">": 218 | filename = args[-1] 219 | with open(filename, "w") as f: 220 | f.write(output) 221 | return f"Output saved to {filename}" 222 | return output 223 | 224 | def show_progress(self, progress: Progress): 225 | """Show a spinning progress indicator.""" 226 | spinner = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"] 227 | idx = 0 228 | while not progress.done: 229 | sys.stdout.write( 230 | f"\r{spinner[idx]} Deleted {progress.current}/{progress.total} objects..." 231 | ) 232 | sys.stdout.flush() 233 | idx = (idx + 1) % len(spinner) 234 | import time 235 | 236 | time.sleep(0.1) 237 | sys.stdout.write(f"\rDeleted {progress.current}/{progress.total} objects. \n") 238 | sys.stdout.flush() 239 | 240 | def get_item_by_index(self, idx: int) -> Optional[IndexedItem]: 241 | """Get item by its index in O(1) time.""" 242 | return self._index_map.get(idx) 243 | 244 | def get_folder_by_name(self, name: str) -> Optional[IndexedItem]: 245 | """Get folder by exact name match in O(1) time.""" 246 | return self._folder_map.get(name) 247 | 248 | def get_prompt(self) -> str: 249 | """Get the current prompt with path.""" 250 | path = f"{self.BLUE}s3://{self.s3.bucket}" 251 | if self.current_path: 252 | path += "/" + "/".join(self.current_path) 253 | return f"{path}{self.RESET}> " 254 | 255 | def cmd_ls(self, args: List[str]) -> str: 256 | """Handle ls command.""" 257 | items = self.list_items() 258 | return self.format_listing(items) 259 | 260 | def cmd_cd(self, args: List[str]) -> str: 261 | """Handle cd command.""" 262 | if not args: 263 | self.current_path = [] 264 | self.list_items() 265 | return "Changed to root directory" 266 | 267 | if args[0] == "..": 268 | if self.current_path: 269 | self.current_path.pop() 270 | self.list_items() 271 | return "Changed to parent directory" 272 | return "Already at root" 273 | 274 | # Try to convert to index first - fast path 275 | try: 276 | idx = int(args[0]) 277 | item = self.get_item_by_index(idx) 278 | if not item: 279 | return f"No item with index {idx}" 280 | if not item.is_folder: 281 | return "Not a folder" 282 | self.current_path.append(item.name) 283 | self.list_items() 284 | return f"Changed to {item.name}" 285 | except ValueError: 286 | # Only try folder name lookup if index conversion fails 287 | item = self.get_folder_by_name(args[0]) 288 | if not item: 289 | return f"No folder named '{args[0]}' found" 290 | self.current_path.append(item.name) 291 | self.list_items() 292 | return f"Changed to {item.name}" 293 | 294 | def cmd_cat(self, args: List[str]) -> str: 295 | """Handle cat command.""" 296 | if not args: 297 | return "Usage: cat [> filename]" 298 | 299 | try: 300 | idx = int(args[0]) 301 | item = self.get_item_by_index(idx) 302 | if not item or item.is_folder: 303 | return "Not a file" 304 | 305 | key = self.get_current_prefix() + item.name 306 | content = self.s3.read_file(key) 307 | self.last_output = content 308 | return content 309 | 310 | except ValueError: 311 | return "Invalid index" 312 | except Exception as e: 313 | return f"Error: {str(e)}" 314 | 315 | def cmd_rm(self, args: List[str]) -> str: 316 | """Handle rm command - removes files or folders.""" 317 | if not args: 318 | return "Usage: rm [ ...]" 319 | 320 | results = [] 321 | total_deleted = 0 322 | total_objects = 0 323 | 324 | # Pre-validate all indices first to avoid partial deletions 325 | items_to_delete = [] 326 | for arg in args: 327 | try: 328 | idx = int(arg) 329 | item = self.get_item_by_index(idx) 330 | if not item: 331 | return f"Invalid index: {idx}" 332 | items_to_delete.append(item) 333 | except ValueError: 334 | return f"Invalid index: {arg}" 335 | 336 | # All indices are valid, proceed with deletion 337 | for item in items_to_delete: 338 | try: 339 | prefix = self.get_current_prefix() + item.name 340 | if item.is_folder: 341 | prefix += "/" 342 | 343 | deleted_count, objects_count = self.s3.delete_objects( 344 | prefix, progress_callback=self.show_progress 345 | ) 346 | 347 | total_deleted += deleted_count 348 | total_objects += objects_count 349 | 350 | if objects_count == 0: 351 | results.append(f"No objects found with prefix {prefix}") 352 | else: 353 | what = "folder" if item.is_folder else "file" 354 | results.append(f"Deleted {what} '{item.name}'") 355 | 356 | except Exception as e: 357 | results.append(f"Error deleting {item.name}: {str(e)}") 358 | 359 | # Add summary if multiple items were processed 360 | if len(args) > 1: 361 | results.append(f"\nSummary: Deleted {total_deleted} objects across {len(args)} items") 362 | 363 | # Refresh the directory listing after deletion 364 | self.list_items() 365 | 366 | return "\n".join(results) 367 | 368 | def cmd_inspect(self, args: List[str]) -> str: 369 | """Handle inspect command for .pt files.""" 370 | if not args: 371 | return "Usage: inspect " 372 | 373 | try: 374 | idx = int(args[0]) 375 | item = self.get_item_by_index(idx) 376 | if not item or item.is_folder: 377 | return "Not a file" 378 | 379 | if not item.name.endswith(".pt"): 380 | return "Not a .pt file" 381 | 382 | # Download the file to a temporary location 383 | key = self.get_current_prefix() + item.name 384 | local_path = f"/tmp/{item.name}" 385 | self.s3.download_file(key, local_path) 386 | 387 | try: 388 | # Get the S3 path relative to bucket root 389 | s3_path = "/".join(self.current_path + [item.name]) 390 | 391 | # Inspect the file 392 | shapes, model_name, decoded_texts, tensor_info = inspect_tensor_file( 393 | self.s3.bucket, local_path, s3_path 394 | ) 395 | 396 | # Format initial output 397 | output = ["\nPT File Inspection:", "-" * 40] 398 | output.append(f"Model: {model_name}") 399 | output.append("\nTensor Shapes:") 400 | for key, shape in shapes.items(): 401 | output.append(f" {key}: {shape}") 402 | 403 | # Add tensor validity checks 404 | if "states" in tensor_info: 405 | states_info = tensor_info["states"] 406 | output.append("\nStates Tensor Check:") 407 | output.append(f" No NaNs: {self.check_mark(not states_info['has_nan'])}") 408 | output.append(f" No Infs: {self.check_mark(not states_info['has_inf'])}") 409 | output.append( 410 | f" Value range: [{states_info['min_val']:.3f}, {states_info['max_val']:.3f}]" 411 | ) 412 | 413 | if decoded_texts: 414 | output.append("\nFirst 4 batches (first 250 chars each):") 415 | output.append("-" * 40) 416 | for i, text in enumerate(decoded_texts[:4]): 417 | # Clean up text: remove multiple newlines and leading/trailing whitespace 418 | text = "\n".join(line for line in text.splitlines() if line.strip()) 419 | preview = text[:250] + "..." if len(text) > 250 else text 420 | output.append(f"Batch {i}: {preview}") 421 | 422 | # Print initial output 423 | print("\n".join(output)) 424 | 425 | # Interactive mode for viewing full batches 426 | num_batches = len(decoded_texts) 427 | while True: 428 | print( 429 | f"\nEnter batch number (0-{num_batches-1}) to view full text, or 'q' to quit:" 430 | ) 431 | try: 432 | choice = input().strip() 433 | if choice.lower() == "q": 434 | return "" 435 | 436 | batch_idx = int(choice) 437 | if 0 <= batch_idx < num_batches: 438 | text = "\n".join( 439 | line 440 | for line in decoded_texts[batch_idx].splitlines() 441 | if line.strip() 442 | ) 443 | print(f"\nFull text of batch {batch_idx}:\n{'-'*40}\n{text}") 444 | 445 | # Show states tensor preview for this batch 446 | if "states" in tensor_info: 447 | states_preview = format_tensor_preview( 448 | tensor_info["states"]["tensor"], batch_idx 449 | ) 450 | print( 451 | f"\nStates tensor preview for batch {batch_idx}:\n{'-'*40}\n{states_preview}" 452 | ) 453 | else: 454 | print( 455 | f"Invalid batch number. Please enter a number between 0 and {num_batches-1}" 456 | ) 457 | except ValueError: 458 | print( 459 | f"Invalid input. Please enter a number between 0 and {num_batches-1}" 460 | ) 461 | 462 | return "" 463 | finally: 464 | # Clean up 465 | if os.path.exists(local_path): 466 | os.remove(local_path) 467 | 468 | except ValueError: 469 | return "Invalid index" 470 | except Exception as e: 471 | return f"Error: {str(e)}" 472 | 473 | def run(self): 474 | """Run the shell loop.""" 475 | print("S3 Shell - Type 'help' for commands") 476 | print(self.cmd_ls([])) # List items on startup 477 | 478 | while True: 479 | try: 480 | command = input(self.get_prompt()).strip() 481 | if not command: 482 | continue 483 | 484 | # Add command to history 485 | self.history.add_command(command) 486 | 487 | parts = command.split() 488 | cmd, args = parts[0], parts[1:] 489 | 490 | if cmd == "exit": 491 | self.history.save() # Save history before exiting 492 | break 493 | 494 | # Handle commands 495 | output = "" 496 | if cmd == "ls": 497 | output = self.cmd_ls(args) 498 | elif cmd == "cd": 499 | output = self.cmd_cd(args) 500 | elif cmd == "cat": 501 | output = self.cmd_cat(args[0:1]) # Only pass the index 502 | output = self.handle_redirect(output, args) 503 | elif cmd == "rm": 504 | output = self.cmd_rm(args) 505 | elif cmd == "filecount": 506 | output = self.cmd_filecount(args) 507 | elif cmd == "sizecheck": 508 | output = self.cmd_sizecheck(args) 509 | elif cmd == "inspect": 510 | output = self.cmd_inspect(args) 511 | elif cmd == "help": 512 | output = self.cmd_help(args) 513 | else: 514 | output = f"Unknown command: {cmd}" 515 | 516 | if output: 517 | print(output) 518 | 519 | except KeyboardInterrupt: 520 | print("\nUse 'exit' to quit") 521 | except Exception as e: 522 | print(f"Error: {str(e)}") 523 | 524 | # Final save of history 525 | self.history.save() 526 | -------------------------------------------------------------------------------- /s3/utils.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | import os 3 | from typing import Optional 4 | 5 | 6 | def create_s3_client( 7 | access_key_id: Optional[str] = None, 8 | secret_access_key: Optional[str] = None, 9 | endpoint_url: Optional[str] = None, 10 | ) -> boto3.client: 11 | """Create an S3 client configured for S3-compatible storage services. 12 | 13 | This function creates a boto3 S3 client with optimized settings for reliable 14 | data transfer. It supports both direct credential passing and environment 15 | variable configuration. 16 | 17 | Args: 18 | access_key_id: S3 access key ID. If None, reads from AWS_ACCESS_KEY_ID env var 19 | secret_access_key: S3 secret key. If None, reads from AWS_SECRET_ACCESS_KEY env var 20 | endpoint_url: S3-compatible storage service endpoint URL 21 | 22 | Returns: 23 | boto3.client: Configured S3 client with optimized settings 24 | 25 | Environment Variables: 26 | - AWS_ACCESS_KEY_ID: S3 access key ID (if not provided as argument) 27 | - AWS_SECRET_ACCESS_KEY: S3 secret key (if not provided as argument) 28 | 29 | Example: 30 | ```python 31 | # Using environment variables 32 | s3_client = create_s3_client() 33 | 34 | # Using explicit credentials 35 | s3_client = create_s3_client( 36 | access_key_id="your_key", 37 | secret_access_key="your_secret", 38 | endpoint_url="your_endpoint_url" 39 | ) 40 | ``` 41 | 42 | Note: 43 | The client is configured with path-style addressing and S3v4 signatures 44 | for maximum compatibility with S3-compatible storage services. 45 | """ 46 | access_key_id = access_key_id or os.environ.get("AWS_ACCESS_KEY_ID") 47 | secret_access_key = secret_access_key or os.environ.get("AWS_SECRET_ACCESS_KEY") 48 | endpoint_url = endpoint_url or os.environ.get("S3_ENDPOINT_URL") 49 | 50 | if not access_key_id or not secret_access_key: 51 | raise ValueError( 52 | "S3 credentials must be provided either through arguments or " 53 | "AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables" 54 | ) 55 | 56 | if not endpoint_url: 57 | raise ValueError( 58 | "S3 endpoint URL must be provided either through arguments or " 59 | "S3_ENDPOINT_URL environment variable" 60 | ) 61 | 62 | session = boto3.session.Session() 63 | return session.client( 64 | service_name="s3", 65 | aws_access_key_id=access_key_id, 66 | aws_secret_access_key=secret_access_key, 67 | endpoint_url=endpoint_url, 68 | use_ssl=True, 69 | verify=True, 70 | config=boto3.session.Config( 71 | s3={"addressing_style": "path"}, 72 | signature_version="s3v4", 73 | # Advanced configuration options (currently commented out): 74 | # retries=dict( 75 | # max_attempts=3, # Number of retry attempts 76 | # mode='adaptive' # Adds exponential backoff 77 | # ), 78 | # max_pool_connections=20, # Limits concurrent connections 79 | # connect_timeout=60, # Connection timeout in seconds 80 | # read_timeout=300, # Read timeout in seconds 81 | # tcp_keepalive=True, # Enable TCP keepalive 82 | ), 83 | ) 84 | -------------------------------------------------------------------------------- /scripts/collect.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=activault # Job name 3 | #SBATCH --output=logs/activault_%j.out # Output file (%j = job ID) 4 | #SBATCH --error=logs/activault_%j.err # Error file (%j = job ID) 5 | #SBATCH --time=200:00:00 # Time limit (200 hours) 6 | #SBATCH --nodes=1 # Number of nodes 7 | #SBATCH --ntasks-per-node=1 # Number of tasks per node 8 | #SBATCH --cpus-per-task=32 # Number of CPU cores per task 9 | #SBATCH --gres=gpu:2 # Number of GPUs (2 in this case) 10 | #SBATCH --mem=250G # Memory per node 11 | #SBATCH --partition=main # Partition/queue name 12 | #SBATCH --mail-type=BEGIN,END,FAIL # Email notifications 13 | 14 | # Activate virtual environment 15 | source .venv/bin/activate 16 | 17 | #Create cache directory 18 | mkdir -p /tmp/.cache/huggingface 19 | 20 | # Create logs directory if it doesn't exist 21 | mkdir -p logs 22 | 23 | # Print some debug information 24 | echo "Job ID: $SLURM_JOB_ID" 25 | echo "Running on node: $SLURMD_NODENAME" 26 | echo "Available GPUs: $CUDA_VISIBLE_DEVICES" 27 | nvidia-smi 28 | 29 | # Default to default.yaml and machine index 0 30 | CONFIG_PATH=${1:-configs/default.yaml} 31 | MACHINE_INDEX=${2:-0} 32 | 33 | # Run the data generation script with machine index 34 | srun python stash.py --config $CONFIG_PATH --machine $MACHINE_INDEX 35 | 36 | # Optional: Copy output files to a backup location 37 | # rsync -av outputs/ /scratch/$USER/activault_outputs/ 38 | -------------------------------------------------------------------------------- /scripts/run_ray_jobs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Run activation collection jobs on a Ray cluster. 4 | 5 | This script provides a parallel alternative to the Slurm-based run_jobs.sh, 6 | allowing Activault to be deployed on Ray clusters without modifying the core functionality. 7 | 8 | Usage: 9 | python scripts/run_ray_jobs.py configs/llama3.3_70b.yaml 8 0 7 --address ray://ray-head:10001 --resources '{"CPU": 32, "GPU": 2}' --wait 10 | 11 | Arguments: 12 | config_path: Path to configuration file 13 | num_total_runs: Total number of workers to run 14 | start_idx: Index of first worker (default: 0) 15 | end_idx: Index of last worker (default: num_total_runs-1) 16 | 17 | Options: 18 | --address: Ray cluster address (default: auto-detect) 19 | --resources: JSON string with resource requirements per worker 20 | --wait: Wait for jobs to complete before exiting 21 | """ 22 | 23 | import os 24 | import sys 25 | import json 26 | import argparse 27 | import logging 28 | from pathlib import Path 29 | import yaml 30 | from datetime import datetime 31 | 32 | # Add the project root to the path 33 | sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) 34 | 35 | # Import Ray utilities 36 | from pipeline.utils.ray_utils import ( 37 | launch_ray_jobs, 38 | monitor_ray_jobs, 39 | calculate_resources_per_worker, 40 | ) 41 | 42 | # Configure logging 43 | logging.basicConfig( 44 | level=logging.INFO, 45 | format="[%(asctime)s] %(levelname)s: %(message)s", 46 | handlers=[ 47 | logging.StreamHandler(), 48 | ], 49 | ) 50 | logger = logging.getLogger(__name__) 51 | 52 | 53 | def parse_args(): 54 | """Parse command line arguments.""" 55 | parser = argparse.ArgumentParser(description="Run activation collection jobs on a Ray cluster") 56 | 57 | parser.add_argument("config_path", type=str, help="Path to configuration file") 58 | parser.add_argument("num_total_runs", type=int, help="Total number of workers to run") 59 | parser.add_argument( 60 | "start_idx", type=int, default=0, nargs="?", help="Index of first worker (default: 0)" 61 | ) 62 | parser.add_argument( 63 | "end_idx", 64 | type=int, 65 | default=None, 66 | nargs="?", 67 | help="Index of last worker (default: num_total_runs-1)", 68 | ) 69 | 70 | parser.add_argument("--address", type=str, default=None, help="Ray cluster address") 71 | parser.add_argument( 72 | "--resources", 73 | type=str, 74 | default=None, 75 | help='JSON string with resource requirements (e.g. \'{"CPU": 4, "GPU": 1}\')', 76 | ) 77 | parser.add_argument( 78 | "--wait", action="store_true", help="Wait for jobs to complete before exiting" 79 | ) 80 | 81 | return parser.parse_args() 82 | 83 | 84 | def load_config(config_path): 85 | """Load configuration from YAML file.""" 86 | with open(config_path, "r") as f: 87 | if config_path.endswith(".yaml") or config_path.endswith(".yml"): 88 | return yaml.safe_load(f) 89 | elif config_path.endswith(".json"): 90 | return json.load(f) 91 | else: 92 | raise ValueError(f"Unsupported config file format: {config_path}") 93 | 94 | 95 | def main(): 96 | """Main entry point for Ray job submission.""" 97 | args = parse_args() 98 | 99 | # Create logs directory if needed 100 | os.makedirs("logs", exist_ok=True) 101 | 102 | # Configure file logging 103 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 104 | log_file = f"logs/ray_jobs_{timestamp}.log" 105 | file_handler = logging.FileHandler(log_file) 106 | file_handler.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s")) 107 | logger.addHandler(file_handler) 108 | 109 | # Set end_idx to (num_total_runs - 1) if not specified 110 | end_idx = args.end_idx if args.end_idx is not None else args.num_total_runs - 1 111 | 112 | # Load config to get resource requirements 113 | config = load_config(args.config_path) 114 | 115 | # Parse resources JSON if provided 116 | resources_per_worker = None 117 | if args.resources: 118 | try: 119 | resources_per_worker = json.loads(args.resources) 120 | except json.JSONDecodeError as e: 121 | logger.error(f"Failed to parse resources JSON: {e}") 122 | sys.exit(1) 123 | else: 124 | # Get resources from config or use defaults 125 | ray_config = config.get("ray_config", {}) 126 | resources_per_worker = calculate_resources_per_worker(ray_config) 127 | 128 | logger.info( 129 | f"Launching jobs {args.start_idx} to {end_idx} (of total {args.num_total_runs} runs)" 130 | ) 131 | logger.info(f"Using config: {args.config_path}") 132 | logger.info(f"Ray address: {args.address or 'auto-detect'}") 133 | logger.info(f"Resources per worker: {resources_per_worker}") 134 | 135 | # Launch jobs 136 | try: 137 | job_refs = launch_ray_jobs( 138 | args.config_path, 139 | args.num_total_runs, 140 | args.start_idx, 141 | end_idx, 142 | args.address, 143 | resources_per_worker, 144 | ) 145 | 146 | logger.info(f"Launched {len(job_refs)} Ray jobs") 147 | 148 | # Write job information to log file 149 | job_log_file = f"logs/ray_jobs_info_{timestamp}.txt" 150 | with open(job_log_file, "w") as f: 151 | f.write("MACHINE | JOB_ID\n") 152 | f.write("---------------\n") 153 | for i, job_ref in enumerate(job_refs, start=args.start_idx): 154 | f.write(f"{i} | {job_ref}\n") 155 | 156 | logger.info(f"Job mapping saved to: {job_log_file}") 157 | 158 | # Wait for jobs to complete if requested 159 | if args.wait: 160 | logger.info("Waiting for jobs to complete...") 161 | results = monitor_ray_jobs(job_refs) 162 | 163 | # Log completion status 164 | success_count = sum(1 for r in results if r.get("status") == "completed") 165 | failed_count = sum(1 for r in results if r.get("status") == "failed") 166 | 167 | logger.info(f"Job completion: {success_count} succeeded, {failed_count} failed") 168 | 169 | # Write detailed results 170 | results_file = f"logs/ray_results_{timestamp}.json" 171 | with open(results_file, "w") as f: 172 | json.dump(results, f, indent=2) 173 | 174 | logger.info(f"Detailed results saved to: {results_file}") 175 | 176 | except Exception as e: 177 | logger.error(f"Error launching Ray jobs: {e}") 178 | sys.exit(1) 179 | 180 | 181 | if __name__ == "__main__": 182 | main() 183 | -------------------------------------------------------------------------------- /scripts/run_slurm_jobs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Usage: ./run_all.sh 4 | # Example: ./run_all.sh configs/default.yaml 8 2 5 5 | # - Will run jobs for indices 2,3,4,5 out of total 8 runs 6 | 7 | # Check if required arguments are provided 8 | if [ "$#" -lt 4 ]; then 9 | echo "Usage: $0 " 10 | echo "Example: $0 configs/llama3.3_70b.yaml 8 2 5" 11 | exit 1 12 | fi 13 | 14 | CONFIG_PATH=$1 15 | NUM_TOTAL_RUNS=$2 16 | START_IDX=$3 17 | END_IDX=$4 18 | 19 | # Validate indices 20 | if [ $START_IDX -ge $NUM_TOTAL_RUNS ] || [ $END_IDX -ge $NUM_TOTAL_RUNS ] || [ $START_IDX -gt $END_IDX ]; then 21 | echo "Error: Invalid index range. Ensure start_idx and end_idx are within [0, num_total_runs-1]" 22 | exit 1 23 | fi 24 | 25 | TIMESTAMP=$(date +"%Y%m%d_%H%M%S") 26 | LOG_FILE="logs/run_all_${TIMESTAMP}.txt" 27 | 28 | # Create logs directory if it doesn't exist 29 | mkdir -p logs 30 | 31 | echo "IMPORTANT: Please ensure `n_runs` in the config file is set to $NUM_TOTAL_RUNS." 32 | 33 | echo "Launching jobs $START_IDX to $END_IDX (of total $NUM_TOTAL_RUNS runs) using config: $CONFIG_PATH" 34 | echo "Logging job mapping to: $LOG_FILE" 35 | 36 | # Create header for log file 37 | echo "MACHINE | JOB" > "$LOG_FILE" 38 | echo "---------------" >> "$LOG_FILE" 39 | 40 | # Launch jobs for the specified range 41 | for ((i=$START_IDX; i<=$END_IDX; i++)); do 42 | echo "Launching job $((i+1))/$NUM_TOTAL_RUNS with machine_index $i" 43 | # Capture the job ID from sbatch output 44 | JOB_ID=$(sbatch scripts/collect.slurm "$CONFIG_PATH" "$i" | grep -o '[0-9]\+') 45 | echo "$i | $JOB_ID" >> "$LOG_FILE" 46 | sleep 1 # Small delay to prevent potential race conditions 47 | done 48 | 49 | echo -e "\nAll jobs submitted. Use 'squeue' to check their status." 50 | echo "Job mapping saved to: $LOG_FILE" -------------------------------------------------------------------------------- /stash.py: -------------------------------------------------------------------------------- 1 | """Copyright (2025) Tilde Research Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | """ 15 | 16 | from pipeline.data.dataset import load_dataset_by_key 17 | from pipeline.data.dataloader import DataLoader 18 | from pipeline.config import Config 19 | from pipeline.setup import ( 20 | setup_model_and_tokenizer, 21 | setup_uploaders, 22 | calculate_machine_params, 23 | display_job_stats, 24 | maybe_add_mlp_attn_hooks, 25 | ) 26 | from pipeline.generate import generate_activations 27 | import logging 28 | import time 29 | 30 | logging.basicConfig( 31 | level=logging.INFO, format="[%(asctime)s] [Worker-%(process)d] %(levelname)s: %(message)s" 32 | ) 33 | logger = logging.getLogger(__name__) 34 | 35 | 36 | def main(): 37 | # Load config from command line args or default 38 | config = Config.from_args() 39 | 40 | # Calculate machine-specific parameters 41 | batches_per_machine, start_batch_skip = calculate_machine_params( 42 | config.machine_index, config.data_config["n_batches"], config.num_runs 43 | ) 44 | 45 | # Add the global start_batch to the machine-specific start_batch 46 | start_batch_skip += config.data_config["start_batch"] 47 | 48 | # Load dataset 49 | dataset = load_dataset_by_key(config.data_config["data_key"]) 50 | 51 | # Setup model and tokenizer with hooks for truncation 52 | model, tokenizer, hooks = setup_model_and_tokenizer( 53 | config.transformer_config, hooks=config.upload_config.get("hooks") 54 | ) 55 | logger.info(f"Model loaded: {model}") 56 | 57 | added_hooks, added_hook_activations = maybe_add_mlp_attn_hooks(model, hooks) 58 | 59 | config.upload_config["hooks"] = hooks 60 | 61 | # Display job statistics 62 | display_job_stats(model, hooks, config, batches_per_machine) 63 | 64 | # Setup uploaders (one per hook) 65 | uploaders = setup_uploaders( 66 | run_name=config.run_name, 67 | hooks=hooks, 68 | batches_per_upload=config.upload_config["batches_per_upload"], 69 | bucket_name=config.data_config["bucket_name"], 70 | ) 71 | logger.info(f"Uploaders loaded: {uploaders}") 72 | 73 | # Create dataloader 74 | loader = DataLoader( 75 | dataset=dataset, 76 | tokenizer=tokenizer, 77 | max_length=config.data_config["seq_length"], 78 | batch_size=config.data_config["batch_size"], 79 | start_batch_skip=start_batch_skip, 80 | batches_per_machine=batches_per_machine, 81 | dataset_key=config.data_config["data_key"], 82 | skip_cache=config.data_config["skip_cache"], 83 | clean_added_tokens=config.data_config["clean_added_tokens"], 84 | ) 85 | 86 | # Run generation 87 | generate_activations(model, loader, config, uploaders, added_hook_activations) 88 | 89 | # just in case... 90 | logger.info( 91 | "DONE generating activations. Waiting 5 minutes before exiting, in case there are any remaining uploads." 92 | ) 93 | time.sleep(5 * 60) 94 | 95 | 96 | if __name__ == "__main__": 97 | main() 98 | --------------------------------------------------------------------------------