├── .github └── overview.png ├── LICENSE.txt ├── NOTICE ├── README.md ├── kaggle_notebooks ├── arc-prize-2024_kaggle.ipynb ├── arc-prize-2024_updated.ipynb └── unsloth-download-2024-9-post4.ipynb ├── params.json ├── the_architects.pdf └── training_code ├── arc_downloader.py ├── arc_loader.py ├── inference_tools.py ├── model_tools.py ├── run_evaluation_Llama-rearc_with_ttt.py ├── run_evaluation_Llama-rearc_without_ttt.py ├── run_finetuning_Llama-rearc.py ├── run_finetuning_Nemo-full.py └── selection.py /.github/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/da-fr/arc-prize-2024/bfc329123b2e0379395852323cf2862d7ae0d108/.github/overview.png -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Kaggle ARC Prize 2024 Submission Code 2 | Copyright 2024 Daniel Franzen and Jan Disselhoff 3 | 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![image](https://raw.githubusercontent.com/da-fr/arc-prize-2024/master/.github/overview.png) 2 | 3 | 4 | This repo contains the code we used for our Kaggle ARC Prize 2024 submission. For an in-depth overview of our method, please take a look at our [paper](https://da-fr.github.io/arc-prize-2024/the_architects.pdf). 5 | 6 | Under `training_code`, you can find our locally executable code that we used to prepare our models. The main entry points are named `run_finetuning_[model].py` for initial finetuning or `run_evaluation_[model].py` for starting an inference run with test-time-training, simulating a kaggle submission. In either case, we first load model and data, then augment our dataset. Afterwards a training run starts. In the latter case, the resulting model is evaluated using our augmentation and scoring strategies. Our training code requires the `unsloth` package and its dependencies to be installed. For evaluation, the `diskcache` package is required for caching the results of inference and score calculation. 7 | 8 | For retraining our winning submission's base model scoring 53.5 points in the Kaggle ARC Prize 2024 Contest, run the `run_finetune_Nemo-full.py`. The datasets used in the training process must be placed in the input folder (see the beginning of the run-file itself for details). The trained model is also available for download on huggingface as [Mistral-NeMo-Minitron-8B-ARChitects-Full-bnb-4bit](https://huggingface.co/da-fr/Mistral-NeMo-Minitron-8B-ARChitects-Full-bnb-4bit). 9 | 10 | Under `kaggle_notebooks`, you can find our notebooks for kaggle. The notebook `arc-prize-2024_kaggle.ipynb` contains the original kaggle submission scoring `53.5` points on the hidden test set. As the competition did not allow internet access, this notebook uses an offline dataset containing various python wheels (which can be created by executing the notebook `unsloth-download-2024-9-post4.ipynb` and creating a dataset from its output). This notebook, including the offline python wheel dataset and the pretrained model, is also available directly [on kaggle](https://www.kaggle.com/code/dfranzen/arc-prize-2024-solution-by-the-architects). The notebook `arc-prize-2024_updated.ipynb` contains an updated version which can download the required packages directly from the internet using pip, and can also be run locally in jupyter (this requires the `unsloth` package to be installed). 11 | 12 | We trained all our models on a single `Nvidia H100` GPU. If you run into memory problems, we suggest reducing batch size and/or the `max_tokens` value. Using a batch size of `2` should allow finetuning `Mistral-NeMo-Minitron-8B-Base` on GPUs with 24 GB memory. 13 | 14 | Here is a rough overview of our files and classes: 15 | 16 | ## Files 17 | 18 | #### `arc_loader.py` 19 | - **Purpose**: Handles all Data formatting and loading 20 | - **Capabilities**: 21 | - Class `ArcDataset` which handles all data set related tasks, e.g.: 22 | - Building datasets from various sources. 23 | - Modifying, shuffling, and augmenting examples. 24 | - Splitting, sorting, and filtering examples. 25 | - Handling dataset keys, challenges and solutions. 26 | - Preparing the data for tokenization. 27 | - Creating and verifying submissions. 28 | 29 | #### `model_tools.py` 30 | - **Purpose**: Contains code for loading, saving and manipulating models 31 | - **Capabilities**: 32 | - Load and Save Model and LoRA adapters 33 | - Shrink Tokenizer and Embedding Layers 34 | - Data Collator for masking the task inputs and the first output 35 | 36 | #### `inference_tools.py` 37 | - **Purpose**: Contains tools for inference and scoring 38 | - **Capabilities**: 39 | - Inference code, including our custom DFS 40 | - Score calculation 41 | 42 | #### `selection.py` 43 | - **Purpose**: Contains functions used to select best answer from different Candidates 44 | - **Capabilities**: 45 | - Various score aggregation methods 46 | - Sorting candidates by their score for later submission generation 47 | - Class `EvalTool` for doing above tasks on-the-fly and printing results 48 | 49 | #### `run_finetuning_[model].py` 50 | - **Purpose**: Run the initial finetuning process. 51 | - **Required packages**: `unsloth` 52 | - **Steps**: 53 | - Load the base model and reduce embedding size. 54 | - Load and augment training data. 55 | - Create a lora adapter and execute training. 56 | - Save the trained lora adapter. 57 | - Merge the lora model into the base model and save as final model. 58 | 59 | #### `run_evaluation_[model].py` 60 | - **Purpose**: Run inference (simuating a kaggle submission). 61 | - **Required packages**: `unsloth` and `diskcache` 62 | - **Steps**: 63 | - Load the finetuned model. 64 | - Possibly perform test-time-training on the evaluation set's examples. 65 | - Save the trained lora adapter for later use. 66 | - Run inference on the evaluation set. 67 | - Write a `submission.json` file. 68 | - Reload and verify the submission file. 69 | 70 | ## License 71 | 72 | Our code is available under the Apache 2.0 license. See the [LICENSE.txt](LICENSE.txt) file for more info. 73 | 74 | -------------------------------------------------------------------------------- /kaggle_notebooks/unsloth-download-2024-9-post4.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "c43edb3e", 7 | "metadata": { 8 | "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", 9 | "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", 10 | "execution": { 11 | "iopub.execute_input": "2024-11-22T14:09:32.106178Z", 12 | "iopub.status.busy": "2024-11-22T14:09:32.105740Z", 13 | "iopub.status.idle": "2024-11-22T14:10:11.710847Z", 14 | "shell.execute_reply": "2024-11-22T14:10:11.709681Z" 15 | }, 16 | "papermill": { 17 | "duration": 39.611595, 18 | "end_time": "2024-11-22T14:10:11.713490", 19 | "exception": false, 20 | "start_time": "2024-11-22T14:09:32.101895", 21 | "status": "completed" 22 | }, 23 | "tags": [] 24 | }, 25 | "outputs": [ 26 | { 27 | "name": "stdout", 28 | "output_type": "stream", 29 | "text": [ 30 | "Requirement already satisfied: wheel in /opt/conda/lib/python3.10/site-packages (0.43.0)\r\n", 31 | "Collecting wheel\r\n", 32 | " Downloading wheel-0.45.0-py3-none-any.whl.metadata (2.3 kB)\r\n", 33 | "Requirement already satisfied: pip in /opt/conda/lib/python3.10/site-packages (24.0)\r\n", 34 | "Collecting pip\r\n", 35 | " Downloading pip-24.3.1-py3-none-any.whl.metadata (3.7 kB)\r\n", 36 | "Downloading wheel-0.45.0-py3-none-any.whl (72 kB)\r\n", 37 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m72.5/72.5 kB\u001b[0m \u001b[31m1.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 38 | "\u001b[?25hDownloading pip-24.3.1-py3-none-any.whl (1.8 MB)\r\n", 39 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m24.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 40 | "\u001b[?25hInstalling collected packages: wheel, pip\r\n", 41 | " Attempting uninstall: wheel\r\n", 42 | " Found existing installation: wheel 0.43.0\r\n", 43 | " Uninstalling wheel-0.43.0:\r\n", 44 | " Successfully uninstalled wheel-0.43.0\r\n", 45 | " Attempting uninstall: pip\r\n", 46 | " Found existing installation: pip 24.0\r\n", 47 | " Uninstalling pip-24.0:\r\n", 48 | " Successfully uninstalled pip-24.0\r\n", 49 | "Successfully installed pip-24.3.1 wheel-0.45.0\r\n" 50 | ] 51 | } 52 | ], 53 | "source": [ 54 | "!pip install wheel pip --upgrade" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 2, 60 | "id": "d7aa94e1", 61 | "metadata": { 62 | "execution": { 63 | "iopub.execute_input": "2024-11-22T14:10:11.721390Z", 64 | "iopub.status.busy": "2024-11-22T14:10:11.720982Z", 65 | "iopub.status.idle": "2024-11-22T14:10:34.509555Z", 66 | "shell.execute_reply": "2024-11-22T14:10:34.508131Z" 67 | }, 68 | "papermill": { 69 | "duration": 22.795625, 70 | "end_time": "2024-11-22T14:10:34.512262", 71 | "exception": false, 72 | "start_time": "2024-11-22T14:10:11.716637", 73 | "status": "completed" 74 | }, 75 | "tags": [] 76 | }, 77 | "outputs": [ 78 | { 79 | "name": "stdout", 80 | "output_type": "stream", 81 | "text": [ 82 | "Found existing installation: accelerate 0.33.0\r\n", 83 | "Uninstalling accelerate-0.33.0:\r\n", 84 | " Successfully uninstalled accelerate-0.33.0\r\n", 85 | "Found existing installation: torch 2.4.0+cpu\r\n", 86 | "Uninstalling torch-2.4.0+cpu:\r\n", 87 | " Successfully uninstalled torch-2.4.0+cpu\r\n" 88 | ] 89 | } 90 | ], 91 | "source": [ 92 | "!pip uninstall --yes accelerate torch" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 3, 98 | "id": "987c05ab", 99 | "metadata": { 100 | "execution": { 101 | "iopub.execute_input": "2024-11-22T14:10:34.520261Z", 102 | "iopub.status.busy": "2024-11-22T14:10:34.519824Z", 103 | "iopub.status.idle": "2024-11-22T14:11:46.420637Z", 104 | "shell.execute_reply": "2024-11-22T14:11:46.418946Z" 105 | }, 106 | "papermill": { 107 | "duration": 71.907985, 108 | "end_time": "2024-11-22T14:11:46.423269", 109 | "exception": false, 110 | "start_time": "2024-11-22T14:10:34.515284", 111 | "status": "completed" 112 | }, 113 | "tags": [] 114 | }, 115 | "outputs": [ 116 | { 117 | "name": "stdout", 118 | "output_type": "stream", 119 | "text": [ 120 | "Collecting unsloth==2024.9.post4\r\n", 121 | " Downloading unsloth-2024.9.post4-py3-none-any.whl.metadata (56 kB)\r\n", 122 | "Collecting torch==2.4.1\r\n", 123 | " Downloading torch-2.4.1-cp310-cp310-manylinux1_x86_64.whl.metadata (26 kB)\r\n", 124 | "Collecting xformers>=0.0.27.post2 (from unsloth==2024.9.post4)\r\n", 125 | " Downloading xformers-0.0.28.post3-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (1.0 kB)\r\n", 126 | "Collecting bitsandbytes (from unsloth==2024.9.post4)\r\n", 127 | " Downloading bitsandbytes-0.44.1-py3-none-manylinux_2_24_x86_64.whl.metadata (3.5 kB)\r\n", 128 | "Collecting triton>=3.0.0 (from unsloth==2024.9.post4)\r\n", 129 | " Downloading triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.3 kB)\r\n", 130 | "Collecting packaging (from unsloth==2024.9.post4)\r\n", 131 | " Downloading packaging-24.2-py3-none-any.whl.metadata (3.2 kB)\r\n", 132 | "Collecting tyro (from unsloth==2024.9.post4)\r\n", 133 | " Downloading tyro-0.9.1-py3-none-any.whl.metadata (9.3 kB)\r\n", 134 | "Collecting transformers<4.45.0 (from unsloth==2024.9.post4)\r\n", 135 | " Downloading transformers-4.44.2-py3-none-any.whl.metadata (43 kB)\r\n", 136 | "Collecting datasets>=2.16.0 (from unsloth==2024.9.post4)\r\n", 137 | " Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)\r\n", 138 | "Collecting sentencepiece>=0.2.0 (from unsloth==2024.9.post4)\r\n", 139 | " Downloading sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.7 kB)\r\n", 140 | "Collecting tqdm (from unsloth==2024.9.post4)\r\n", 141 | " Downloading tqdm-4.67.0-py3-none-any.whl.metadata (57 kB)\r\n", 142 | "Collecting psutil (from unsloth==2024.9.post4)\r\n", 143 | " Downloading psutil-6.1.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (22 kB)\r\n", 144 | "Collecting wheel>=0.42.0 (from unsloth==2024.9.post4)\r\n", 145 | " Using cached wheel-0.45.0-py3-none-any.whl.metadata (2.3 kB)\r\n", 146 | "Collecting numpy (from unsloth==2024.9.post4)\r\n", 147 | " Downloading numpy-2.1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (62 kB)\r\n", 148 | "Collecting accelerate>=0.34.1 (from unsloth==2024.9.post4)\r\n", 149 | " Downloading accelerate-1.1.1-py3-none-any.whl.metadata (19 kB)\r\n", 150 | "Collecting trl!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,<=0.11.1,>=0.7.9 (from unsloth==2024.9.post4)\r\n", 151 | " Downloading trl-0.11.1-py3-none-any.whl.metadata (12 kB)\r\n", 152 | "Collecting peft!=0.11.0,>=0.7.1 (from unsloth==2024.9.post4)\r\n", 153 | " Downloading peft-0.13.2-py3-none-any.whl.metadata (13 kB)\r\n", 154 | "Collecting protobuf<4.0.0 (from unsloth==2024.9.post4)\r\n", 155 | " Downloading protobuf-3.20.3-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (679 bytes)\r\n", 156 | "Collecting huggingface-hub (from unsloth==2024.9.post4)\r\n", 157 | " Downloading huggingface_hub-0.26.2-py3-none-any.whl.metadata (13 kB)\r\n", 158 | "Collecting hf-transfer (from unsloth==2024.9.post4)\r\n", 159 | " Downloading hf_transfer-0.1.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.7 kB)\r\n", 160 | "Collecting filelock (from torch==2.4.1)\r\n", 161 | " Downloading filelock-3.16.1-py3-none-any.whl.metadata (2.9 kB)\r\n", 162 | "Collecting typing-extensions>=4.8.0 (from torch==2.4.1)\r\n", 163 | " Downloading typing_extensions-4.12.2-py3-none-any.whl.metadata (3.0 kB)\r\n", 164 | "Collecting sympy (from torch==2.4.1)\r\n", 165 | " Downloading sympy-1.13.3-py3-none-any.whl.metadata (12 kB)\r\n", 166 | "Collecting networkx (from torch==2.4.1)\r\n", 167 | " Downloading networkx-3.4.2-py3-none-any.whl.metadata (6.3 kB)\r\n", 168 | "Collecting jinja2 (from torch==2.4.1)\r\n", 169 | " Downloading jinja2-3.1.4-py3-none-any.whl.metadata (2.6 kB)\r\n", 170 | "Collecting fsspec (from torch==2.4.1)\r\n", 171 | " Downloading fsspec-2024.10.0-py3-none-any.whl.metadata (11 kB)\r\n", 172 | "Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch==2.4.1)\r\n", 173 | " Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\r\n", 174 | "Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch==2.4.1)\r\n", 175 | " Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\r\n", 176 | "Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch==2.4.1)\r\n", 177 | " Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\r\n", 178 | "Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch==2.4.1)\r\n", 179 | " Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\r\n", 180 | "Collecting nvidia-cublas-cu12==12.1.3.1 (from torch==2.4.1)\r\n", 181 | " Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\r\n", 182 | "Collecting nvidia-cufft-cu12==11.0.2.54 (from torch==2.4.1)\r\n", 183 | " Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\r\n", 184 | "Collecting nvidia-curand-cu12==10.3.2.106 (from torch==2.4.1)\r\n", 185 | " Downloading nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\r\n", 186 | "Collecting nvidia-cusolver-cu12==11.4.5.107 (from torch==2.4.1)\r\n", 187 | " Downloading nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\r\n", 188 | "Collecting nvidia-cusparse-cu12==12.1.0.106 (from torch==2.4.1)\r\n", 189 | " Downloading nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\r\n", 190 | "Collecting nvidia-nccl-cu12==2.20.5 (from torch==2.4.1)\r\n", 191 | " Downloading nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl.metadata (1.8 kB)\r\n", 192 | "Collecting nvidia-nvtx-cu12==12.1.105 (from torch==2.4.1)\r\n", 193 | " Downloading nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.7 kB)\r\n", 194 | "Collecting triton>=3.0.0 (from unsloth==2024.9.post4)\r\n", 195 | " Downloading triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.3 kB)\r\n", 196 | "Collecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch==2.4.1)\r\n", 197 | " Downloading nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl.metadata (1.5 kB)\r\n", 198 | "Collecting pyyaml (from accelerate>=0.34.1->unsloth==2024.9.post4)\r\n", 199 | " Downloading PyYAML-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.1 kB)\r\n", 200 | "Collecting safetensors>=0.4.3 (from accelerate>=0.34.1->unsloth==2024.9.post4)\r\n", 201 | " Downloading safetensors-0.4.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)\r\n", 202 | "Collecting pyarrow>=15.0.0 (from datasets>=2.16.0->unsloth==2024.9.post4)\r\n", 203 | " Downloading pyarrow-18.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)\r\n", 204 | "Collecting dill<0.3.9,>=0.3.0 (from datasets>=2.16.0->unsloth==2024.9.post4)\r\n", 205 | " Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)\r\n", 206 | "Collecting pandas (from datasets>=2.16.0->unsloth==2024.9.post4)\r\n", 207 | " Downloading pandas-2.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (89 kB)\r\n", 208 | "Collecting requests>=2.32.2 (from datasets>=2.16.0->unsloth==2024.9.post4)\r\n", 209 | " Downloading requests-2.32.3-py3-none-any.whl.metadata (4.6 kB)\r\n", 210 | "Collecting xxhash (from datasets>=2.16.0->unsloth==2024.9.post4)\r\n", 211 | " Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)\r\n", 212 | "Collecting multiprocess<0.70.17 (from datasets>=2.16.0->unsloth==2024.9.post4)\r\n", 213 | " Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)\r\n", 214 | "Collecting fsspec (from torch==2.4.1)\r\n", 215 | " Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)\r\n", 216 | "Collecting aiohttp (from datasets>=2.16.0->unsloth==2024.9.post4)\r\n", 217 | " Downloading aiohttp-3.11.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.7 kB)\r\n", 218 | "Collecting regex!=2019.12.17 (from transformers<4.45.0->unsloth==2024.9.post4)\r\n", 219 | " Downloading regex-2024.11.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)\r\n", 220 | "Collecting tokenizers<0.20,>=0.19 (from transformers<4.45.0->unsloth==2024.9.post4)\r\n", 221 | " Downloading tokenizers-0.19.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)\r\n", 222 | "Collecting docstring-parser>=0.16 (from tyro->unsloth==2024.9.post4)\r\n", 223 | " Downloading docstring_parser-0.16-py3-none-any.whl.metadata (3.0 kB)\r\n", 224 | "Collecting rich>=11.1.0 (from tyro->unsloth==2024.9.post4)\r\n", 225 | " Downloading rich-13.9.4-py3-none-any.whl.metadata (18 kB)\r\n", 226 | "Collecting shtab>=1.5.6 (from tyro->unsloth==2024.9.post4)\r\n", 227 | " Downloading shtab-1.7.1-py3-none-any.whl.metadata (7.3 kB)\r\n", 228 | "INFO: pip is looking at multiple versions of xformers to determine which version is compatible with other requirements. This could take a while.\r\n", 229 | "Collecting xformers>=0.0.27.post2 (from unsloth==2024.9.post4)\r\n", 230 | " Downloading xformers-0.0.28.post2-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (1.0 kB)\r\n", 231 | " Downloading xformers-0.0.28.post1-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (1.0 kB)\r\n", 232 | "Collecting MarkupSafe>=2.0 (from jinja2->torch==2.4.1)\r\n", 233 | " Downloading MarkupSafe-3.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.0 kB)\r\n", 234 | "Collecting mpmath<1.4,>=1.1.0 (from sympy->torch==2.4.1)\r\n", 235 | " Downloading mpmath-1.3.0-py3-none-any.whl.metadata (8.6 kB)\r\n", 236 | "Collecting aiohappyeyeballs>=2.3.0 (from aiohttp->datasets>=2.16.0->unsloth==2024.9.post4)\r\n", 237 | " Downloading aiohappyeyeballs-2.4.3-py3-none-any.whl.metadata (6.1 kB)\r\n", 238 | "Collecting aiosignal>=1.1.2 (from aiohttp->datasets>=2.16.0->unsloth==2024.9.post4)\r\n", 239 | " Downloading aiosignal-1.3.1-py3-none-any.whl.metadata (4.0 kB)\r\n", 240 | "Collecting async-timeout<6.0,>=4.0 (from aiohttp->datasets>=2.16.0->unsloth==2024.9.post4)\r\n", 241 | " Downloading async_timeout-5.0.1-py3-none-any.whl.metadata (5.1 kB)\r\n", 242 | "Collecting attrs>=17.3.0 (from aiohttp->datasets>=2.16.0->unsloth==2024.9.post4)\r\n", 243 | " Downloading attrs-24.2.0-py3-none-any.whl.metadata (11 kB)\r\n", 244 | "Collecting frozenlist>=1.1.1 (from aiohttp->datasets>=2.16.0->unsloth==2024.9.post4)\r\n", 245 | " Downloading frozenlist-1.5.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)\r\n", 246 | "Collecting multidict<7.0,>=4.5 (from aiohttp->datasets>=2.16.0->unsloth==2024.9.post4)\r\n", 247 | " Downloading multidict-6.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.0 kB)\r\n", 248 | "Collecting propcache>=0.2.0 (from aiohttp->datasets>=2.16.0->unsloth==2024.9.post4)\r\n", 249 | " Downloading propcache-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.7 kB)\r\n", 250 | "Collecting yarl<2.0,>=1.17.0 (from aiohttp->datasets>=2.16.0->unsloth==2024.9.post4)\r\n", 251 | " Downloading yarl-1.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (67 kB)\r\n", 252 | "Collecting charset-normalizer<4,>=2 (from requests>=2.32.2->datasets>=2.16.0->unsloth==2024.9.post4)\r\n", 253 | " Downloading charset_normalizer-3.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (34 kB)\r\n", 254 | "Collecting idna<4,>=2.5 (from requests>=2.32.2->datasets>=2.16.0->unsloth==2024.9.post4)\r\n", 255 | " Downloading idna-3.10-py3-none-any.whl.metadata (10 kB)\r\n", 256 | "Collecting urllib3<3,>=1.21.1 (from requests>=2.32.2->datasets>=2.16.0->unsloth==2024.9.post4)\r\n", 257 | " Downloading urllib3-2.2.3-py3-none-any.whl.metadata (6.5 kB)\r\n", 258 | "Collecting certifi>=2017.4.17 (from requests>=2.32.2->datasets>=2.16.0->unsloth==2024.9.post4)\r\n", 259 | " Downloading certifi-2024.8.30-py3-none-any.whl.metadata (2.2 kB)\r\n", 260 | "Collecting markdown-it-py>=2.2.0 (from rich>=11.1.0->tyro->unsloth==2024.9.post4)\r\n", 261 | " Downloading markdown_it_py-3.0.0-py3-none-any.whl.metadata (6.9 kB)\r\n", 262 | "Collecting pygments<3.0.0,>=2.13.0 (from rich>=11.1.0->tyro->unsloth==2024.9.post4)\r\n", 263 | " Downloading pygments-2.18.0-py3-none-any.whl.metadata (2.5 kB)\r\n", 264 | "Collecting python-dateutil>=2.8.2 (from pandas->datasets>=2.16.0->unsloth==2024.9.post4)\r\n", 265 | " Downloading python_dateutil-2.9.0.post0-py2.py3-none-any.whl.metadata (8.4 kB)\r\n", 266 | "Collecting pytz>=2020.1 (from pandas->datasets>=2.16.0->unsloth==2024.9.post4)\r\n", 267 | " Downloading pytz-2024.2-py2.py3-none-any.whl.metadata (22 kB)\r\n", 268 | "Collecting tzdata>=2022.7 (from pandas->datasets>=2.16.0->unsloth==2024.9.post4)\r\n", 269 | " Downloading tzdata-2024.2-py2.py3-none-any.whl.metadata (1.4 kB)\r\n", 270 | "Collecting mdurl~=0.1 (from markdown-it-py>=2.2.0->rich>=11.1.0->tyro->unsloth==2024.9.post4)\r\n", 271 | " Downloading mdurl-0.1.2-py3-none-any.whl.metadata (1.6 kB)\r\n", 272 | "Collecting six>=1.5 (from python-dateutil>=2.8.2->pandas->datasets>=2.16.0->unsloth==2024.9.post4)\r\n", 273 | " Downloading six-1.16.0-py2.py3-none-any.whl.metadata (1.8 kB)\r\n", 274 | "Downloading unsloth-2024.9.post4-py3-none-any.whl (165 kB)\r\n", 275 | "Downloading torch-2.4.1-cp310-cp310-manylinux1_x86_64.whl (797.1 MB)\r\n", 276 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m797.1/797.1 MB\u001b[0m \u001b[31m40.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 277 | "\u001b[?25hDownloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\r\n", 278 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m410.6/410.6 MB\u001b[0m \u001b[31m38.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 279 | "\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\r\n", 280 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.1/14.1 MB\u001b[0m \u001b[31m104.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 281 | "\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\r\n", 282 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m23.7/23.7 MB\u001b[0m \u001b[31m119.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 283 | "\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\r\n", 284 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m823.6/823.6 kB\u001b[0m \u001b[31m24.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 285 | "\u001b[?25hDownloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\r\n", 286 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m41.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 287 | "\u001b[?25hDownloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\r\n", 288 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m121.6/121.6 MB\u001b[0m \u001b[31m81.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 289 | "\u001b[?25hDownloading nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\r\n", 290 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.5/56.5 MB\u001b[0m \u001b[31m19.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 291 | "\u001b[?25hDownloading nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\r\n", 292 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m124.2/124.2 MB\u001b[0m \u001b[31m46.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 293 | "\u001b[?25hDownloading nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\r\n", 294 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m196.0/196.0 MB\u001b[0m \u001b[31m98.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 295 | "\u001b[?25hDownloading nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl (176.2 MB)\r\n", 296 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m176.2/176.2 MB\u001b[0m \u001b[31m101.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 297 | "\u001b[?25hDownloading nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)\r\n", 298 | "Downloading triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (209.4 MB)\r\n", 299 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m209.4/209.4 MB\u001b[0m \u001b[31m101.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 300 | "\u001b[?25hDownloading accelerate-1.1.1-py3-none-any.whl (333 kB)\r\n", 301 | "Downloading datasets-3.1.0-py3-none-any.whl (480 kB)\r\n", 302 | "Downloading fsspec-2024.9.0-py3-none-any.whl (179 kB)\r\n", 303 | "Downloading huggingface_hub-0.26.2-py3-none-any.whl (447 kB)\r\n", 304 | "Downloading numpy-2.1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.3 MB)\r\n", 305 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m16.3/16.3 MB\u001b[0m \u001b[31m121.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 306 | "\u001b[?25hDownloading packaging-24.2-py3-none-any.whl (65 kB)\r\n", 307 | "Downloading peft-0.13.2-py3-none-any.whl (320 kB)\r\n", 308 | "Downloading protobuf-3.20.3-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.1 MB)\r\n", 309 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m28.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 310 | "\u001b[?25hDownloading sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\r\n", 311 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m41.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 312 | "\u001b[?25hDownloading tqdm-4.67.0-py3-none-any.whl (78 kB)\r\n", 313 | "Downloading transformers-4.44.2-py3-none-any.whl (9.5 MB)\r\n", 314 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m9.5/9.5 MB\u001b[0m \u001b[31m109.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 315 | "\u001b[?25hDownloading trl-0.11.1-py3-none-any.whl (318 kB)\r\n", 316 | "Downloading typing_extensions-4.12.2-py3-none-any.whl (37 kB)\r\n", 317 | "Downloading tyro-0.9.1-py3-none-any.whl (111 kB)\r\n", 318 | "Using cached wheel-0.45.0-py3-none-any.whl (72 kB)\r\n", 319 | "Downloading xformers-0.0.28.post1-cp310-cp310-manylinux_2_28_x86_64.whl (16.7 MB)\r\n", 320 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m16.7/16.7 MB\u001b[0m \u001b[31m107.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 321 | "\u001b[?25hDownloading bitsandbytes-0.44.1-py3-none-manylinux_2_24_x86_64.whl (122.4 MB)\r\n", 322 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m122.4/122.4 MB\u001b[0m \u001b[31m100.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 323 | "\u001b[?25hDownloading filelock-3.16.1-py3-none-any.whl (16 kB)\r\n", 324 | "Downloading hf_transfer-0.1.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.6 MB)\r\n", 325 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.6/3.6 MB\u001b[0m \u001b[31m80.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 326 | "\u001b[?25hDownloading jinja2-3.1.4-py3-none-any.whl (133 kB)\r\n", 327 | "Downloading networkx-3.4.2-py3-none-any.whl (1.7 MB)\r\n", 328 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.7/1.7 MB\u001b[0m \u001b[31m49.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 329 | "\u001b[?25hDownloading psutil-6.1.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (287 kB)\r\n", 330 | "Downloading sympy-1.13.3-py3-none-any.whl (6.2 MB)\r\n", 331 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.2/6.2 MB\u001b[0m \u001b[31m97.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 332 | "\u001b[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)\r\n", 333 | "Downloading docstring_parser-0.16-py3-none-any.whl (36 kB)\r\n", 334 | "Downloading aiohttp-3.11.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.6 MB)\r\n", 335 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m46.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 336 | "\u001b[?25hDownloading MarkupSafe-3.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (20 kB)\r\n", 337 | "Downloading mpmath-1.3.0-py3-none-any.whl (536 kB)\r\n", 338 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m536.2/536.2 kB\u001b[0m \u001b[31m12.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 339 | "\u001b[?25hDownloading multiprocess-0.70.16-py310-none-any.whl (134 kB)\r\n", 340 | "Downloading pyarrow-18.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (40.0 MB)\r\n", 341 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.0/40.0 MB\u001b[0m \u001b[31m110.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 342 | "\u001b[?25hDownloading PyYAML-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (751 kB)\r\n", 343 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m751.2/751.2 kB\u001b[0m \u001b[31m22.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 344 | "\u001b[?25hDownloading regex-2024.11.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (781 kB)\r\n", 345 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m781.7/781.7 kB\u001b[0m \u001b[31m20.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 346 | "\u001b[?25hDownloading requests-2.32.3-py3-none-any.whl (64 kB)\r\n", 347 | "Downloading rich-13.9.4-py3-none-any.whl (242 kB)\r\n", 348 | "Downloading safetensors-0.4.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (435 kB)\r\n", 349 | "Downloading shtab-1.7.1-py3-none-any.whl (14 kB)\r\n", 350 | "Downloading tokenizers-0.19.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.6 MB)\r\n", 351 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.6/3.6 MB\u001b[0m \u001b[31m82.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 352 | "\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl (19.7 MB)\r\n", 353 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m19.7/19.7 MB\u001b[0m \u001b[31m123.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 354 | "\u001b[?25hDownloading pandas-2.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.1 MB)\r\n", 355 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.1/13.1 MB\u001b[0m \u001b[31m110.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 356 | "\u001b[?25hDownloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)\r\n", 357 | "Downloading aiohappyeyeballs-2.4.3-py3-none-any.whl (14 kB)\r\n", 358 | "Downloading aiosignal-1.3.1-py3-none-any.whl (7.6 kB)\r\n", 359 | "Downloading async_timeout-5.0.1-py3-none-any.whl (6.2 kB)\r\n", 360 | "Downloading attrs-24.2.0-py3-none-any.whl (63 kB)\r\n", 361 | "Downloading certifi-2024.8.30-py3-none-any.whl (167 kB)\r\n", 362 | "Downloading charset_normalizer-3.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (144 kB)\r\n", 363 | "Downloading frozenlist-1.5.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (241 kB)\r\n", 364 | "Downloading idna-3.10-py3-none-any.whl (70 kB)\r\n", 365 | "Downloading markdown_it_py-3.0.0-py3-none-any.whl (87 kB)\r\n", 366 | "Downloading multidict-6.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (124 kB)\r\n", 367 | "Downloading propcache-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (208 kB)\r\n", 368 | "Downloading pygments-2.18.0-py3-none-any.whl (1.2 MB)\r\n", 369 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.2/1.2 MB\u001b[0m \u001b[31m26.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 370 | "\u001b[?25hDownloading python_dateutil-2.9.0.post0-py2.py3-none-any.whl (229 kB)\r\n", 371 | "Downloading pytz-2024.2-py2.py3-none-any.whl (508 kB)\r\n", 372 | "Downloading tzdata-2024.2-py2.py3-none-any.whl (346 kB)\r\n", 373 | "Downloading urllib3-2.2.3-py3-none-any.whl (126 kB)\r\n", 374 | "Downloading yarl-1.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (319 kB)\r\n", 375 | "Downloading mdurl-0.1.2-py3-none-any.whl (10.0 kB)\r\n", 376 | "Downloading six-1.16.0-py2.py3-none-any.whl (11 kB)\r\n", 377 | "Saved ./wheelhouse/unsloth-2024.9.post4-py3-none-any.whl\r\n", 378 | "Saved ./wheelhouse/torch-2.4.1-cp310-cp310-manylinux1_x86_64.whl\r\n", 379 | "Saved ./wheelhouse/nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl\r\n", 380 | "Saved ./wheelhouse/nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl\r\n", 381 | "Saved ./wheelhouse/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl\r\n", 382 | "Saved ./wheelhouse/nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl\r\n", 383 | "Saved ./wheelhouse/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl\r\n", 384 | "Saved ./wheelhouse/nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl\r\n", 385 | "Saved ./wheelhouse/nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl\r\n", 386 | "Saved ./wheelhouse/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl\r\n", 387 | "Saved ./wheelhouse/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl\r\n", 388 | "Saved ./wheelhouse/nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl\r\n", 389 | "Saved ./wheelhouse/nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl\r\n", 390 | "Saved ./wheelhouse/triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl\r\n", 391 | "Saved ./wheelhouse/accelerate-1.1.1-py3-none-any.whl\r\n", 392 | "Saved ./wheelhouse/datasets-3.1.0-py3-none-any.whl\r\n", 393 | "Saved ./wheelhouse/fsspec-2024.9.0-py3-none-any.whl\r\n", 394 | "Saved ./wheelhouse/huggingface_hub-0.26.2-py3-none-any.whl\r\n", 395 | "Saved ./wheelhouse/numpy-2.1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl\r\n", 396 | "Saved ./wheelhouse/packaging-24.2-py3-none-any.whl\r\n", 397 | "Saved ./wheelhouse/peft-0.13.2-py3-none-any.whl\r\n", 398 | "Saved ./wheelhouse/protobuf-3.20.3-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl\r\n", 399 | "Saved ./wheelhouse/sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl\r\n", 400 | "Saved ./wheelhouse/tqdm-4.67.0-py3-none-any.whl\r\n", 401 | "Saved ./wheelhouse/transformers-4.44.2-py3-none-any.whl\r\n", 402 | "Saved ./wheelhouse/trl-0.11.1-py3-none-any.whl\r\n", 403 | "Saved ./wheelhouse/typing_extensions-4.12.2-py3-none-any.whl\r\n", 404 | "Saved ./wheelhouse/tyro-0.9.1-py3-none-any.whl\r\n", 405 | "Saved ./wheelhouse/wheel-0.45.0-py3-none-any.whl\r\n", 406 | "Saved ./wheelhouse/xformers-0.0.28.post1-cp310-cp310-manylinux_2_28_x86_64.whl\r\n", 407 | "Saved ./wheelhouse/bitsandbytes-0.44.1-py3-none-manylinux_2_24_x86_64.whl\r\n", 408 | "Saved ./wheelhouse/filelock-3.16.1-py3-none-any.whl\r\n", 409 | "Saved ./wheelhouse/hf_transfer-0.1.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl\r\n", 410 | "Saved ./wheelhouse/jinja2-3.1.4-py3-none-any.whl\r\n", 411 | "Saved ./wheelhouse/networkx-3.4.2-py3-none-any.whl\r\n", 412 | "Saved ./wheelhouse/psutil-6.1.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl\r\n", 413 | "Saved ./wheelhouse/sympy-1.13.3-py3-none-any.whl\r\n", 414 | "Saved ./wheelhouse/dill-0.3.8-py3-none-any.whl\r\n", 415 | "Saved ./wheelhouse/docstring_parser-0.16-py3-none-any.whl\r\n", 416 | "Saved ./wheelhouse/aiohttp-3.11.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl\r\n", 417 | "Saved ./wheelhouse/MarkupSafe-3.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl\r\n", 418 | "Saved ./wheelhouse/mpmath-1.3.0-py3-none-any.whl\r\n", 419 | "Saved ./wheelhouse/multiprocess-0.70.16-py310-none-any.whl\r\n", 420 | "Saved ./wheelhouse/pyarrow-18.0.0-cp310-cp310-manylinux_2_28_x86_64.whl\r\n", 421 | "Saved ./wheelhouse/PyYAML-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl\r\n", 422 | "Saved ./wheelhouse/regex-2024.11.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl\r\n", 423 | "Saved ./wheelhouse/requests-2.32.3-py3-none-any.whl\r\n", 424 | "Saved ./wheelhouse/rich-13.9.4-py3-none-any.whl\r\n", 425 | "Saved ./wheelhouse/safetensors-0.4.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl\r\n", 426 | "Saved ./wheelhouse/shtab-1.7.1-py3-none-any.whl\r\n", 427 | "Saved ./wheelhouse/tokenizers-0.19.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl\r\n", 428 | "Saved ./wheelhouse/nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl\r\n", 429 | "Saved ./wheelhouse/pandas-2.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl\r\n", 430 | "Saved ./wheelhouse/xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl\r\n", 431 | "Saved ./wheelhouse/aiohappyeyeballs-2.4.3-py3-none-any.whl\r\n", 432 | "Saved ./wheelhouse/aiosignal-1.3.1-py3-none-any.whl\r\n", 433 | "Saved ./wheelhouse/async_timeout-5.0.1-py3-none-any.whl\r\n", 434 | "Saved ./wheelhouse/attrs-24.2.0-py3-none-any.whl\r\n", 435 | "Saved ./wheelhouse/certifi-2024.8.30-py3-none-any.whl\r\n", 436 | "Saved ./wheelhouse/charset_normalizer-3.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl\r\n", 437 | "Saved ./wheelhouse/frozenlist-1.5.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl\r\n", 438 | "Saved ./wheelhouse/idna-3.10-py3-none-any.whl\r\n", 439 | "Saved ./wheelhouse/markdown_it_py-3.0.0-py3-none-any.whl\r\n", 440 | "Saved ./wheelhouse/multidict-6.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl\r\n", 441 | "Saved ./wheelhouse/propcache-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl\r\n", 442 | "Saved ./wheelhouse/pygments-2.18.0-py3-none-any.whl\r\n", 443 | "Saved ./wheelhouse/python_dateutil-2.9.0.post0-py2.py3-none-any.whl\r\n", 444 | "Saved ./wheelhouse/pytz-2024.2-py2.py3-none-any.whl\r\n", 445 | "Saved ./wheelhouse/tzdata-2024.2-py2.py3-none-any.whl\r\n", 446 | "Saved ./wheelhouse/urllib3-2.2.3-py3-none-any.whl\r\n", 447 | "Saved ./wheelhouse/yarl-1.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl\r\n", 448 | "Saved ./wheelhouse/mdurl-0.1.2-py3-none-any.whl\r\n", 449 | "Saved ./wheelhouse/six-1.16.0-py2.py3-none-any.whl\r\n" 450 | ] 451 | } 452 | ], 453 | "source": [ 454 | "!pip wheel \"unsloth==2024.9.post4\" \"torch==2.4.1\" -w wheelhouse" 455 | ] 456 | } 457 | ], 458 | "metadata": { 459 | "kaggle": { 460 | "accelerator": "none", 461 | "dataSources": [], 462 | "dockerImageVersionId": 30761, 463 | "isGpuEnabled": false, 464 | "isInternetEnabled": true, 465 | "language": "python", 466 | "sourceType": "notebook" 467 | }, 468 | "kernelspec": { 469 | "display_name": "Python 3", 470 | "language": "python", 471 | "name": "python3" 472 | }, 473 | "language_info": { 474 | "codemirror_mode": { 475 | "name": "ipython", 476 | "version": 3 477 | }, 478 | "file_extension": ".py", 479 | "mimetype": "text/x-python", 480 | "name": "python", 481 | "nbconvert_exporter": "python", 482 | "pygments_lexer": "ipython3", 483 | "version": "3.10.14" 484 | }, 485 | "papermill": { 486 | "default_parameters": {}, 487 | "duration": 137.912542, 488 | "end_time": "2024-11-22T14:11:46.869896", 489 | "environment_variables": {}, 490 | "exception": null, 491 | "input_path": "__notebook__.ipynb", 492 | "output_path": "__notebook__.ipynb", 493 | "parameters": {}, 494 | "start_time": "2024-11-22T14:09:28.957354", 495 | "version": "2.6.0" 496 | } 497 | }, 498 | "nbformat": 4, 499 | "nbformat_minor": 5 500 | } 501 | -------------------------------------------------------------------------------- /params.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://params.com/params.json", 3 | "docs": { 4 | "main": "README.md", 5 | "sidebar": { 6 | "Introduction": "README.md" 7 | } 8 | } 9 | } 10 | 11 | -------------------------------------------------------------------------------- /the_architects.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/da-fr/arc-prize-2024/bfc329123b2e0379395852323cf2862d7ae0d108/the_architects.pdf -------------------------------------------------------------------------------- /training_code/arc_downloader.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import zipfile 3 | import os 4 | import io 5 | import re 6 | import json 7 | 8 | 9 | zip_url = 'https://codeload.github.com/fchollet/ARC-AGI/zip/refs/heads/master' 10 | subset_names = ['training', 'evaluation'] 11 | 12 | 13 | def download_arc_data(arc_data_path): 14 | # check if files are already there 15 | required_files = [] 16 | for subset in subset_names: 17 | required_files.append(os.path.join(arc_data_path, f'arc-agi_{subset}_challenges.json')) 18 | required_files.append(os.path.join(arc_data_path, f'arc-agi_{subset}_solutions.json')) 19 | if all(map(os.path.isfile, required_files)): return 20 | 21 | 22 | # download repo 23 | r = requests.get(zip_url) 24 | assert r.status_code == 200 25 | z = zipfile.ZipFile(io.BytesIO(r.content)) 26 | 27 | # extract subsets 28 | extract_id = re.compile('^ARC-AGI-master/data/([a-z]+)/([a-z0-9]+)[.]json') 29 | datasets = {} 30 | for f in z.filelist: 31 | id = extract_id.match(f.filename) 32 | if id: 33 | if id.group(1) not in datasets: datasets[id.group(1)] = {} 34 | datasets[id.group(1)][id.group(2)] = json.loads(z.read(f)) 35 | 36 | # store challenges and solutions seperately 37 | os.makedirs(arc_data_path, exist_ok=True) 38 | for subset, challenges in datasets.items(): 39 | solutions = {} 40 | for k, v in challenges.items(): 41 | assert v.pop('name', k) == k # remove name tags that occur inconsistently in the data 42 | solutions[k] = [t.pop('output') for t in v['test']] 43 | with open(os.path.join(arc_data_path, f'arc-agi_{subset}_challenges.json'), 'w') as f: json.dump(challenges, f) 44 | with open(os.path.join(arc_data_path, f'arc-agi_{subset}_solutions.json'), 'w') as f: json.dump(solutions, f) 45 | print(f'Downloaded arc {subset} set.') 46 | -------------------------------------------------------------------------------- /training_code/arc_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Daniel Franzen and Jan Disselhoff 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import re 17 | import itertools 18 | import json 19 | import hashlib 20 | import numpy as np 21 | from numpy.random import randint 22 | from glob import glob 23 | from tqdm import tqdm 24 | from collections import OrderedDict 25 | 26 | 27 | class ArcDataset(object): 28 | def __init__(self, challenge, solutions={}, keys=None, is_fake=False, is_orig=False): 29 | if keys is None: 30 | self.keys = [] 31 | for k, v in challenge.items(): 32 | reply_num = len(v['test']) 33 | self.keys.extend([f'{k}_{i}' for i in range(reply_num)] if reply_num else [k]) 34 | self.keys = sorted(self.keys) 35 | else: 36 | self.keys = [k for k in keys] 37 | base_keys = set(map(self.get_base_key, self.keys)) 38 | self.challenge = {k: challenge[k] for k in base_keys} 39 | self.solutions = {k: solutions[k] for k in base_keys if k in solutions} 40 | self.is_orig = is_fake 41 | self.is_orig = is_orig 42 | 43 | @classmethod 44 | def load_from_json(cls, challenges_file): # for loading challenges in kaggle json arc dataset format 45 | with open(challenges_file) as f: 46 | challenge = f.read() 47 | return cls( 48 | challenge=json.loads(challenge), 49 | is_fake=hashlib.md5(challenge.encode('utf-8')).hexdigest().lower() == 'a6b7dac3cab03abf2eb333e16610d6dc', 50 | is_orig=True, 51 | ) 52 | 53 | def load_solutions(self, solutions_file): # for loading solutions in kaggle json arc dataset format 54 | with open(solutions_file) as f: solutions = f.read() 55 | data = json.loads(solutions) 56 | solutions = {k: data[k] for k in self.challenge} 57 | return self.__class__(keys=self.keys, challenge=self.challenge, solutions=solutions, is_orig=self.is_orig) 58 | 59 | # loader for Michael Hodel's ReArc https://github.com/neoneye/arc-dataset-collection 60 | @classmethod 61 | def load_from_rearc(cls, path, n, sizes, seed, mix_datasets={}, shuffle=True): # loader for ReArc 62 | np.random.seed(seed) 63 | keys = [[] for _ in range(n)] 64 | challenge = {} 65 | solutions = {} 66 | sizes = list(sizes) 67 | 68 | with open(os.path.join(path, 'metadata.json')) as f: 69 | metadata = json.load(f) 70 | 71 | for key in tqdm(sorted(metadata.keys()), desc="load dataset 're-arc'"): 72 | with open(os.path.join(path, 'tasks', f'{key}.json')) as f: 73 | tasks = np.random.permutation(json.load(f)).tolist() 74 | 75 | next_sizes = [] 76 | for epoch in range(n): 77 | if not len(next_sizes): 78 | next_sizes = np.random.permutation(sizes).tolist() 79 | next_size_with_test = 1 + next_sizes.pop() 80 | base_key = f'rearc-{key}{epoch:02x}' 81 | keys[epoch].append(f'{base_key}_0') 82 | challenge[base_key] = {'train': [], 'test': []} 83 | solutions[base_key] = reply = [] 84 | for _ in range(next_size_with_test): 85 | if not len(tasks): 86 | raise RuntimeError('Not enough examples - generate more re-arc examples or reduce epochs.') 87 | challenge[base_key]['train'].append({k: v for k, v in tasks.pop().items()}) 88 | challenge[base_key]['test'].append(challenge[base_key]['train'].pop()) 89 | solutions[base_key].append(challenge[base_key]['test'][-1].pop('output')) 90 | 91 | for name, ds in mix_datasets.items(): 92 | name = cls.base_key_replace_invalid_chars(name) 93 | for epoch, ds_keys in enumerate(np.array_split(ds.keys, len(keys))): 94 | keys[epoch].extend([f'{name}-{k}' for k in ds_keys]) 95 | challenge.update({f'{name}-{k}': v for k, v in ds.challenge.items()}) 96 | solutions.update({f'{name}-{k}': v for k, v in ds.solutions.items()}) 97 | 98 | if shuffle: 99 | keys = [np.random.permutation(epoch) for epoch in keys] 100 | keys = [k for epoch in keys for k in epoch] 101 | return cls(keys=keys, challenge=challenge, solutions=solutions, is_orig=True) 102 | 103 | # loader for neoneye's format, as used in https://github.com/neoneye/arc-dataset-collection 104 | @classmethod 105 | def load_from_neoneye(cls, path): 106 | pattern = os.path.join(path, 'data', '*', '*.json') 107 | files = set(glob(pattern)) 108 | for i in itertools.count(): 109 | updated = [fn for fn in files if fn.endswith(f'_v{i + 1}.json')] 110 | if not updated: break 111 | for fn in updated: 112 | files.remove(fn.replace(f'_v{i + 1}.json', ('.json' if i == 1 else f'_v{i}.json'))) 113 | assert len(files), f"No files found for pattern '{pattern}'." 114 | challenge = {} 115 | solutions = {} 116 | assert len(files), 'no files found' 117 | for fn in tqdm(files, desc=f"load dataset '{os.path.split(path)[-1]}'"): 118 | with open(fn) as f: 119 | key = cls.base_key_replace_invalid_chars(os.path.split(fn)[-1].replace('.json', '')) 120 | challenge[key] = json.load(f) 121 | solutions[key] = [test_case.pop('output') for test_case in challenge[key]['test']] 122 | return cls(challenge=challenge, solutions=solutions, is_orig=True) 123 | 124 | def change_keys(self, keys): 125 | return self.__class__(challenge=self.challenge, solutions=self.solutions, keys=keys) 126 | 127 | def split(self, n, split_seed, **kwargs): 128 | assert self.is_orig, 'Must be run on original dataset.' 129 | keys = sorted(self.challenge.keys()) 130 | if split_seed == 'len': 131 | keys = self.sort_keys_by_len(keys=keys, **kwargs) 132 | else: 133 | assert isinstance(split_seed, int) 134 | assert not kwargs 135 | np.random.seed(split_seed) 136 | keys = np.random.permutation(keys) 137 | split_datasets = [] 138 | for new_keys in np.array_split(keys, n): 139 | new_challenge = {k: self.challenge[k] for k in new_keys} 140 | split_datasets.append(self.__class__(challenge=new_challenge, solutions=self.solutions, is_orig=True)) 141 | return split_datasets 142 | 143 | def remove_test_data(self): 144 | assert self.is_orig, 'Must be run on original dataset.' 145 | new_challenge = {k: {'train': v['train'], 'test': []} for k, v in self.challenge.items()} 146 | return self.__class__(challenge=new_challenge) 147 | 148 | @staticmethod 149 | def base_key_replace_invalid_chars(base_key): 150 | return base_key.replace('_', '-').replace('.', '-') 151 | 152 | @staticmethod 153 | def get_base_key_and_reply_num(key): 154 | key_num = key.split('.', 1)[0] 155 | base_key, reply_num = key_num.split('_') if '_' in key_num else (key_num, -1) 156 | return base_key, int(reply_num) 157 | 158 | @classmethod 159 | def get_base_key(cls, key): 160 | return cls.get_base_key_and_reply_num(key)[0] 161 | 162 | def grouped_keys(self): 163 | grouped_keys = OrderedDict() 164 | for key in self.keys: 165 | base_key, reply_num = self.get_base_key_and_reply_num(key) 166 | if base_key not in grouped_keys: 167 | grouped_keys[base_key] = [] 168 | while len(grouped_keys[base_key])<=reply_num: 169 | grouped_keys[base_key].append([]) 170 | grouped_keys[base_key][reply_num].append(key) 171 | return grouped_keys 172 | 173 | def move_test_to_train(self): 174 | assert self.is_orig, 'Must be run on original dataset.' 175 | new_challenge = {} 176 | for k, v in self.challenge.items(): 177 | new_challenge[k] = { 178 | 'train': v['train'] + [{**t, 'output': self.solutions[k][i]} for i, t in enumerate(v['test'])], 179 | 'test': [] 180 | } 181 | return self.__class__(challenge=new_challenge, is_orig=self.is_orig) 182 | 183 | @staticmethod 184 | def permute_array(a, descriptor, invert=False): 185 | permutation = [int(i) for i in descriptor if str(i).isdigit()] 186 | assert sorted(permutation) == list(range(10)) 187 | a = np.asarray(a) 188 | assert a.ndim == 2 189 | if invert: permutation = np.argsort(permutation) 190 | a = np.asarray(permutation)[a] 191 | return a 192 | 193 | @classmethod 194 | def transform_array(cls, array, transforms, apply_perm=True, invert=False): 195 | if array is None: return None 196 | array = np.asarray(array) 197 | if invert: transforms = transforms[::-1] 198 | for tf in transforms: 199 | if tf == 'tp': 200 | array = np.swapaxes(array, 0, 1) 201 | if tf == 'rt': 202 | array = np.rot90(np.rot90(np.rot90(array)) if invert else array) 203 | if apply_perm and tf.startswith('perm'): 204 | array = cls.permute_array(array, tf, invert=invert) 205 | return array 206 | 207 | @classmethod 208 | def fmt_array(cls, array, lines_sep, tf=None): 209 | if tf is not None: 210 | array = cls.transform_array(array, tf) 211 | return lines_sep.join(''.join(map(str, row)) for row in array) 212 | 213 | @classmethod 214 | def fmt_input(cls, array, query_beg, reply_beg, **kwargs): 215 | return query_beg + cls.fmt_array(array, **kwargs) + reply_beg 216 | 217 | @classmethod 218 | def fmt_output(cls, array, reply_end, **kwargs): 219 | return cls.fmt_array(array, **kwargs) + reply_end 220 | 221 | @classmethod 222 | def fmt_train(cls, train_ex, preprompt, query_beg, reply_beg, reply_end, **kwargs): 223 | examples = [cls.fmt_input(x['input'], query_beg, reply_beg, **kwargs) + 224 | cls.fmt_output(x['output'], reply_end, **kwargs) for x in train_ex] 225 | return preprompt + ''.join(examples) 226 | 227 | def fmt_task(self, key, preprompt, query_beg, reply_beg, reply_end, reply=True, **kwargs): 228 | key_num, *tf = key.split('.') 229 | base_key, reply_num = self.get_base_key_and_reply_num(key_num) 230 | data_train = self.challenge[base_key]['train'] 231 | data_query = self.challenge[base_key]['test'] 232 | if reply is True: 233 | reply = self.solutions[base_key][reply_num] if base_key in self.solutions and reply_num >= 0 else None 234 | elif reply is not None: 235 | assert reply_num >= 0 236 | for t in tf: 237 | if t.startswith('ex'): 238 | data_train = [data_train[int(i)] for i in t[2:].split('-')] 239 | ret = dict(key=key) 240 | ret['train'] = self.fmt_train(data_train, preprompt, query_beg, reply_beg, reply_end, tf=tf, **kwargs) 241 | ret['query'] = self.fmt_input(data_query[reply_num]['input'], query_beg, reply_beg, tf=tf, **kwargs) if reply_num >= 0 else '' 242 | ret['input'] = ret['train'] + ret['query'] if reply_num >= 0 else '' 243 | if reply is not None: 244 | ret['reply'] = self.fmt_output(reply, reply_end, tf=tf, **kwargs) 245 | ret['text'] = ret['train'] + (ret['query'] + ret['reply'] if reply is not None else '') 246 | return ret 247 | 248 | def get_task(self, key, max_tokens=None, len_name=None, **kwargs): 249 | while True: 250 | fmt = self.fmt_task(key, **kwargs) 251 | if max_tokens is None or self.count_tokens(fmt[len_name]) <= max_tokens: 252 | break 253 | if not key.split('.')[-1].startswith('ex'): 254 | base_key = self.get_base_key(key) 255 | key = f"{key}.ex{'-'.join(map(str, range(len(self.challenge[base_key]['train']))))}" 256 | key_split = key.split('.') 257 | key_split[-1] = '-'.join(key_split[-1].split('-')[:-1]) 258 | assert len(key_split[-1]) > 2 and key_split[-1].startswith('ex') 259 | key = '.'.join(key_split) 260 | return key, fmt 261 | 262 | def repeat(self, n, seed=None): 263 | if seed is not None: 264 | np.random.seed(seed) 265 | new_keys = [] 266 | for i in range(n): 267 | new_keys.extend(self.keys if seed is None else np.random.permutation(self.keys)) 268 | return self.change_keys(new_keys) 269 | 270 | @staticmethod 271 | def count_tokens(data, replace_special=re.compile('<[^<]*>')): 272 | replaced = replace_special.sub('x', data) # replace '<...>' by a single char to count special tokens only once 273 | return len(replaced) 274 | 275 | @classmethod 276 | def max_new_tokens(cls, reply_end, lines_sep, max_size=30, safety_margin=1, **_): 277 | max_sized_reply = np.zeros([max_size, max_size], dtype=int) 278 | fmt = cls.fmt_output(max_sized_reply, reply_end=reply_end, lines_sep=lines_sep) 279 | return cls.count_tokens(fmt) + safety_margin 280 | 281 | def get_length(self, key, len_name, max_of_transposed=False, max_tokens=None, **fmt_opts): 282 | if not fmt_opts: 283 | fmt_opts = dict(preprompt='', query_beg='', reply_beg='', reply_end='', lines_sep='') 284 | length = self.count_tokens(self.fmt_task(key, **fmt_opts)[len_name]) 285 | else: 286 | length = self.count_tokens(self.fmt_task(key, **fmt_opts)[len_name]) 287 | if max_of_transposed: 288 | length = max(length, self.count_tokens(self.fmt_task(f'{key}.tp', fmt_opts)[len_name])) 289 | length += 1 # for bos token 290 | return length 291 | 292 | def sort_keys_by_len(self, keys, reverse=False, **kwargs): 293 | lengths = [(key, self.get_length(key, **kwargs)) for key in keys] 294 | return [x[0] for x in sorted(lengths, reverse=reverse, key=lambda x: x[1])] 295 | 296 | def sorted_by_len(self,**kwargs): 297 | return self.change_keys(self.sort_keys_by_len(self.keys, **kwargs)) 298 | 299 | def convert_with_token_limit(self, **kwargs): 300 | out_list = [] 301 | new_keys = [] 302 | for key in tqdm(self.keys, desc='convert dataset'): 303 | key, fmt = self.get_task(key, **kwargs) 304 | new_keys.append(key) 305 | out_list.append(fmt) 306 | return out_list, self.change_keys(new_keys) 307 | 308 | def as_list(self, **kwargs): 309 | return self.convert_with_token_limit(**kwargs)[0] 310 | 311 | @staticmethod 312 | def rand_perm(n, sep=None, keep_zero=False): 313 | permutation = np.random.permutation(n).tolist() 314 | if keep_zero: 315 | permutation = [0] + [x for x in permutation if x != 0] 316 | return permutation if sep is None else sep.join(map(str, permutation)) 317 | 318 | def augment_keys(self, keys, tp=False, rt=False, n=1, perm=False, keep_background=False, shfl_ex=False): 319 | keys = [k + n * '.tp' for n in range(2) for k in keys] if tp == 'all' else keys 320 | keys = [k + n * '.rt' for n in range(4) for k in keys] if rt == 'all' else keys 321 | keys = [k + bool(tp) * randint(0, 2) * '.tp' for k in keys] if tp != 'all' else keys 322 | keys = [k + bool(rt) * randint(0, 4) * '.rt' for k in keys] if rt != 'all' else keys 323 | keys = keys * n # repeat n times 324 | keys = [k + bool(perm) * ('.perm' + self.rand_perm(10, '', keep_background)) for k in keys] 325 | n_ex = lambda k: len(self.challenge[self.get_base_key(k)]['train']) 326 | keys = [k + bool(shfl_ex) * ('.ex' + self.rand_perm(n_ex(k), '-')) for k in keys] 327 | return keys 328 | 329 | def augment(self, seed, **kwargs): 330 | if seed is not None: 331 | np.random.seed(seed) 332 | return self.change_keys([k for key in self.keys for k in self.augment_keys([key], **kwargs)]) 333 | 334 | def decode(self, text, lines_sep, key=None): 335 | correct, info = None, 'unknown' 336 | try: 337 | data = [[int(x) for x in row if x.isdigit()] for row in text.split(lines_sep)] 338 | data = [row for row in data if len(row)] 339 | data = np.array(data, dtype=int) 340 | assert data.ndim == 2 and all(0 < x <= 30 for x in data.shape) 341 | except: 342 | data = None 343 | correct, info = False, 'cant_decode' 344 | if key is not None and data is not None: 345 | key_num, *transforms = key.split('.') 346 | base_key, reply_num = self.get_base_key_and_reply_num(key_num) 347 | data = self.transform_array(data, transforms, invert=True) 348 | correct_solution = self.solutions.get(base_key) 349 | if correct_solution is None: 350 | info = 'sol_unknown' 351 | else: 352 | correct_solution = np.asarray(correct_solution[reply_num]) 353 | if np.array_equal(correct_solution, data): 354 | correct, info = True, 'ALL_CORRECT' 355 | else: 356 | correct, info = False, ('bad_content' if correct_solution.shape == data.shape else 'bad_xy_size') 357 | return data, correct, info 358 | 359 | def get_submission(self, results=None): 360 | assert self.is_orig, 'Must be run on original dataset.' 361 | submission = {k: [{f'attempt_{i+1}': [[0]] for i in range(2)} for _ in range(len(v['test']))] for k, v in self.challenge.items()} 362 | if results is not None: 363 | self.fill_submission(results, submission) 364 | return submission 365 | 366 | @staticmethod 367 | def fill_submission(results, submission): 368 | for base_key, data in results.items(): 369 | for reply_num, guesses in enumerate(data): 370 | target_dict = submission[base_key][reply_num] 371 | for i, g in enumerate(guesses[:len(target_dict)]): 372 | target_dict[f'attempt_{i + 1}'] = g['output'].tolist() 373 | 374 | def validate_submission(self, submission): 375 | assert self.is_orig, 'Must be run on original dataset.' 376 | assert self.solutions, 'Solutions must be loaded for submission verification.' 377 | score = 0 378 | for k, v in self.solutions.items(): 379 | for i, r in enumerate(v): 380 | for attempt in ['attempt_1', 'attempt_2']: 381 | if np.array_equal(r, submission[k][i][attempt]): 382 | score += 1 / len(v) 383 | break 384 | return score 385 | -------------------------------------------------------------------------------- /training_code/inference_tools.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Daniel Franzen and Jan Disselhoff 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import sys 16 | import torch 17 | import hashlib 18 | import numpy as np 19 | from tqdm import tqdm 20 | 21 | 22 | def is_unsloth_model(model): 23 | return model.model_tags is not None and 'unsloth' in model.model_tags 24 | 25 | 26 | def logits_to_score(sequence, logits): 27 | assert sequence.ndim == 1 28 | assert logits.ndim == 2 29 | assert len(sequence) == len(logits) 30 | return -logits.log_softmax(-1)[torch.arange(len(logits)), sequence].sum().item() 31 | 32 | 33 | def calc_score(input, reply, model_tok, cache=None, **_): 34 | if cache is not None: # try loading result from cache 35 | return cache(calc_score)(input=input, reply=reply, model_tok=model_tok) 36 | 37 | # prepare model and tokenizer 38 | model, tokenizer = model_tok if isinstance(model_tok, (list, tuple)) else model_tok() 39 | 40 | with torch.no_grad(): # calculate score 41 | input_len = len(tokenizer(input)['input_ids']) 42 | tokenized = tokenizer([input+reply], return_tensors='pt') 43 | tokenized.pop('token_type_ids', None) 44 | sequence = tokenized['input_ids'][0][input_len:].cpu() 45 | logits = model(**tokenized.to(model.device))['logits'][0, input_len-1: -1].float().cpu() 46 | return logits_to_score(sequence, logits) 47 | 48 | 49 | def explore(model, logits, path, eos, max_new_tokens, max_score, pos, cache, score=0.0): 50 | first_token_logits, logits = logits[0], (logits[1:] if len(logits) > 1 else None) 51 | softmax = list(enumerate(-first_token_logits.detach().float().log_softmax(-1).cpu())) 52 | 53 | if len(path): # follow precomputed path first 54 | softmax[0], softmax[path[0]], path = softmax[path[0]], softmax[0], path[1:] 55 | 56 | return_suffixes = [] 57 | for i, s in softmax: # loop over all possible tokens 58 | next_score = score + s.item() 59 | if next_score < max_score: # check if still below the score limit, otherwise stop exploration 60 | if i == eos: # candidate found, append to suffixes (tokens are aggregated on backward pass) 61 | suffixes = [([], next_score)] 62 | elif max_new_tokens > 1: # check if still below token limit, otherwise stop exploration 63 | if logits is None: # if not following the initial guess, calculate logits to pass to explore function 64 | if pos < cache[0][0][0].shape[2]: # cut back key-value-cache when backtracking 65 | cache[0] = tuple(tuple(c[:, :, :pos] for c in l) for l in cache[0]) 66 | logits, cache[0] = model( 67 | input_ids=torch.full((1, 1), i, device=model.device), 68 | position_ids=torch.full((1, 1), pos, device=model.device), 69 | past_key_values=cache[0], 70 | )[:2] 71 | logits = logits[0] # unbatch 72 | # explore suffixes 73 | suffixes = explore(model, logits, path, eos, max_new_tokens-1, max_score, pos+1, cache, next_score) 74 | else: suffixes = [] 75 | 76 | # update suffixes 77 | for suffix in suffixes: 78 | suffix[0].append(i) 79 | return_suffixes.extend(suffixes) 80 | 81 | logits = None 82 | return return_suffixes 83 | 84 | 85 | def dfs(model, input_ids, eos_token_id, max_new_tokens, min_prob, pos=None, attention_mask=None): 86 | assert not torch.is_grad_enabled() 87 | assert attention_mask is None or attention_mask.all(), 'not implemented' 88 | sys.setrecursionlimit(1000 + max_new_tokens) # avoid stack overflows 89 | 90 | # prepare inputs 91 | input_ids = torch.as_tensor(input_ids, device=model.device, dtype=int) 92 | if input_ids.ndim == 2: 93 | input_ids = input_ids.squeeze(0) 94 | assert input_ids.ndim == 1, 'batching not supported' 95 | 96 | if pos is None: 97 | # no guess passed, set generation starting position to length of input 98 | pos = len(input_ids) 99 | elif pos < len(input_ids): 100 | # if guess passed, remove final eos_token from input 101 | if input_ids[-1] == eos_token_id: 102 | input_ids = input_ids[:-1] 103 | 104 | # process prompt and best guess 105 | logits, cache = model(input_ids=input_ids[torch.newaxis])[:2] 106 | logits = logits[0, pos-1:] 107 | 108 | # run dfs 109 | result = explore(model, logits, input_ids[pos:], eos_token_id, max_new_tokens, -np.log(min_prob), pos, [cache]) 110 | 111 | # return results sorted by scores 112 | return sorted([(np.array(suffix[::-1]), score_val) for suffix, score_val in result], key=lambda x: x[1]) 113 | 114 | 115 | def infer_single(prompt, model_tok, guess=None, min_prob=None, cache=None, **kwargs): 116 | assert len(prompt) 117 | 118 | if cache is not None: # try loading result from cache 119 | return cache(infer_single)(prompt=prompt, model_tok=model_tok, guess=guess, min_prob=min_prob, **kwargs) 120 | 121 | # prepare model and tokenizer 122 | model, tokenizer = model_tok if isinstance(model_tok, (list, tuple)) else model_tok() 123 | 124 | with torch.no_grad(): 125 | # tokenize input 126 | tokenized = tokenizer(prompt, return_tensors='pt').to(model.device) 127 | input_len = tokenized['input_ids'].shape[-1] 128 | tokenized.pop('token_type_ids', None) 129 | 130 | if min_prob is not None: # run dfs if 'min_prob' is passed 131 | if guess is not None: 132 | tokenized = tokenizer(guess, return_tensors='pt').to(model.device) 133 | tokenized.pop('token_type_ids', None) 134 | ret = dfs(model, **tokenized, pos=input_len, min_prob=min_prob, eos_token_id=tokenizer.eos_token_id, **kwargs) 135 | 136 | else: # run model 'generate' function 137 | assert kwargs.get('num_beams', 1) == 1 or not is_unsloth_model(model) 138 | gen = model.generate(**tokenized, return_dict_in_generate=True, output_logits=True, use_cache=True, 139 | eos_token_id=tokenizer.eos_token_id, **kwargs) 140 | sequence = gen['sequences'][0, input_len:].cpu() 141 | logits = torch.stack(gen['logits'], axis=-2)[0].float().cpu() 142 | ret = [(sequence, logits_to_score(sequence, logits))] 143 | 144 | return [(tokenizer.decode(o), s) for o, s in ret] 145 | 146 | 147 | def infer_task(keys, dataset, fmt_opts, aug_score_opts=None, pass_guess=True, print_func=print, **kwargs): 148 | unique_results = {} 149 | best_guess = (None, float('inf')) 150 | for key in keys: 151 | # format task 152 | key, fmt = dataset.get_task(key, **fmt_opts) 153 | input_len = dataset.count_tokens(fmt['input']) 154 | reply_len = dataset.count_tokens(fmt['reply']) if 'reply' in fmt else '?' 155 | 156 | # get current best guess 157 | guess = None 158 | if pass_guess and best_guess[0] is not None: 159 | guess = dataset.get_task(key, reply=best_guess[0], **fmt_opts)[1]['text'] 160 | assert guess.startswith(fmt['input']) 161 | 162 | # run inference 163 | data = infer_single(prompt=fmt['input'], guess=guess, **kwargs) 164 | 165 | # loop over inference outputs 166 | for i, (sequence, score) in enumerate(data): 167 | # decode output 168 | output, correct, corr_info = dataset.decode(sequence, fmt_opts['lines_sep'], key) 169 | 170 | # print some info 171 | token_info = f" in:{input_len:>4} out:{dataset.count_tokens(sequence):>3}/{reply_len:>3}" 172 | score_info = f"{min(np.exp(-score), 0.99):>3.0%}" 173 | shape_info = f'{output.shape[0]:>2}x{output.shape[1]:<2}' if output is not None else '--x--' 174 | print_func(f"{token_info} > {shape_info} {corr_info} p={score_info} [{key}.out{i}]") 175 | 176 | if output is not None: 177 | # add output to results 178 | hashable = tuple(map(tuple, output)) 179 | if hashable not in unique_results: 180 | unique_results[hashable] = dict(output=output, correct=correct, scores_inf={}) 181 | res = unique_results[hashable] 182 | 183 | # calculate score 184 | res['scores_inf'][key] = score 185 | if aug_score_opts and 'scores_aug' not in res: 186 | aug_score_opts_copy = aug_score_opts.copy() 187 | key_hash = int(hashlib.md5(key.split('.')[0].encode('utf-8')).hexdigest()[:6], 16) 188 | out_hash = int(hashlib.md5(str(hashable).encode('utf-8')).hexdigest()[:6], 16) 189 | np.random.seed(aug_score_opts_copy.pop('seed') + key_hash + out_hash) 190 | aug_keys = dataset.augment_keys([key.split('.', 1)[0]], **aug_score_opts_copy) 191 | aug_key_fmt = [dataset.get_task(k, reply=output, **fmt_opts) for k in aug_keys] 192 | res['scores_aug'] = {key: calc_score(**fmt, **kwargs) for key, fmt in aug_key_fmt} 193 | 194 | # update best guess 195 | new_score = min(res['scores_inf'].values()) 196 | if new_score < best_guess[1]: 197 | best_guess = res['output'], new_score 198 | 199 | return list(unique_results.values()) 200 | 201 | 202 | def inference_run(dataset, fmt_opts, max_new_tokens=None, callback=None, **kwargs): 203 | # set token limits 204 | if max_new_tokens is None: 205 | max_new_tokens = dataset.max_new_tokens(**fmt_opts) 206 | if 'max_tokens' in fmt_opts: 207 | fmt_opts = {**fmt_opts, 'max_tokens': fmt_opts['max_tokens'] - max_new_tokens, 'len_name': 'input'} 208 | 209 | # iterate over dataset 210 | results = {} 211 | with tqdm(dataset.grouped_keys().items(), desc='inference') as pbar: 212 | for base_key, tasks in pbar: 213 | results[base_key] = [] 214 | for task_num, task in enumerate(tasks): 215 | res = infer_task(keys=task, dataset=dataset, fmt_opts=fmt_opts, max_new_tokens=max_new_tokens, 216 | print_func=pbar.write, **kwargs) 217 | results[base_key].append(res) 218 | if callback is not None: 219 | callback(res, name=f'{base_key}_{task_num}', value=1/len(tasks), print_func=pbar.write) 220 | return results 221 | -------------------------------------------------------------------------------- /training_code/model_tools.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Daniel Franzen and Jan Disselhoff 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import json 17 | import torch 18 | import peft 19 | from tokenizers import Tokenizer 20 | from huggingface_hub import snapshot_download 21 | from trl import DataCollatorForCompletionOnlyLM 22 | 23 | 24 | class InputMaskingDataCollator(DataCollatorForCompletionOnlyLM): 25 | def __init__(self, mask_first_n_examples=0, **kwargs): 26 | super().__init__(**kwargs) 27 | self.mask_first_n_examples = mask_first_n_examples 28 | 29 | def torch_call(self, examples): 30 | batch = super().torch_call(examples) # call super, masking all inputs 31 | for i in range(len(batch['labels'])): 32 | for _ in range(self.mask_first_n_examples): 33 | # mask first still unmasked output block 34 | beg_pos = ((batch['labels'][i] != -100).nonzero().min()).item() 35 | mid_pos = ((batch['labels'][i][beg_pos:] == -100).nonzero().min()).item() + beg_pos 36 | end_pos = ((batch['labels'][i] != -100).nonzero().max()).item() + 1 37 | if mid_pos < end_pos: 38 | batch['labels'][i][beg_pos:mid_pos] = -100 39 | return batch 40 | 41 | 42 | def load_unsloth_4bit(model_path): 43 | from unsloth import FastLanguageModel 44 | return FastLanguageModel.from_pretrained( 45 | model_name=model_path, 46 | dtype=None, 47 | load_in_4bit=True, 48 | ) 49 | 50 | 51 | def save_model_and_tokenizer(store_path, model, tokenizer): 52 | model.save_pretrained(store_path) 53 | tokenizer.save_pretrained(store_path) 54 | to_delete = os.path.join(store_path, 'tokenizer.model') # delete file, as it interferes with token removal 55 | if os.path.isfile(to_delete): 56 | os.remove(to_delete) 57 | 58 | 59 | def fix_dtypes(model, fix_weights=True, fix_quant_states=True): 60 | # fix some data types (workaround for unsloth) 61 | for module in model.modules(): 62 | weight = getattr(module, 'weight', None) 63 | if weight is not None: 64 | if torch.is_floating_point(weight): 65 | if fix_weights and weight.dtype != model.dtype: 66 | module.to(model.dtype) 67 | else: 68 | qs = getattr(weight, 'quant_state', None) 69 | if qs is not None: 70 | if fix_quant_states and qs.dtype != model.dtype: 71 | qs.dtype = model.dtype 72 | return model 73 | 74 | 75 | def is_peft_model(model): 76 | return hasattr(model, 'peft_type') 77 | 78 | 79 | def merge_peft_into_base(model): 80 | assert is_peft_model(model) 81 | return fix_dtypes(model.merge_and_unload()) 82 | 83 | 84 | def get_and_fix_peft_weights(store): 85 | # change some keys (workaround for added 'modules_to_save') 86 | state_dict = peft.load_peft_weights(store) 87 | for k in list(state_dict.keys()): 88 | if 'modules_to_save' in k: 89 | del state_dict[k] 90 | original_module_key = k.replace('.modules_to_save.', '.original_module.') 91 | if original_module_key in state_dict: del state_dict[original_module_key] 92 | assert k.replace('.modules_to_save.', '.') in state_dict 93 | return state_dict 94 | 95 | 96 | def set_peft_weights(model, state_dict): 97 | res = peft.set_peft_model_state_dict(model, state_dict) 98 | assert not res.unexpected_keys, 'error loading weights - some keys not available in model' 99 | 100 | 101 | def load_peft_state(model, store): 102 | # convenience method to load peft weights from file and set them for model 103 | set_peft_weights(model, get_and_fix_peft_weights(store)) 104 | 105 | 106 | def get_or_map_special_tokens(data, mapping=None): 107 | tokens = set() 108 | if isinstance(data, dict): 109 | special = data.get('special_tokens') 110 | if special is not None: # find and/or update special token mappings 111 | for v in special.values(): 112 | tokens.update(v['ids']) 113 | if mapping is not None: 114 | v['ids'] = [mapping.get(i) for i in v['ids'] if i in mapping] 115 | for v in data.values(): # recursively process dict values 116 | tokens.update(get_or_map_special_tokens(v, mapping)) 117 | if isinstance(data, list): 118 | for v in data: # recursively process lists 119 | tokens.update(get_or_map_special_tokens(v, mapping)) 120 | return tokens 121 | 122 | 123 | def remove_tokenizer_normalizer(tokenizer): 124 | assert tokenizer.is_fast 125 | tokenizer_json = json.loads(tokenizer._tokenizer.to_str()) 126 | if tokenizer_json.get('normalizer') is not None: 127 | tokenizer_json['normalizer'] = None 128 | tokenizer._tokenizer = Tokenizer.from_str(json.dumps(tokenizer_json)) 129 | 130 | 131 | def shrink_tokenizer_vocab(tokenizer, keep_indices, keep_special=True, remove_unk=False): 132 | assert tokenizer.is_fast 133 | tok_json = json.loads(tokenizer._tokenizer.to_str()) 134 | assert tok_json['model']['type'] == "BPE" 135 | 136 | if keep_special: # get special tokens to keep 137 | keep_indices.update(tokenizer.all_special_ids) 138 | keep_indices.update(get_or_map_special_tokens(tok_json.get('post_processor'))) 139 | 140 | if remove_unk: # remove unknown token 141 | keep_indices -= {tokenizer.unk_token_id} 142 | 143 | # build mapping from old to new id 144 | mapping = {old: new for new, old in enumerate(sorted(keep_indices))} 145 | 146 | # update tokenizer info 147 | tok_json['model']['vocab'] = {k: mapping[v] for k, v in tok_json['model']['vocab'].items() if v in mapping} 148 | tok_json['model']['merges'] = [] 149 | tok_json['added_tokens'] = [{**t, 'id': mapping[t['id']]} for t in tok_json['added_tokens'] if t['id'] in mapping] 150 | tok_json['added_tokens'] = sorted(tok_json['added_tokens'], key=lambda t: t['id']) 151 | get_or_map_special_tokens(tok_json.get('post_processor'), mapping) 152 | 153 | tokenizer._tokenizer = Tokenizer.from_str(json.dumps(tok_json)) # reload json, modifying tokenizer in-place 154 | 155 | if remove_unk: 156 | tokenizer.unk_token = None 157 | 158 | return mapping # token mapping to be used later 159 | 160 | 161 | def shrink_model_embeddings(model, mapping): 162 | with torch.no_grad(): 163 | # copy embeddings to keep 164 | row_select = torch.tensor([x[0] for x in sorted(mapping.items(), key=lambda x: x[1])]) 165 | row_select = row_select.to(model.get_input_embeddings().weight.data.device) 166 | new_embed_t = torch.index_select(model.get_input_embeddings().weight.data, 0, row_select) 167 | row_select = row_select.to(model.get_output_embeddings().weight.data.device) 168 | new_lm_head = torch.index_select(model.get_output_embeddings().weight.data, 0, row_select) 169 | 170 | # resize model embeddings 171 | model.resize_token_embeddings(len(row_select)) 172 | 173 | # set to copied values 174 | model.get_input_embeddings().weight.data[:] = new_embed_t 175 | model.get_output_embeddings().weight.data[:] = new_lm_head 176 | 177 | # map model tokens to new id 178 | for config in [model.config, model.generation_config]: 179 | for k, v in list(config.to_dict().items()): 180 | if k.endswith('token_id'): 181 | setattr(config, k, [mapping.get(t) for t in v] if isinstance(v, list) else mapping.get(v)) 182 | 183 | 184 | def keep_single_char_tokens(model, tokenizer, keep=None, keep_norm=False, keep_model_tok=True, **kwargs): 185 | if not keep_norm: 186 | remove_tokenizer_normalizer(tokenizer) # required for some models 187 | if keep is None: # keep all single_length tokens 188 | keep_indices = set(v for k, v in tokenizer.vocab.items() if len(k) == 1) 189 | else: # keep tokens that were passed 190 | keep_indices = set(tokenizer.vocab[t] for t in keep) 191 | if keep_model_tok: # keep tokens used by model 192 | for config in [model.config, model.generation_config]: 193 | for k, v in config.to_dict().items(): 194 | if k.endswith('token_id'): 195 | keep_indices.update(v if isinstance(v, list) else [v]) 196 | keep_indices -= {None} 197 | mapping = shrink_tokenizer_vocab(tokenizer, keep_indices, **kwargs) 198 | shrink_model_embeddings(model, mapping) 199 | return mapping 200 | -------------------------------------------------------------------------------- /training_code/run_evaluation_Llama-rearc_with_ttt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Daniel Franzen and Jan Disselhoff 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import json 17 | from unsloth import FastLanguageModel 18 | from unsloth import UnslothTrainer as Trainer, unsloth_train, is_bfloat16_supported 19 | from unsloth import UnslothTrainingArguments as TrainingArguments 20 | from datasets import Dataset 21 | from diskcache import Cache 22 | 23 | from arc_loader import ArcDataset 24 | from model_tools import InputMaskingDataCollator 25 | from model_tools import load_unsloth_4bit, save_model_and_tokenizer 26 | from inference_tools import inference_run 27 | from selection import EvalTool 28 | from arc_downloader import download_arc_data 29 | 30 | # input paths 31 | base_model = 'da-fr/Llama-3.2-3B-ARChitects-ReArc-bnb-4bit' # auto-downloaded from huggingface.co 32 | arc_data_path = os.path.join('input', 'arc-prize-2024') # as on kaggle arc prize 2024 33 | download_arc_data(arc_data_path) 34 | 35 | # output paths 36 | output_path = 'output_evaluation_Llama-rearc_with_ttt' 37 | save_model_path = os.path.join(output_path, 'finetuned_model') 38 | inference_cache = os.path.join(output_path, 'inference_cache') 39 | submission_file = os.path.join(output_path, 'submission.json') 40 | 41 | # load evaluation dataset 42 | arc_eval_set = ArcDataset.load_from_json(os.path.join(arc_data_path, 'arc-agi_evaluation_challenges.json')) 43 | arc_eval_set = arc_eval_set.load_solutions(os.path.join(arc_data_path, 'arc-agi_evaluation_solutions.json')) 44 | 45 | # load model 46 | retrain = not os.path.exists(save_model_path) 47 | model, tokenizer = load_unsloth_4bit(base_model if retrain else save_model_path) 48 | 49 | # set formatting options 50 | fmt_opts = dict( 51 | preprompt='ABCDEFGHJKLMNPQRSTUVWXYZabcdefghjklmnpqrstuvwxyz', 52 | query_beg='I', 53 | reply_beg='\n+/-=O', 54 | reply_end='\n' + tokenizer.eos_token, 55 | lines_sep='\n', 56 | max_tokens=128000, 57 | ) 58 | 59 | if retrain: 60 | # create lora model 61 | model = FastLanguageModel.get_peft_model( 62 | model=model, 63 | target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj', 64 | 'embed_tokens', 'lm_head'], 65 | r=64, 66 | lora_alpha=16, 67 | lora_dropout=0, 68 | bias="none", 69 | use_gradient_checkpointing=True, 70 | random_state=42, 71 | use_rslora=True, 72 | loftq_config=None, 73 | ) 74 | 75 | # augment data set and transform to list (eventually removing examples to stay below the max. token count) 76 | train_aug_opts = dict(tp=True, rt=True, perm=True, shfl_ex=True, seed=0) 77 | train_dataset_augment = arc_eval_set.remove_test_data().repeat(n=48, seed=0).augment(**train_aug_opts) 78 | train_dataset_as_list = train_dataset_augment.as_list(len_name='text', **fmt_opts) 79 | 80 | # run test-time training 81 | FastLanguageModel.for_training(model) 82 | trainer = Trainer( 83 | model=model, 84 | tokenizer=tokenizer, 85 | train_dataset=Dataset.from_list(train_dataset_as_list), 86 | dataset_text_field="text", 87 | max_seq_length=fmt_opts['max_tokens'], 88 | data_collator=InputMaskingDataCollator( 89 | instruction_template=fmt_opts['query_beg'], 90 | response_template=fmt_opts['reply_beg'], 91 | mlm=False, 92 | tokenizer=tokenizer, 93 | mask_first_n_examples=0, 94 | ), 95 | args=TrainingArguments( 96 | per_device_train_batch_size=2, 97 | gradient_accumulation_steps=2, 98 | warmup_ratio=0.25, 99 | num_train_epochs=1, 100 | learning_rate=1e-4, 101 | embedding_learning_rate=1e-5, 102 | fp16=not is_bfloat16_supported(), 103 | bf16=is_bfloat16_supported(), 104 | logging_steps=10, 105 | optim="adamw_8bit", 106 | weight_decay=0.00, 107 | lr_scheduler_type='cosine', 108 | seed=42, 109 | output_dir='tmp_output', 110 | save_strategy='no', 111 | report_to='none', 112 | ), 113 | ) 114 | trainer_stats = unsloth_train(trainer) 115 | save_model_and_tokenizer(save_model_path, model, tokenizer) 116 | 117 | # run inference 118 | FastLanguageModel.for_inference(model) 119 | infer_aug_opts = dict(tp='all', rt='all', perm=True, shfl_ex=True, seed=10000) 120 | infer_dataset = arc_eval_set.repeat(2).augment(**infer_aug_opts) 121 | model_cache = Cache(inference_cache).memoize(typed=True, ignore=set(['model_tok', 'guess'])) 122 | eval_tool = EvalTool(n_guesses=2) 123 | inference_results = inference_run( 124 | model_tok=(model, tokenizer), 125 | fmt_opts=fmt_opts, 126 | dataset=infer_dataset, 127 | min_prob=0.1, 128 | aug_score_opts=infer_aug_opts, 129 | callback=eval_tool.process_result, 130 | cache=model_cache, 131 | ) 132 | 133 | # write submission 134 | with open(submission_file, 'w') as f: 135 | json.dump(arc_eval_set.get_submission(inference_results), f) 136 | with open(submission_file, 'r') as f: 137 | print(f"Score for '{submission_file}':", arc_eval_set.validate_submission(json.load(f))) 138 | -------------------------------------------------------------------------------- /training_code/run_evaluation_Llama-rearc_without_ttt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Daniel Franzen and Jan Disselhoff 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import json 17 | from unsloth import FastLanguageModel 18 | from diskcache import Cache 19 | 20 | from arc_loader import ArcDataset 21 | from model_tools import load_unsloth_4bit 22 | from inference_tools import inference_run 23 | from selection import EvalTool 24 | from arc_downloader import download_arc_data 25 | 26 | # input paths 27 | base_model = 'da-fr/Llama-3.2-3B-ARChitects-ReArc-bnb-4bit' # auto-downloaded from huggingface.co 28 | arc_data_path = os.path.join('input', 'arc-prize-2024') # as on kaggle arc prize 2024 29 | download_arc_data(arc_data_path) 30 | 31 | # output paths 32 | output_path = 'output_evaluation_Llama-rearc_without_ttt' 33 | inference_cache = os.path.join(output_path, 'inference_cache') 34 | submission_file = os.path.join(output_path, 'submission.json') 35 | 36 | # load evaluation dataset 37 | arc_eval_set = ArcDataset.load_from_json(os.path.join(arc_data_path, 'arc-agi_evaluation_challenges.json')) 38 | arc_eval_set = arc_eval_set.load_solutions(os.path.join(arc_data_path, 'arc-agi_evaluation_solutions.json')) 39 | 40 | # load model 41 | model, tokenizer = load_unsloth_4bit(base_model) 42 | 43 | # set formatting options 44 | fmt_opts = dict( 45 | preprompt='ABCDEFGHJKLMNPQRSTUVWXYZabcdefghjklmnpqrstuvwxyz', 46 | query_beg='I', 47 | reply_beg='\n+/-=O', 48 | reply_end='\n' + tokenizer.eos_token, 49 | lines_sep='\n', 50 | max_tokens=128000, 51 | ) 52 | 53 | # run inference 54 | FastLanguageModel.for_inference(model) 55 | infer_aug_opts = dict(tp='all', rt='all', perm=True, shfl_ex=True, seed=10000) 56 | infer_dataset = arc_eval_set.repeat(2).augment(**infer_aug_opts) 57 | model_cache = Cache(inference_cache).memoize(typed=True, ignore=set(['model_tok', 'guess'])) 58 | eval_tool = EvalTool(n_guesses=2) 59 | inference_results = inference_run( 60 | model_tok=(model, tokenizer), 61 | fmt_opts=fmt_opts, 62 | dataset=infer_dataset, 63 | min_prob=0.1, 64 | aug_score_opts=infer_aug_opts, 65 | callback=eval_tool.process_result, 66 | cache=model_cache, 67 | ) 68 | 69 | # write submission 70 | with open(submission_file, 'w') as f: 71 | json.dump(arc_eval_set.get_submission(inference_results), f) 72 | with open(submission_file, 'r') as f: 73 | print(f"Score for '{submission_file}':", arc_eval_set.validate_submission(json.load(f))) 74 | -------------------------------------------------------------------------------- /training_code/run_finetuning_Llama-rearc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Daniel Franzen and Jan Disselhoff 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from unsloth import FastLanguageModel 17 | from unsloth import UnslothTrainer as Trainer, unsloth_train, is_bfloat16_supported 18 | from unsloth import UnslothTrainingArguments as TrainingArguments 19 | from datasets import Dataset 20 | 21 | from arc_loader import ArcDataset 22 | from model_tools import InputMaskingDataCollator 23 | from model_tools import load_unsloth_4bit, keep_single_char_tokens, save_model_and_tokenizer 24 | from model_tools import load_peft_state, merge_peft_into_base 25 | from arc_downloader import download_arc_data 26 | 27 | # input paths 28 | base_model = 'chuanli11/Llama-3.2-3B-Instruct-uncensored' # auto-downloaded from huggingface.co 29 | re_arc_path = os.path.join('input', 're_arc') # https://github.com/michaelhodel/re-arc 30 | download_arc_data(arc_data_path) 31 | 32 | # output paths 33 | save_model_path = os.path.join('pretrained_models', "Llama-3.2-3B-ReArc") 34 | 35 | for action in ['train', 'merge']: 36 | # continue if task already accomplished 37 | if action == 'train' and os.path.exists(f'{save_model_path}-lora'): 38 | continue 39 | if action == 'merge' and os.path.exists(f'{save_model_path}-merged'): 40 | continue 41 | 42 | # load base model & reduce embedding size 43 | model = tokenizer = None # free memory 44 | model, tokenizer = load_unsloth_4bit(base_model) 45 | keep_tok = list('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789!?.:,;*+/-=')+tokenizer.tokenize('\n') 46 | keep_single_char_tokens(model, tokenizer, keep=keep_tok, remove_unk=True) 47 | 48 | # set formatting options 49 | fmt_opts = dict( 50 | preprompt='ABCDEFGHJKLMNPQRSTUVWXYZabcdefghjklmnpqrstuvwxyz', 51 | query_beg='I', 52 | reply_beg='\n+/-=O', 53 | reply_end='\n' + tokenizer.eos_token, 54 | lines_sep='\n', 55 | max_tokens=128000, 56 | ) 57 | 58 | # create lora model 59 | lora_layers = ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj', 'embed_tokens', 'lm_head'] 60 | model = FastLanguageModel.get_peft_model( 61 | model=model, 62 | target_modules=lora_layers, 63 | r=256, 64 | lora_alpha=24, 65 | lora_dropout=0, 66 | bias="none", 67 | use_gradient_checkpointing=True, 68 | random_state=42, 69 | use_rslora=True, 70 | loftq_config=None, 71 | ) 72 | 73 | if action == 'train': 74 | # load training data 75 | train_dataset = ArcDataset.load_from_rearc(re_arc_path, n=368, sizes=[6], seed=42) 76 | 77 | # augment data set and transform to list (eventually removing examples to stay below the max. token count) 78 | train_aug_opts = dict(tp=True, rt=True, perm=True, shfl_ex=True, seed=0) 79 | train_dataset_augment = train_dataset.augment(**train_aug_opts) 80 | train_dataset_as_list = train_dataset_augment.as_list(len_name='text', **fmt_opts) 81 | 82 | 83 | # run training 84 | FastLanguageModel.for_training(model) 85 | tokenizer.padding_side = 'right' 86 | trainer = Trainer( 87 | model=model, 88 | tokenizer=tokenizer, 89 | train_dataset=Dataset.from_list(train_dataset_as_list), 90 | dataset_text_field="text", 91 | max_seq_length=fmt_opts['max_tokens'], 92 | packing=False, 93 | data_collator=InputMaskingDataCollator( 94 | instruction_template=fmt_opts['query_beg'], 95 | response_template=fmt_opts['reply_beg'], 96 | mlm=False, 97 | tokenizer=tokenizer, 98 | mask_first_n_examples=1, 99 | ), 100 | args=TrainingArguments( 101 | per_device_train_batch_size=4, 102 | gradient_accumulation_steps=2, 103 | warmup_ratio=0.25, 104 | num_train_epochs=1, 105 | learning_rate=1e-4, 106 | embedding_learning_rate=1e-5, 107 | fp16=not is_bfloat16_supported(), 108 | bf16=is_bfloat16_supported(), 109 | logging_steps=10, 110 | optim="adamw_8bit", 111 | weight_decay=0.00, 112 | lr_scheduler_type='cosine', 113 | seed=42, 114 | output_dir='tmp_output', 115 | save_strategy='no', 116 | report_to='none', 117 | ), 118 | ) 119 | trainer_stats = unsloth_train(trainer) 120 | save_model_and_tokenizer(f'{save_model_path}-lora', model, tokenizer) 121 | 122 | if action == 'merge': 123 | # load peft weights and merge 124 | load_peft_state(model, f'{save_model_path}-lora') 125 | model = merge_peft_into_base(model) 126 | save_model_and_tokenizer(f'{save_model_path}-merged', model, tokenizer) 127 | -------------------------------------------------------------------------------- /training_code/run_finetuning_Nemo-full.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Daniel Franzen and Jan Disselhoff 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from unsloth import FastLanguageModel 17 | from unsloth import UnslothTrainer as Trainer, unsloth_train, is_bfloat16_supported 18 | from unsloth import UnslothTrainingArguments as TrainingArguments 19 | from datasets import Dataset 20 | 21 | from arc_loader import ArcDataset 22 | from model_tools import InputMaskingDataCollator 23 | from model_tools import load_unsloth_4bit, keep_single_char_tokens, save_model_and_tokenizer 24 | from model_tools import load_peft_state, merge_peft_into_base 25 | from arc_downloader import download_arc_data 26 | 27 | # input paths 28 | base_model = 'nvidia/Mistral-NeMo-Minitron-8B-Base' # auto-downloaded from huggingface.co 29 | arc_data_path = os.path.join('input', 'arc-prize-2024') # as on kaggle arc prize 2024 30 | download_arc_data(arc_data_path) 31 | re_arc_path = os.path.join('input', 're_arc') # https://github.com/michaelhodel/re-arc 32 | neoneye_path = os.path.join('input', 'arc-dataset-collection') # https://github.com/neoneye/arc-dataset-collection 33 | 34 | # output paths 35 | save_model_path = os.path.join('pretrained_models', "Mistral-NeMo-Minitron-Full") 36 | 37 | for action in ['train', 'merge']: 38 | # continue if task already accomplished 39 | if action == 'train' and os.path.exists(f'{save_model_path}-lora'): 40 | continue 41 | if action == 'merge' and os.path.exists(f'{save_model_path}-merged'): 42 | continue 43 | 44 | # load base model & reduce embedding size 45 | model = tokenizer = None # free memory 46 | model, tokenizer = load_unsloth_4bit(base_model) 47 | keep_tok = list('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789!?.:,;*+/-=')+tokenizer.tokenize('\n') 48 | keep_single_char_tokens(model, tokenizer, keep=keep_tok, remove_unk=True) 49 | 50 | # set formatting options 51 | fmt_opts = dict( 52 | preprompt='ABCDEFGHJKLMNPQRSTUVWXYZabcdefghjklmnpqrstuvwxyz', 53 | query_beg='I', 54 | reply_beg='\n+/-=O', 55 | reply_end='\n' + tokenizer.eos_token, 56 | lines_sep='\n', 57 | max_tokens=8192, 58 | ) 59 | 60 | # create lora model 61 | lora_layers = ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj', 'embed_tokens', 'lm_head'] 62 | model = FastLanguageModel.get_peft_model( 63 | model=model, 64 | target_modules=lora_layers, 65 | r=256, 66 | lora_alpha=24, 67 | lora_dropout=0, 68 | bias="none", 69 | use_gradient_checkpointing=True, 70 | random_state=42, 71 | use_rslora=True, 72 | loftq_config=None, 73 | ) 74 | 75 | if action == 'train': 76 | # load training data 77 | arc_eval_set = ArcDataset.load_from_json(os.path.join(arc_data_path, 'arc-agi_evaluation_challenges.json')) 78 | arc_eval_set = arc_eval_set.load_solutions(os.path.join(arc_data_path, 'arc-agi_evaluation_solutions.json')) 79 | concept_arc = ArcDataset.load_from_neoneye(os.path.join(neoneye_path, 'dataset', 'ConceptARC')) 80 | mix_datasets = { 81 | 'arceval': arc_eval_set.move_test_to_train().repeat(128), 82 | 'concept': concept_arc.move_test_to_train().repeat(128), 83 | } 84 | train_dataset = ArcDataset.load_from_rearc(re_arc_path, n=644, sizes=[6], seed=42, mix_datasets=mix_datasets) 85 | 86 | # augment data set and transform to list (eventually removing examples to stay below the max. token count) 87 | train_aug_opts = dict(tp=True, rt=True, perm=True, shfl_ex=True, seed=0) 88 | train_dataset_augment = train_dataset.augment(**train_aug_opts) 89 | train_dataset_as_list = train_dataset_augment.as_list(len_name='text', **fmt_opts) 90 | 91 | 92 | # run training 93 | FastLanguageModel.for_training(model) 94 | tokenizer.padding_side = 'right' 95 | trainer = Trainer( 96 | model=model, 97 | tokenizer=tokenizer, 98 | train_dataset=Dataset.from_list(train_dataset_as_list), 99 | dataset_text_field="text", 100 | max_seq_length=fmt_opts['max_tokens'], 101 | packing=False, 102 | data_collator=InputMaskingDataCollator( 103 | instruction_template=fmt_opts['query_beg'], 104 | response_template=fmt_opts['reply_beg'], 105 | mlm=False, 106 | tokenizer=tokenizer, 107 | mask_first_n_examples=1, 108 | ), 109 | args=TrainingArguments( 110 | per_device_train_batch_size=4, 111 | gradient_accumulation_steps=2, 112 | warmup_ratio=0.25, 113 | num_train_epochs=1, 114 | learning_rate=1e-4, 115 | embedding_learning_rate=1e-5, 116 | fp16=not is_bfloat16_supported(), 117 | bf16=is_bfloat16_supported(), 118 | logging_steps=10, 119 | optim="adamw_8bit", 120 | weight_decay=0.00, 121 | lr_scheduler_type='cosine', 122 | seed=42, 123 | output_dir='tmp_output', 124 | save_strategy='no', 125 | report_to='none', 126 | ), 127 | ) 128 | trainer_stats = unsloth_train(trainer) 129 | save_model_and_tokenizer(f'{save_model_path}-lora', model, tokenizer) 130 | 131 | if action == 'merge': 132 | # load peft weights and merge 133 | load_peft_state(model, f'{save_model_path}-lora') 134 | model = merge_peft_into_base(model) 135 | save_model_and_tokenizer(f'{save_model_path}-merged', model, tokenizer) 136 | -------------------------------------------------------------------------------- /training_code/selection.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Daniel Franzen and Jan Disselhoff 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | 17 | 18 | def max_gen_prob(res): 19 | return min(res['scores_inf'].values()) 20 | 21 | 22 | def max_aug_prob(res): 23 | return min(res['scores_aug'].values()) 24 | 25 | 26 | def min_aug_prob(res): 27 | return max(res['scores_aug'].values()) 28 | 29 | 30 | def sum_aug_prob(res): 31 | scores = list(res['scores_aug'].values()) 32 | return sum([-np.exp(-s) for s in scores]) 33 | 34 | 35 | def mul_aug_prob(res, base_log_prob=3): 36 | scores = list(res['scores_aug'].values()) 37 | return sum([s - base_log_prob for s in scores]) 38 | 39 | 40 | def mul_all_prob(res, base_log_prob=3): 41 | scores = list(res['scores_aug'].values()) 42 | scores.extend(res['scores_inf'].values()) 43 | return sum([s - base_log_prob for s in scores]) 44 | 45 | 46 | all_score_algos = [ 47 | max_gen_prob, # highest probability from inference results 48 | max_aug_prob, # highest probability from augmented scoring 49 | min_aug_prob, # lowest probability from augmented scoring 50 | sum_aug_prob, # sum of probabilites from augmented scoring 51 | mul_aug_prob, # sum of log probabilities from augmented scoring 52 | mul_all_prob, # sum of log probabilities from inference results and augmented scoring combined 53 | ] 54 | 55 | 56 | class EvalTool(object): # providing on-the-fly evaluation of scoring algorithms 57 | def __init__(self, n_guesses, score_algos=all_score_algos, sorting_algo=-1): 58 | self.score_algos = score_algos 59 | self.n_guesses = n_guesses # number of guesses allowed 60 | self.sorting_algo = sorting_algo # sorting algorithm for results, relevant for final submission (default: last) 61 | self.n_acc = [0] * len(score_algos) # counting correct n-guesses for different scoring algorithms 62 | self.a_acc = 0 # counting cases where the solution is found at all 63 | self.count = 0 # counting number of tasks seen 64 | 65 | def process_result(self, res, name, value, print_func=print): 66 | for r in res: 67 | r['scores_alg'] = [algo(r) for algo in self.score_algos] 68 | pos = ([i for i, r in enumerate(res) if r['correct']] + [None])[0] 69 | self.count += value 70 | self.a_acc += value if pos is not None else 0 71 | corr_info = f"{len(res)} candidates, correct solution {'not found' if pos is None else 'FOUND'}" 72 | if print_func is not None: 73 | print_func(f" * task '{name}': {corr_info}") 74 | for i, algo in enumerate(self.score_algos): 75 | if pos is not None: 76 | scores = [r['scores_alg'][i] for r in res] 77 | rank = np.argsort(np.argsort(scores))[pos] 78 | if rank < self.n_guesses: 79 | self.n_acc[i] += value 80 | rank_info = f", corr_sol. @{rank + 1:>2} / {len(res)}" if pos is not None else '' 81 | n_acc_info = f"{self.n_acc[i] / self.count:7.2%} ({self.n_acc[i]:>6.2f}/{self.count:>6.2f})" 82 | if print_func is not None: 83 | print_func(f" {f'{self.score_algos[i].__name__}:':14} {n_acc_info}{rank_info}") 84 | a_acc_info = f"{self.a_acc / self.count:7.2%} ({self.a_acc:>6.2f}/{self.count:>6.2f})" 85 | if print_func is not None: 86 | print_func(f" {'correct_found:':14} {a_acc_info}\n") 87 | if self.sorting_algo is not None: 88 | res.sort(key=lambda x: x['scores_alg'][self.sorting_algo]) 89 | --------------------------------------------------------------------------------