├── .github └── workflows │ ├── deploy.yaml │ └── test.yaml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── assets └── treasure_trove.jpeg ├── examples └── textbooks_A2YN │ ├── gpt_labeling.py │ ├── labeler_filtering.py │ └── train_labeler.py ├── nbs ├── 00_core.ipynb ├── 02_tutorial.ipynb ├── _quarto.yml ├── index.ipynb ├── nbdev.yml └── styles.css ├── settings.ini ├── setup.py └── treasure_trove ├── __init__.py ├── _modidx.py └── core.py /.github/workflows/deploy.yaml: -------------------------------------------------------------------------------- 1 | name: Deploy to GitHub Pages 2 | 3 | permissions: 4 | contents: write 5 | pages: write 6 | 7 | on: 8 | push: 9 | branches: [ "main", "master" ] 10 | workflow_dispatch: 11 | jobs: 12 | deploy: 13 | runs-on: ubuntu-latest 14 | steps: [uses: fastai/workflows/quarto-ghp@master] 15 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: [workflow_dispatch, pull_request, push] 3 | 4 | jobs: 5 | test: 6 | runs-on: ubuntu-latest 7 | steps: [uses: fastai/workflows/nbdev-ci@master] 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | logs/ 3 | results/ 4 | test-trainer/ 5 | 6 | _docs/ 7 | _proc/ 8 | 9 | *.bak 10 | .gitattributes 11 | .last_checked 12 | .gitconfig 13 | *.bak 14 | *.log 15 | *~ 16 | ~* 17 | _tmp* 18 | tmp* 19 | tags 20 | *.pkg 21 | 22 | # Byte-compiled / optimized / DLL files 23 | __pycache__/ 24 | *.py[cod] 25 | *$py.class 26 | 27 | # C extensions 28 | *.so 29 | 30 | # Distribution / packaging 31 | .Python 32 | env/ 33 | build/ 34 | develop-eggs/ 35 | dist/ 36 | downloads/ 37 | eggs/ 38 | .eggs/ 39 | lib/ 40 | lib64/ 41 | parts/ 42 | sdist/ 43 | var/ 44 | wheels/ 45 | *.egg-info/ 46 | .installed.cfg 47 | *.egg 48 | 49 | # PyInstaller 50 | # Usually these files are written by a python script from a template 51 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 52 | *.manifest 53 | *.spec 54 | 55 | # Installer logs 56 | pip-log.txt 57 | pip-delete-this-directory.txt 58 | 59 | # Unit test / coverage reports 60 | htmlcov/ 61 | .tox/ 62 | .coverage 63 | .coverage.* 64 | .cache 65 | nosetests.xml 66 | coverage.xml 67 | *.cover 68 | .hypothesis/ 69 | 70 | # Translations 71 | *.mo 72 | *.pot 73 | 74 | # Django stuff: 75 | *.log 76 | local_settings.py 77 | 78 | # Flask stuff: 79 | instance/ 80 | .webassets-cache 81 | 82 | # Scrapy stuff: 83 | .scrapy 84 | 85 | # Sphinx documentation 86 | docs/_build/ 87 | 88 | # PyBuilder 89 | target/ 90 | 91 | # Jupyter Notebook 92 | .ipynb_checkpoints 93 | 94 | # pyenv 95 | .python-version 96 | 97 | # celery beat schedule file 98 | celerybeat-schedule 99 | 100 | # SageMath parsed files 101 | *.sage.py 102 | 103 | # dotenv 104 | .env 105 | 106 | # virtualenv 107 | .venv 108 | venv/ 109 | ENV/ 110 | 111 | # Spyder project settings 112 | .spyderproject 113 | .spyproject 114 | 115 | # Rope project settings 116 | .ropeproject 117 | 118 | # mkdocs documentation 119 | /site 120 | 121 | # mypy 122 | .mypy_cache/ 123 | 124 | .vscode 125 | *.swp 126 | 127 | # osx generated files 128 | .DS_Store 129 | .DS_Store? 130 | .Trashes 131 | ehthumbs.db 132 | Thumbs.db 133 | .idea 134 | 135 | # pytest 136 | .pytest_cache 137 | 138 | # tools/trust-doc-nbs 139 | docs_src/.last_checked 140 | 141 | # symlinks to fastai 142 | docs_src/fastai 143 | tools/fastai 144 | 145 | # link checker 146 | checklink/cookies.txt 147 | 148 | # .gitconfig is now autogenerated 149 | .gitconfig 150 | 151 | # Quarto installer 152 | .deb 153 | .pkg 154 | 155 | # Quarto 156 | .quarto 157 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2022, fastai 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include settings.ini 2 | include LICENSE 3 | include CONTRIBUTING.md 4 | include README.md 5 | recursive-exclude * __pycache__ 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # treasure_trove 2 | 3 | 4 | 5 | This file will become your README and also the index of your 6 | documentation. 7 | 8 | ## Install 9 | 10 | ``` sh 11 | pip install treasure_trove 12 | ``` 13 | 14 | ## How to use 15 | 16 | Fill me in please! Don’t forget code examples: 17 | 18 | ``` python 19 | 1 + 1 20 | ``` 21 | 22 | 2 23 | -------------------------------------------------------------------------------- /assets/treasure_trove.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CarperAI/treasure_trove/ff2ef6973c10d45f41c21e3b3fdbaae1a96e1a14/assets/treasure_trove.jpeg -------------------------------------------------------------------------------- /examples/textbooks_A2YN/gpt_labeling.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from datasets import concatenate_datasets, load_dataset 4 | from squeakily.helpers import LLMLabeler 5 | from treasure_trove.core import label_dataset 6 | 7 | # Number of highly educational: 20 8 | # Number of medium educational: 93 9 | # Number of low educational: 7 10 | 11 | # Number of high quality: 40 12 | # Number of medium quality: 73 13 | # Number of low quality: 7 14 | 15 | # Number of highly educational: 14 16 | # Number of medium educational: 99 17 | # Number of low educational: 7 18 | 19 | instruction = f"""You are a senior level software engineer and you are tasked with reviewing a given code snippet's educational value. Use the following guidelines for determining the code's educational value: 20 | Highly educational code has the following: 21 | * Readability: The code is written in a way that is easy to understand and follow, with consistent detailed comments, formatting, meaningful variable names, and appropriate code structure. 22 | * Modularity: The code is organized into reusable and independent modules or functions, making it easier to comprehend and reuse in other projects. 23 | * Detailed explanations: The code is accompanied by thorough explanations of the concepts and techniques used, providing learners with a deeper understanding of the underlying principles. 24 | * Good design principles: The code follows best practices for software design, such as encapsulation, separation of concerns, and adhering to design patterns, making it easier to understand and maintain. 25 | Medium educational code has the following: 26 | * Readability: The code is reasonably well-structured and readable, but there may be occasional inconsistencies, some comments, or less descriptive variable names. 27 | * Partial modularity: The code contains some reusable components, but not all parts of the code are organized into separate modules or functions. 28 | * Some explanations: The code may have limited explanations or comments that provide a general understanding of the code's logic and purpose. 29 | * Adequate design principles: The code follows basic design principles, such as separation of concerns, but may not fully adhere to advanced design patterns or best practices. 30 | Low educational code has the following: 31 | * Poor readability: The code is poorly structured and difficult to follow, with little to no comments, inconsistent formatting and unclear variable names. 32 | * No modularity: The code is written in a monolithic style, lacking any organization into reusable or independent modules or functions. 33 | * Limited explanations: The code provides minimal or no explanations, leaving learners with little guidance on its logic or purpose. 34 | * Neglects design principles: The code shows a lack of consideration for design principles, making it harder to comprehend, maintain, and extend. 35 | * Boilerplate and autogenerated: The code contains a lot of boilerplate or autogenerated code, making it harder to comprehend and reuse. 36 | 37 | Output nothing other than one of the following labels: 38 | """ 39 | 40 | labels = ["highly educational", "medium educational", "low educational"] 41 | api_key = os.environ["OPENAI_KEY"] 42 | labeler = LLMLabeler(instruction, labels, model_name="gpt-4", api_key=api_key) # gpt-3.5-turbo 43 | 44 | 45 | languages = ["python", "go", "java", "javascript", "c", "cpp"] 46 | subsets = [] 47 | for lang in languages: 48 | print(f"Labeling {lang}...") 49 | ds = load_dataset("CarperAI/starcoder_60k", data_dir=f"{lang}")["train"] 50 | sample_ratio = 1.0 51 | subset = label_dataset(ds, "cleaned_code", labeler, labels, sample=sample_ratio, num_workers=4) 52 | # write to parquet 53 | subset.to_parquet(f"data/{lang}.parquet") 54 | subsets.append(subset) 55 | 56 | labeled_ds = concatenate_datasets(subsets) 57 | 58 | # upload to huggingface 59 | labeled_ds.push_to_hub("CarperAI/textbooks_A2YN_labeled_six_languages_60k", private=True) 60 | 61 | # print number of each class 62 | print(f"Number of {labels[0]}: {len(labeled_ds.filter(lambda x: x['label'] == 0))}") 63 | print(f"Number of {labels[1]}: {len(labeled_ds.filter(lambda x: x['label'] == 1))}") 64 | print(f"Number of {labels[2]}: {len(labeled_ds.filter(lambda x: x['label'] == 2))}") 65 | -------------------------------------------------------------------------------- /examples/textbooks_A2YN/labeler_filtering.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from transformers import pipeline, AutoTokenizer 3 | 4 | MODEL_NAME="CarperAI/code_edu_classifier_multi_lang" 5 | TOKENIZER_NAME="bigcode/starencoder" 6 | 7 | tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME) 8 | tokenizer.pad_token = tokenizer.eos_token 9 | pipe = pipeline( 10 | "text-classification", model=MODEL_NAME, tokenizer=tokenizer, device="cuda:0" 11 | ) 12 | data_dir = "" 13 | languages = ["python", "java", "javascript", "go", "c", "cpp"] 14 | tokenizer_kwargs = {'padding':True,'truncation':True,'max_length':1024} 15 | 16 | def func(x): 17 | labels = [] 18 | scores = [] 19 | for i in pipe(x["content"], truncation=True, padding="max_length", max_length=1024, batch_size=256): 20 | labels.append(i["label"]) 21 | scores.append(i["score"]) 22 | return {"label": labels, "score": scores} 23 | 24 | for lang in languages: 25 | ds = load_dataset("parquet", data_dir=f"{data_dir}/{lang}", split="train") 26 | print(f"Loaded {lang} dataset with {len(ds)} examples") 27 | ds = ds.map(lambda x: func(x), batched=True, batch_size=256) 28 | ds.to_parquet(f"{data_dir}/{lang}_labeled/") -------------------------------------------------------------------------------- /examples/textbooks_A2YN/train_labeler.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from transformers import pipeline, TrainingArguments 3 | from treasure_trove.core import filter_dataset, label_dataset, train_labeler 4 | 5 | 6 | ds = load_dataset("CarperAI/textbooks_A2YN_labeled_six_languages_60k")["train"] 7 | batch_size = 16 8 | training_args = TrainingArguments( 9 | output_dir="./code_edu", 10 | num_train_epochs=3, 11 | per_device_train_batch_size=batch_size, 12 | per_device_eval_batch_size=batch_size, 13 | eval_accumulation_steps=2, 14 | warmup_steps=500, 15 | weight_decay=0.01, 16 | logging_dir="./logs", 17 | logging_steps=50, 18 | evaluation_strategy="steps", 19 | eval_steps=200, 20 | save_strategy="epoch", 21 | # load_best_model_at_end=True, 22 | metric_for_best_model="accuracy", 23 | greater_is_better=True, 24 | seed=42, 25 | push_to_hub=True, 26 | hub_model_id="CarperAI/code_edu_classifier_multi_lang", 27 | hub_private_repo=True, 28 | ) 29 | base_model_name = "bigcode/starencoder" 30 | labels = ["highly educational", "medium educational", "low educational"] 31 | model, tokenizer = train_labeler( 32 | ds, 33 | "content", 34 | base_model_name, 35 | labels=labels, 36 | training_args=training_args, 37 | num_workers=64, 38 | max_length=1024, 39 | test_set_size=0.01, 40 | push_to_hub=True, 41 | ) -------------------------------------------------------------------------------- /nbs/00_core.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# core\n", 9 | "\n", 10 | "> Fill in a module description here" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "# | default_exp core" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "name": "stderr", 29 | "output_type": "stream", 30 | "text": [ 31 | "/admin/home-nathan/miniconda3/envs/trove/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 32 | " from .autonotebook import tqdm as notebook_tqdm\n" 33 | ] 34 | } 35 | ], 36 | "source": [ 37 | "# | export\n", 38 | "import evaluate\n", 39 | "import random\n", 40 | "import time\n", 41 | "\n", 42 | "import numpy as np\n", 43 | "\n", 44 | "from transformers import (\n", 45 | " AutoModelForSequenceClassification,\n", 46 | " AutoTokenizer,\n", 47 | " DataCollatorWithPadding,\n", 48 | " Trainer,\n", 49 | ")" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "# | hide\n", 59 | "from nbdev.showdoc import *" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "# | export\n", 69 | "\n", 70 | "def classify(x, labels, llm_labeler, max_failures=5, default_label=0):\n", 71 | " # do random sleep to avoid rate limiting\n", 72 | " num_sleep = random.randint(0, 5)\n", 73 | " time.sleep(num_sleep)\n", 74 | " failures = 0\n", 75 | " while failures < max_failures:\n", 76 | " try:\n", 77 | " label = labels.index(llm_labeler(x)[0])\n", 78 | " time.sleep(1)\n", 79 | " return label\n", 80 | " except Exception as e:\n", 81 | " failures += 1\n", 82 | " print(e)\n", 83 | " time.sleep(1)\n", 84 | " pass\n", 85 | " if failures == max_failures:\n", 86 | " return default_label" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "# | export\n", 96 | "def label_dataset(\n", 97 | " dataset, text_column, labeler_model, labels, sample=0.1, num_workers=4, max_chars=4_096\n", 98 | "):\n", 99 | " \"\"\"\n", 100 | " Filters a dataset using a labeler model.\n", 101 | "\n", 102 | " Args:\n", 103 | " dataset (datasets.Dataset): Dataset to filter\n", 104 | " text_column (str): Name of the column containing the text to classify\n", 105 | " labeler_model (Any): Model to use for labeling\n", 106 | " labels (List[str]): List of labels\n", 107 | " sample (float): The fraction of the dataset to label and use for filtering\n", 108 | " batch_size (int): Batch size for labeling\n", 109 | " num_workers (int): Number of workers for labeling\n", 110 | " max_chars (int): Maximum number of characters to truncate the text to before labeling (reduces rate limiting errors)\n", 111 | " \"\"\"\n", 112 | "\n", 113 | " # Get a subset of the dataset\n", 114 | " subset = dataset.shuffle(seed=115).select(range(int(len(dataset) * sample)))\n", 115 | "\n", 116 | " # Label the subset\n", 117 | " subset = subset.map(\n", 118 | " lambda x: {\"label\": classify(x[text_column][:max_chars], labels, labeler_model)},\n", 119 | " batched=False,\n", 120 | " num_proc=num_workers,\n", 121 | " )\n", 122 | "\n", 123 | " return subset" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": {}, 130 | "outputs": [ 131 | { 132 | "name": "stderr", 133 | "output_type": "stream", 134 | "text": [ 135 | "Using custom data configuration bigcode--the-stack-smol-8f8055c3a4e4b4e3\n", 136 | "Found cached dataset json (/admin/home-nathan/.cache/huggingface/datasets/bigcode___json/bigcode--the-stack-smol-8f8055c3a4e4b4e3/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)\n", 137 | "100%|██████████| 1/1 [00:00<00:00, 2.87it/s]\n", 138 | "Loading cached shuffled indices for dataset at /admin/home-nathan/.cache/huggingface/datasets/bigcode___json/bigcode--the-stack-smol-8f8055c3a4e4b4e3/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-feaf44b92e145e5a.arrow\n", 139 | "Loading cached processed dataset at /admin/home-nathan/.cache/huggingface/datasets/bigcode___json/bigcode--the-stack-smol-8f8055c3a4e4b4e3/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-17846c759c765b1d.arrow\n", 140 | "Loading cached processed dataset at /admin/home-nathan/.cache/huggingface/datasets/bigcode___json/bigcode--the-stack-smol-8f8055c3a4e4b4e3/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-8794fdd26ff8c584.arrow\n", 141 | "Loading cached processed dataset at /admin/home-nathan/.cache/huggingface/datasets/bigcode___json/bigcode--the-stack-smol-8f8055c3a4e4b4e3/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-6f9fd64122602836.arrow\n", 142 | "Loading cached processed dataset at /admin/home-nathan/.cache/huggingface/datasets/bigcode___json/bigcode--the-stack-smol-8f8055c3a4e4b4e3/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-548b3112b079028b.arrow\n" 143 | ] 144 | } 145 | ], 146 | "source": [ 147 | "from functools import partial\n", 148 | "from datasets import load_dataset\n", 149 | "\n", 150 | "\n", 151 | "def mock_labeler(x, labels):\n", 152 | " return [np.random.choice(labels, p=[0.25, 0.75])]\n", 153 | "\n", 154 | "\n", 155 | "labels = [\"positive\", \"negative\"]\n", 156 | "labeler = partial(mock_labeler, labels=labels)\n", 157 | "ds = load_dataset(\"bigcode/the-stack-smol\", data_dir=\"data/python\")[\"train\"]\n", 158 | "\n", 159 | "subset = label_dataset(ds, \"content\", labeler, labels, sample=0.1)\n", 160 | "\n", 161 | "assert \"label\" in subset.column_names" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "# | export\n", 171 | "def train_labeler(\n", 172 | " dataset,\n", 173 | " text_column,\n", 174 | " base_model_name,\n", 175 | " labels,\n", 176 | " training_args,\n", 177 | " test_set_size=0.05,\n", 178 | " num_workers=4,\n", 179 | " max_length=512,\n", 180 | " push_to_hub=False,\n", 181 | "):\n", 182 | " \"\"\"\n", 183 | " Trains a labeler model on a labeled dataset.\n", 184 | "\n", 185 | " Args:\n", 186 | " dataset (datasets.Dataset): Dataset to train on\n", 187 | " text_column (str): Name of the text column\n", 188 | " base_model_name (str): Name of the base model to use\n", 189 | " labels (list): List of labels\n", 190 | " training_args (transformers.TrainingArguments): Training arguments\n", 191 | " test_set_size (float): Fraction of the dataset to use for testing\n", 192 | " num_workers (int): Number of workers for training\n", 193 | " max_length (int): Maximum length of the input\n", 194 | " \"\"\"\n", 195 | " # Load the tokenizer\n", 196 | " tokenizer = AutoTokenizer.from_pretrained(base_model_name, max_length=max_length)\n", 197 | " if tokenizer.pad_token is None:\n", 198 | " tokenizer.pad_token = tokenizer.eos_token\n", 199 | "\n", 200 | " # Load the model\n", 201 | " model = AutoModelForSequenceClassification.from_pretrained(\n", 202 | " base_model_name, num_labels=len(labels), max_length=max_length\n", 203 | " )\n", 204 | " model.config.id2label = {i: label for i, label in enumerate(labels)}\n", 205 | "\n", 206 | " # Preprocess the dataset\n", 207 | " dataset = dataset.map(\n", 208 | " lambda x: tokenizer(\n", 209 | " x[text_column], padding=\"max_length\", truncation=True, max_length=max_length\n", 210 | " ),\n", 211 | " batched=True,\n", 212 | " num_proc=num_workers,\n", 213 | " )\n", 214 | "\n", 215 | " # Split the dataset\n", 216 | " dataset = dataset.train_test_split(test_size=test_set_size, seed=115)\n", 217 | "\n", 218 | " # Get the data collator\n", 219 | " data_collator = DataCollatorWithPadding(tokenizer=tokenizer)\n", 220 | "\n", 221 | " def compute_metrics(eval_preds):\n", 222 | " acc_metric = evaluate.load(\"accuracy\")\n", 223 | " precision_metric = evaluate.load(\"precision\")\n", 224 | " recall_metric = evaluate.load(\"recall\")\n", 225 | " f1_metric = evaluate.load(\"f1\")\n", 226 | " logits, labels = eval_preds\n", 227 | " if isinstance(logits, tuple): # Some models return tuples\n", 228 | " logits = logits[0]\n", 229 | " \n", 230 | " predictions = np.argmax(logits, axis=-1)\n", 231 | " acc = acc_metric.compute(predictions=predictions, references=labels)\n", 232 | " precision = precision_metric.compute(predictions=predictions, references=labels, average=\"macro\" if len(labels) > 2 else \"binary\")\n", 233 | " recall = recall_metric.compute(predictions=predictions, references=labels, average=\"macro\" if len(labels) > 2 else \"binary\")\n", 234 | " f1 = f1_metric.compute(predictions=predictions, references=labels, average=\"macro\" if len(labels) > 2 else \"binary\")\n", 235 | " return {**acc, **precision, **recall, **f1}\n", 236 | "\n", 237 | " # Get the trainer\n", 238 | " trainer = Trainer(\n", 239 | " model=model,\n", 240 | " args=training_args,\n", 241 | " train_dataset=dataset[\"train\"],\n", 242 | " eval_dataset=dataset[\"test\"],\n", 243 | " data_collator=data_collator,\n", 244 | " compute_metrics=compute_metrics,\n", 245 | " )\n", 246 | "\n", 247 | " # Train the model\n", 248 | " trainer.train()\n", 249 | "\n", 250 | " # Push the model to the hub\n", 251 | " if push_to_hub:\n", 252 | " trainer.push_to_hub()\n", 253 | "\n", 254 | " # Return the model\n", 255 | " return model, tokenizer" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": null, 261 | "metadata": {}, 262 | "outputs": [ 263 | { 264 | "name": "stderr", 265 | "output_type": "stream", 266 | "text": [ 267 | "/admin/home-nathan/miniconda3/envs/trove/lib/python3.11/site-packages/torch/cuda/__init__.py:546: UserWarning: Can't initialize NVML\n", 268 | " warnings.warn(\"Can't initialize NVML\")\n", 269 | "Some weights of the model checkpoint at prajjwal1/bert-small were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight']\n", 270 | "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 271 | "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", 272 | "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-small and are newly initialized: ['classifier.bias', 'classifier.weight']\n", 273 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", 274 | "#0: 0%| | 0/3 [00:00╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n", 299 | " in <module>:19 \n", 300 | " \n", 301 | " 16 seed=115, \n", 302 | " 17 push_to_hub=False \n", 303 | " 18 ) \n", 304 | " 19 model, tokenizer = train_labeler( \n", 305 | " 20 ds, \n", 306 | " 21 \"content\", \n", 307 | " 22 base_model_name, \n", 308 | " \n", 309 | " in train_labeler:79 \n", 310 | " \n", 311 | " 76 ) \n", 312 | " 77 \n", 313 | " 78 # Train the model \n", 314 | " 79 trainer.train() \n", 315 | " 80 \n", 316 | " 81 # Push the model to the hub \n", 317 | " 82 if push_to_hub: \n", 318 | " \n", 319 | " /admin/home-nathan/miniconda3/envs/trove/lib/python3.11/site-packages/transformers/trainer.py:16 \n", 320 | " 45 in train \n", 321 | " \n", 322 | " 1642 │ │ inner_training_loop = find_executable_batch_size( \n", 323 | " 1643 │ │ │ self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size \n", 324 | " 1644 │ │ ) \n", 325 | " 1645 │ │ return inner_training_loop( \n", 326 | " 1646 │ │ │ args=args, \n", 327 | " 1647 │ │ │ resume_from_checkpoint=resume_from_checkpoint, \n", 328 | " 1648 │ │ │ trial=trial, \n", 329 | " \n", 330 | " /admin/home-nathan/miniconda3/envs/trove/lib/python3.11/site-packages/transformers/trainer.py:19 \n", 331 | " 38 in _inner_training_loop \n", 332 | " \n", 333 | " 1935 │ │ │ │ │ self.control = self.callback_handler.on_step_begin(args, self.state, \n", 334 | " 1936 │ │ │ │ \n", 335 | " 1937 │ │ │ │ with self.accelerator.accumulate(model): \n", 336 | " 1938 │ │ │ │ │ tr_loss_step = self.training_step(model, inputs) \n", 337 | " 1939 │ │ │ │ \n", 338 | " 1940 │ │ │ │ if ( \n", 339 | " 1941 │ │ │ │ │ args.logging_nan_inf_filter \n", 340 | " \n", 341 | " /admin/home-nathan/miniconda3/envs/trove/lib/python3.11/site-packages/transformers/trainer.py:27 \n", 342 | " 59 in training_step \n", 343 | " \n", 344 | " 2756 │ │ │ return loss_mb.reduce_mean().detach().to(self.args.device) \n", 345 | " 2757 │ │ \n", 346 | " 2758 │ │ with self.compute_loss_context_manager(): \n", 347 | " 2759 │ │ │ loss = self.compute_loss(model, inputs) \n", 348 | " 2760 │ │ \n", 349 | " 2761 │ │ if self.args.n_gpu > 1: \n", 350 | " 2762 │ │ │ loss = loss.mean() # mean() to average on multi-gpu parallel training \n", 351 | " \n", 352 | " /admin/home-nathan/miniconda3/envs/trove/lib/python3.11/site-packages/transformers/trainer.py:27 \n", 353 | " 97 in compute_loss \n", 354 | " \n", 355 | " 2794 │ │ │ │ loss = self.label_smoother(outputs, labels) \n", 356 | " 2795 │ │ else: \n", 357 | " 2796 │ │ │ if isinstance(outputs, dict) and \"loss\" not in outputs: \n", 358 | " 2797 │ │ │ │ raise ValueError( \n", 359 | " 2798 │ │ │ │ │ \"The model did not return a loss from the inputs, only the following \n", 360 | " 2799 │ │ │ │ │ f\"{','.join(outputs.keys())}. For reference, the inputs it received \n", 361 | " 2800 │ │ │ │ ) \n", 362 | "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n", 363 | "ValueError: The model did not return a loss from the inputs, only the following keys: logits. For reference, the \n", 364 | "inputs it received are input_ids,token_type_ids,attention_mask.\n", 365 | "\n" 366 | ], 367 | "text/plain": [ 368 | "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n", 369 | "\u001b[31m│\u001b[0m in \u001b[92m\u001b[0m:\u001b[94m19\u001b[0m \u001b[31m│\u001b[0m\n", 370 | "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", 371 | "\u001b[31m│\u001b[0m \u001b[2m16 \u001b[0m\u001b[2m│ \u001b[0mseed=\u001b[94m115\u001b[0m, \u001b[31m│\u001b[0m\n", 372 | "\u001b[31m│\u001b[0m \u001b[2m17 \u001b[0m\u001b[2m│ \u001b[0mpush_to_hub=\u001b[94mFalse\u001b[0m \u001b[31m│\u001b[0m\n", 373 | "\u001b[31m│\u001b[0m \u001b[2m18 \u001b[0m) \u001b[31m│\u001b[0m\n", 374 | "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m19 model, tokenizer = train_labeler( \u001b[31m│\u001b[0m\n", 375 | "\u001b[31m│\u001b[0m \u001b[2m20 \u001b[0m\u001b[2m│ \u001b[0mds, \u001b[31m│\u001b[0m\n", 376 | "\u001b[31m│\u001b[0m \u001b[2m21 \u001b[0m\u001b[2m│ \u001b[0m\u001b[33m\"\u001b[0m\u001b[33mcontent\u001b[0m\u001b[33m\"\u001b[0m, \u001b[31m│\u001b[0m\n", 377 | "\u001b[31m│\u001b[0m \u001b[2m22 \u001b[0m\u001b[2m│ \u001b[0mbase_model_name, \u001b[31m│\u001b[0m\n", 378 | "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", 379 | "\u001b[31m│\u001b[0m in \u001b[92mtrain_labeler\u001b[0m:\u001b[94m79\u001b[0m \u001b[31m│\u001b[0m\n", 380 | "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", 381 | "\u001b[31m│\u001b[0m \u001b[2m76 \u001b[0m\u001b[2m│ \u001b[0m) \u001b[31m│\u001b[0m\n", 382 | "\u001b[31m│\u001b[0m \u001b[2m77 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n", 383 | "\u001b[31m│\u001b[0m \u001b[2m78 \u001b[0m\u001b[2m│ \u001b[0m\u001b[2m# Train the model\u001b[0m \u001b[31m│\u001b[0m\n", 384 | "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m79 \u001b[2m│ \u001b[0mtrainer.train() \u001b[31m│\u001b[0m\n", 385 | "\u001b[31m│\u001b[0m \u001b[2m80 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n", 386 | "\u001b[31m│\u001b[0m \u001b[2m81 \u001b[0m\u001b[2m│ \u001b[0m\u001b[2m# Push the model to the hub\u001b[0m \u001b[31m│\u001b[0m\n", 387 | "\u001b[31m│\u001b[0m \u001b[2m82 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mif\u001b[0m push_to_hub: \u001b[31m│\u001b[0m\n", 388 | "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", 389 | "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-nathan/miniconda3/envs/trove/lib/python3.11/site-packages/transformers/\u001b[0m\u001b[1;33mtrainer.py\u001b[0m:\u001b[94m16\u001b[0m \u001b[31m│\u001b[0m\n", 390 | "\u001b[31m│\u001b[0m \u001b[94m45\u001b[0m in \u001b[92mtrain\u001b[0m \u001b[31m│\u001b[0m\n", 391 | "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", 392 | "\u001b[31m│\u001b[0m \u001b[2m1642 \u001b[0m\u001b[2m│ │ \u001b[0minner_training_loop = find_executable_batch_size( \u001b[31m│\u001b[0m\n", 393 | "\u001b[31m│\u001b[0m \u001b[2m1643 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[96mself\u001b[0m._inner_training_loop, \u001b[96mself\u001b[0m._train_batch_size, args.auto_find_batch_size \u001b[31m│\u001b[0m\n", 394 | "\u001b[31m│\u001b[0m \u001b[2m1644 \u001b[0m\u001b[2m│ │ \u001b[0m) \u001b[31m│\u001b[0m\n", 395 | "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1645 \u001b[2m│ │ \u001b[0m\u001b[94mreturn\u001b[0m inner_training_loop( \u001b[31m│\u001b[0m\n", 396 | "\u001b[31m│\u001b[0m \u001b[2m1646 \u001b[0m\u001b[2m│ │ │ \u001b[0margs=args, \u001b[31m│\u001b[0m\n", 397 | "\u001b[31m│\u001b[0m \u001b[2m1647 \u001b[0m\u001b[2m│ │ │ \u001b[0mresume_from_checkpoint=resume_from_checkpoint, \u001b[31m│\u001b[0m\n", 398 | "\u001b[31m│\u001b[0m \u001b[2m1648 \u001b[0m\u001b[2m│ │ │ \u001b[0mtrial=trial, \u001b[31m│\u001b[0m\n", 399 | "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", 400 | "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-nathan/miniconda3/envs/trove/lib/python3.11/site-packages/transformers/\u001b[0m\u001b[1;33mtrainer.py\u001b[0m:\u001b[94m19\u001b[0m \u001b[31m│\u001b[0m\n", 401 | "\u001b[31m│\u001b[0m \u001b[94m38\u001b[0m in \u001b[92m_inner_training_loop\u001b[0m \u001b[31m│\u001b[0m\n", 402 | "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", 403 | "\u001b[31m│\u001b[0m \u001b[2m1935 \u001b[0m\u001b[2m│ │ │ │ │ \u001b[0m\u001b[96mself\u001b[0m.control = \u001b[96mself\u001b[0m.callback_handler.on_step_begin(args, \u001b[96mself\u001b[0m.state, \u001b[31m│\u001b[0m\n", 404 | "\u001b[31m│\u001b[0m \u001b[2m1936 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m \u001b[31m│\u001b[0m\n", 405 | "\u001b[31m│\u001b[0m \u001b[2m1937 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[94mwith\u001b[0m \u001b[96mself\u001b[0m.accelerator.accumulate(model): \u001b[31m│\u001b[0m\n", 406 | "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1938 \u001b[2m│ │ │ │ │ \u001b[0mtr_loss_step = \u001b[96mself\u001b[0m.training_step(model, inputs) \u001b[31m│\u001b[0m\n", 407 | "\u001b[31m│\u001b[0m \u001b[2m1939 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m \u001b[31m│\u001b[0m\n", 408 | "\u001b[31m│\u001b[0m \u001b[2m1940 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[94mif\u001b[0m ( \u001b[31m│\u001b[0m\n", 409 | "\u001b[31m│\u001b[0m \u001b[2m1941 \u001b[0m\u001b[2m│ │ │ │ │ \u001b[0margs.logging_nan_inf_filter \u001b[31m│\u001b[0m\n", 410 | "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", 411 | "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-nathan/miniconda3/envs/trove/lib/python3.11/site-packages/transformers/\u001b[0m\u001b[1;33mtrainer.py\u001b[0m:\u001b[94m27\u001b[0m \u001b[31m│\u001b[0m\n", 412 | "\u001b[31m│\u001b[0m \u001b[94m59\u001b[0m in \u001b[92mtraining_step\u001b[0m \u001b[31m│\u001b[0m\n", 413 | "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", 414 | "\u001b[31m│\u001b[0m \u001b[2m2756 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94mreturn\u001b[0m loss_mb.reduce_mean().detach().to(\u001b[96mself\u001b[0m.args.device) \u001b[31m│\u001b[0m\n", 415 | "\u001b[31m│\u001b[0m \u001b[2m2757 \u001b[0m\u001b[2m│ │ \u001b[0m \u001b[31m│\u001b[0m\n", 416 | "\u001b[31m│\u001b[0m \u001b[2m2758 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mwith\u001b[0m \u001b[96mself\u001b[0m.compute_loss_context_manager(): \u001b[31m│\u001b[0m\n", 417 | "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m2759 \u001b[2m│ │ │ \u001b[0mloss = \u001b[96mself\u001b[0m.compute_loss(model, inputs) \u001b[31m│\u001b[0m\n", 418 | "\u001b[31m│\u001b[0m \u001b[2m2760 \u001b[0m\u001b[2m│ │ \u001b[0m \u001b[31m│\u001b[0m\n", 419 | "\u001b[31m│\u001b[0m \u001b[2m2761 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[96mself\u001b[0m.args.n_gpu > \u001b[94m1\u001b[0m: \u001b[31m│\u001b[0m\n", 420 | "\u001b[31m│\u001b[0m \u001b[2m2762 \u001b[0m\u001b[2m│ │ │ \u001b[0mloss = loss.mean() \u001b[2m# mean() to average on multi-gpu parallel training\u001b[0m \u001b[31m│\u001b[0m\n", 421 | "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", 422 | "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-nathan/miniconda3/envs/trove/lib/python3.11/site-packages/transformers/\u001b[0m\u001b[1;33mtrainer.py\u001b[0m:\u001b[94m27\u001b[0m \u001b[31m│\u001b[0m\n", 423 | "\u001b[31m│\u001b[0m \u001b[94m97\u001b[0m in \u001b[92mcompute_loss\u001b[0m \u001b[31m│\u001b[0m\n", 424 | "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", 425 | "\u001b[31m│\u001b[0m \u001b[2m2794 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mloss = \u001b[96mself\u001b[0m.label_smoother(outputs, labels) \u001b[31m│\u001b[0m\n", 426 | "\u001b[31m│\u001b[0m \u001b[2m2795 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94melse\u001b[0m: \u001b[31m│\u001b[0m\n", 427 | "\u001b[31m│\u001b[0m \u001b[2m2796 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[96misinstance\u001b[0m(outputs, \u001b[96mdict\u001b[0m) \u001b[95mand\u001b[0m \u001b[33m\"\u001b[0m\u001b[33mloss\u001b[0m\u001b[33m\"\u001b[0m \u001b[95mnot\u001b[0m \u001b[95min\u001b[0m outputs: \u001b[31m│\u001b[0m\n", 428 | "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m2797 \u001b[2m│ │ │ │ \u001b[0m\u001b[94mraise\u001b[0m \u001b[96mValueError\u001b[0m( \u001b[31m│\u001b[0m\n", 429 | "\u001b[31m│\u001b[0m \u001b[2m2798 \u001b[0m\u001b[2m│ │ │ │ │ \u001b[0m\u001b[33m\"\u001b[0m\u001b[33mThe model did not return a loss from the inputs, only the following\u001b[0m \u001b[31m│\u001b[0m\n", 430 | "\u001b[31m│\u001b[0m \u001b[2m2799 \u001b[0m\u001b[2m│ │ │ │ │ \u001b[0m\u001b[33mf\u001b[0m\u001b[33m\"\u001b[0m\u001b[33m{\u001b[0m\u001b[33m'\u001b[0m\u001b[33m,\u001b[0m\u001b[33m'\u001b[0m.join(outputs.keys())\u001b[33m}\u001b[0m\u001b[33m. For reference, the inputs it received \u001b[0m \u001b[31m│\u001b[0m\n", 431 | "\u001b[31m│\u001b[0m \u001b[2m2800 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m) \u001b[31m│\u001b[0m\n", 432 | "\u001b[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n", 433 | "\u001b[1;91mValueError: \u001b[0mThe model did not return a loss from the inputs, only the following keys: logits. For reference, the \n", 434 | "inputs it received are input_ids,token_type_ids,attention_mask.\n" 435 | ] 436 | }, 437 | "metadata": {}, 438 | "output_type": "display_data" 439 | } 440 | ], 441 | "source": [ 442 | "# from transformers import TrainingArguments\n", 443 | "\n", 444 | "# base_model_name = \"prajjwal1/bert-small\"\n", 445 | "# batch_size = 4\n", 446 | "# training_args = TrainingArguments(\n", 447 | "# output_dir=\"./data\",\n", 448 | "# num_train_epochs=1,\n", 449 | "# per_device_train_batch_size=batch_size,\n", 450 | "# per_device_eval_batch_size=batch_size,\n", 451 | "# logging_dir=\"./logs\",\n", 452 | "# logging_steps=50,\n", 453 | "# evaluation_strategy=\"epoch\",\n", 454 | "# save_strategy=\"epoch\",\n", 455 | "# metric_for_best_model=\"accuracy\",\n", 456 | "# greater_is_better=True,\n", 457 | "# seed=115,\n", 458 | "# push_to_hub=False\n", 459 | "# )\n", 460 | "# model, tokenizer = train_labeler(\n", 461 | "# ds,\n", 462 | "# \"content\",\n", 463 | "# base_model_name,\n", 464 | "# labels=labels,\n", 465 | "# training_args=training_args,\n", 466 | "# )\n", 467 | "# assert type(model) == AutoModelForSequenceClassification" 468 | ] 469 | }, 470 | { 471 | "cell_type": "code", 472 | "execution_count": null, 473 | "metadata": {}, 474 | "outputs": [], 475 | "source": [ 476 | "# | export\n", 477 | "def filter_dataset(\n", 478 | " dataset, text_column, labeler_model, labels_to_keep, batch_size=32, num_workers=4\n", 479 | "):\n", 480 | " \"\"\"\n", 481 | " Filters a dataset using a labeler model.\n", 482 | "\n", 483 | " Args:\n", 484 | " dataset (datasets.Dataset): Dataset to filter\n", 485 | " text_column (str): Name of the text column\n", 486 | " labeler_model (transformers.pipelines.TextClassificationPipeline): Model to use for labeling\n", 487 | " labels_to_keep (list): List of labels to keep\n", 488 | " batch_size (int): Batch size for labeling\n", 489 | " num_workers (int): Number of workers for labeling\n", 490 | " \"\"\"\n", 491 | "\n", 492 | " def label(x):\n", 493 | " predicted = labeler_model(x, padding=True, truncation=True, max_length=512)\n", 494 | " return {\n", 495 | " \"label\": [l[\"label\"] for l in predicted],\n", 496 | " \"score\": [l[\"score\"] for l in predicted],\n", 497 | " }\n", 498 | "\n", 499 | "\n", 500 | " # TODO: first just label the dataset with scores and everything\n", 501 | " # then just split the dataset into the number of subsets and configs so that people can specify which one they want\n", 502 | "\n", 503 | " # Label the dataset\n", 504 | " dataset = dataset.map(\n", 505 | " lambda x: label(x[text_column]),\n", 506 | " batched=True,\n", 507 | " batch_size=batch_size,\n", 508 | " num_proc=num_workers,\n", 509 | " )\n", 510 | "\n", 511 | " # Filter the dataset\n", 512 | " dataset = dataset.filter(lambda x: x[\"label\"] in labels_to_keep)\n", 513 | "\n", 514 | " return dataset" 515 | ] 516 | }, 517 | { 518 | "cell_type": "code", 519 | "execution_count": null, 520 | "metadata": {}, 521 | "outputs": [ 522 | { 523 | "data": { 524 | "application/vnd.jupyter.widget-view+json": { 525 | "model_id": "c14fd4c3288947358f8b9e01c4a50655", 526 | "version_major": 2, 527 | "version_minor": 0 528 | }, 529 | "text/plain": [ 530 | " 0%| | 0/10 [00:00 Find the treasure in your trove of data" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "#| eval: false\n", 31 | "from datasets import load_dataset\n", 32 | "from squeakily.helpers import LLMLabeler\n", 33 | "from transformers import pipeline, TrainingArguments\n", 34 | "from treasure_trove.core import filter_dataset, label_dataset, train_labeler\n", 35 | "\n", 36 | "instruction = \"\"\"Please label the following code as either educational or non-educational.\n", 37 | "Educational code is code that is well written, follows best practices, has documentation such that it might be found in a textbook.\n", 38 | "Non-educational code is code that is poorly written, lacks documentation, contain bugs, or is not idiomatic.\n", 39 | "Labels:\n", 40 | "\"\"\"\n", 41 | "labels = [\"educational\", \"non-educational\"]\n", 42 | "api_key = \"\"\n", 43 | "labeler = LLMLabeler(instruction, labels, model_name=\"gpt-4\", api_key=api_key)\n", 44 | "\n", 45 | "ds = load_dataset(\"bigcode/the-stack-smol\", data_dir=\"data/python\")[\"train\"]\n", 46 | "\n", 47 | "# Get the training arguments\n", 48 | "batch_size=4,\n", 49 | "training_args = TrainingArguments(\n", 50 | " output_dir=\"./code_edu\",\n", 51 | " num_train_epochs=1,\n", 52 | " per_device_train_batch_size=batch_size,\n", 53 | " per_device_eval_batch_size=batch_size,\n", 54 | " warmup_steps=500,\n", 55 | " weight_decay=0.01,\n", 56 | " logging_dir=\"./logs\",\n", 57 | " logging_steps=10,\n", 58 | " evaluation_strategy=\"epoch\",\n", 59 | " save_strategy=\"epoch\",\n", 60 | " load_best_model_at_end=True,\n", 61 | " metric_for_best_model=\"accuracy\",\n", 62 | " greater_is_better=True,\n", 63 | " seed=42,\n", 64 | " push_to_hub=True,\n", 65 | ")" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "#| eval: false\n", 75 | "subset = label_dataset(ds, \"content\", labeler, labels, sample=0.001)\n", 76 | "base_model_name = \"bigcode/starencoder\"\n", 77 | "model, tokenizer = train_labeler(\n", 78 | " subset,\n", 79 | " \"content\",\n", 80 | " base_model_name,\n", 81 | " n_labels=len(labels),\n", 82 | " training_args=training_args,\n", 83 | " num_workers=4,\n", 84 | " max_length=512,\n", 85 | " push_to_hub=True,\n", 86 | ")\n", 87 | "pipe = pipeline(\n", 88 | " \"text-classification\", model=model, tokenizer=tokenizer, device=model.device\n", 89 | ")\n", 90 | "filtered_ds = filter_dataset(ds, \"content\", model, labels.index(\"educational\"))\n", 91 | "filtered_ds.push_to_hub(\"ncoop57/code_edu\")" 92 | ] 93 | } 94 | ], 95 | "metadata": { 96 | "kernelspec": { 97 | "display_name": "python3", 98 | "language": "python", 99 | "name": "python3" 100 | } 101 | }, 102 | "nbformat": 4, 103 | "nbformat_minor": 4 104 | } 105 | -------------------------------------------------------------------------------- /nbs/_quarto.yml: -------------------------------------------------------------------------------- 1 | project: 2 | type: website 3 | 4 | format: 5 | html: 6 | theme: cosmo 7 | css: styles.css 8 | toc: true 9 | 10 | website: 11 | twitter-card: true 12 | open-graph: true 13 | repo-actions: [issue] 14 | navbar: 15 | background: primary 16 | search: true 17 | sidebar: 18 | style: floating 19 | 20 | metadata-files: [nbdev.yml, sidebar.yml] -------------------------------------------------------------------------------- /nbs/index.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# | hide\n", 10 | "from treasure_trove.core import *" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": {}, 16 | "source": [ 17 | "# treasure_trove\n", 18 | "\n", 19 | "![Find the treasure in your trove of data](assets/treasure_trove.jpeg)" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "This file will become your README and also the index of your documentation." 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "## Install" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "```sh\n", 41 | "pip install git+https://github.com/CarperAI/treasure_trove\n", 42 | "```" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": {}, 48 | "source": [ 49 | "## How to use" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "Fill me in please! Don't forget code examples:" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [ 64 | { 65 | "data": { 66 | "text/plain": [ 67 | "2" 68 | ] 69 | }, 70 | "execution_count": null, 71 | "metadata": {}, 72 | "output_type": "execute_result" 73 | } 74 | ], 75 | "source": [ 76 | "1 + 1" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [] 85 | } 86 | ], 87 | "metadata": { 88 | "kernelspec": { 89 | "display_name": "python3", 90 | "language": "python", 91 | "name": "python3" 92 | } 93 | }, 94 | "nbformat": 4, 95 | "nbformat_minor": 4 96 | } 97 | -------------------------------------------------------------------------------- /nbs/nbdev.yml: -------------------------------------------------------------------------------- 1 | project: 2 | output-dir: _docs 3 | 4 | website: 5 | title: "treasure_trove" 6 | site-url: "https://CarperAI.github.io/treasure_trove" 7 | description: "Find the treasure in your trove of data" 8 | repo-branch: main 9 | repo-url: "https://github.com/CarperAI/treasure_trove" 10 | -------------------------------------------------------------------------------- /nbs/styles.css: -------------------------------------------------------------------------------- 1 | .cell { 2 | margin-bottom: 1rem; 3 | } 4 | 5 | .cell > .sourceCode { 6 | margin-bottom: 0; 7 | } 8 | 9 | .cell-output > pre { 10 | margin-bottom: 0; 11 | } 12 | 13 | .cell-output > pre, .cell-output > .sourceCode > pre, .cell-output-stdout > pre { 14 | margin-left: 0.8rem; 15 | margin-top: 0; 16 | background: none; 17 | border-left: 2px solid lightsalmon; 18 | border-top-left-radius: 0; 19 | border-top-right-radius: 0; 20 | } 21 | 22 | .cell-output > .sourceCode { 23 | border: none; 24 | } 25 | 26 | .cell-output > .sourceCode { 27 | background: none; 28 | margin-top: 0; 29 | } 30 | 31 | div.description { 32 | padding-left: 2px; 33 | padding-top: 5px; 34 | font-style: italic; 35 | font-size: 135%; 36 | opacity: 70%; 37 | } 38 | -------------------------------------------------------------------------------- /settings.ini: -------------------------------------------------------------------------------- 1 | [DEFAULT] 2 | # All sections below are required unless otherwise specified. 3 | # See https://github.com/fastai/nbdev/blob/master/settings.ini for examples. 4 | 5 | ### Python library ### 6 | repo = treasure_trove 7 | lib_name = %(repo)s 8 | version = 0.0.1 9 | min_python = 3.7 10 | license = apache2 11 | black_formatting = False 12 | 13 | ### nbdev ### 14 | doc_path = _docs 15 | lib_path = treasure_trove 16 | nbs_path = nbs 17 | recursive = True 18 | tst_flags = notest 19 | put_version_in_init = True 20 | 21 | ### Docs ### 22 | branch = main 23 | custom_sidebar = False 24 | doc_host = https://%(user)s.github.io 25 | doc_baseurl = /%(repo)s 26 | git_url = https://github.com/%(user)s/%(repo)s 27 | title = %(lib_name)s 28 | 29 | ### PyPI ### 30 | audience = Developers 31 | author = ncoop57 32 | author_email = nacooper01@email.wm.edu 33 | copyright = 2023 onwards, %(author)s 34 | description = Find the treasure in your trove of data 35 | keywords = nbdev jupyter notebook python 36 | language = English 37 | status = 3 38 | user = CarperAI 39 | 40 | ### Optional ### 41 | requirements = accelerate datasets evaluate fastcore langchain openai squeakily transformers 42 | dev_requirements = black[jupyter] ipykernel 43 | # console_scripts = -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pkg_resources import parse_version 2 | from configparser import ConfigParser 3 | import setuptools, shlex 4 | assert parse_version(setuptools.__version__)>=parse_version('36.2') 5 | 6 | # note: all settings are in settings.ini; edit there, not here 7 | config = ConfigParser(delimiters=['=']) 8 | config.read('settings.ini', encoding='utf-8') 9 | cfg = config['DEFAULT'] 10 | 11 | cfg_keys = 'version description keywords author author_email'.split() 12 | expected = cfg_keys + "lib_name user branch license status min_python audience language".split() 13 | for o in expected: assert o in cfg, "missing expected setting: {}".format(o) 14 | setup_cfg = {o:cfg[o] for o in cfg_keys} 15 | 16 | licenses = { 17 | 'apache2': ('Apache Software License 2.0','OSI Approved :: Apache Software License'), 18 | 'mit': ('MIT License', 'OSI Approved :: MIT License'), 19 | 'gpl2': ('GNU General Public License v2', 'OSI Approved :: GNU General Public License v2 (GPLv2)'), 20 | 'gpl3': ('GNU General Public License v3', 'OSI Approved :: GNU General Public License v3 (GPLv3)'), 21 | 'bsd3': ('BSD License', 'OSI Approved :: BSD License'), 22 | } 23 | statuses = [ '1 - Planning', '2 - Pre-Alpha', '3 - Alpha', 24 | '4 - Beta', '5 - Production/Stable', '6 - Mature', '7 - Inactive' ] 25 | py_versions = '3.6 3.7 3.8 3.9 3.10'.split() 26 | 27 | requirements = shlex.split(cfg.get('requirements', '')) 28 | if cfg.get('pip_requirements'): requirements += shlex.split(cfg.get('pip_requirements', '')) 29 | min_python = cfg['min_python'] 30 | lic = licenses.get(cfg['license'].lower(), (cfg['license'], None)) 31 | dev_requirements = (cfg.get('dev_requirements') or '').split() 32 | 33 | setuptools.setup( 34 | name = cfg['lib_name'], 35 | license = lic[0], 36 | classifiers = [ 37 | 'Development Status :: ' + statuses[int(cfg['status'])], 38 | 'Intended Audience :: ' + cfg['audience'].title(), 39 | 'Natural Language :: ' + cfg['language'].title(), 40 | ] + ['Programming Language :: Python :: '+o for o in py_versions[py_versions.index(min_python):]] + (['License :: ' + lic[1] ] if lic[1] else []), 41 | url = cfg['git_url'], 42 | packages = setuptools.find_packages(), 43 | include_package_data = True, 44 | install_requires = requirements, 45 | extras_require={ 'dev': dev_requirements }, 46 | dependency_links = cfg.get('dep_links','').split(), 47 | python_requires = '>=' + cfg['min_python'], 48 | long_description = open('README.md', encoding='utf-8').read(), 49 | long_description_content_type = 'text/markdown', 50 | zip_safe = False, 51 | entry_points = { 52 | 'console_scripts': cfg.get('console_scripts','').split(), 53 | 'nbdev': [f'{cfg.get("lib_path")}={cfg.get("lib_path")}._modidx:d'] 54 | }, 55 | **setup_cfg) 56 | 57 | 58 | -------------------------------------------------------------------------------- /treasure_trove/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.1" 2 | -------------------------------------------------------------------------------- /treasure_trove/_modidx.py: -------------------------------------------------------------------------------- 1 | # Autogenerated by nbdev 2 | 3 | d = { 'settings': { 'branch': 'main', 4 | 'doc_baseurl': '/treasure_trove', 5 | 'doc_host': 'https://CarperAI.github.io', 6 | 'git_url': 'https://github.com/CarperAI/treasure_trove', 7 | 'lib_path': 'treasure_trove'}, 8 | 'syms': { 'treasure_trove.core': { 'treasure_trove.core.classify': ('core.html#classify', 'treasure_trove/core.py'), 9 | 'treasure_trove.core.filter_dataset': ('core.html#filter_dataset', 'treasure_trove/core.py'), 10 | 'treasure_trove.core.label_dataset': ('core.html#label_dataset', 'treasure_trove/core.py'), 11 | 'treasure_trove.core.train_labeler': ('core.html#train_labeler', 'treasure_trove/core.py')}}} 12 | -------------------------------------------------------------------------------- /treasure_trove/core.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/00_core.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['classify', 'label_dataset', 'train_labeler', 'filter_dataset'] 5 | 6 | # %% ../nbs/00_core.ipynb 2 7 | import evaluate 8 | import random 9 | import time 10 | 11 | import numpy as np 12 | 13 | from transformers import ( 14 | AutoModelForSequenceClassification, 15 | AutoTokenizer, 16 | DataCollatorWithPadding, 17 | Trainer, 18 | ) 19 | 20 | # %% ../nbs/00_core.ipynb 4 21 | def classify(x, labels, llm_labeler, max_failures=5, default_label=0): 22 | # do random sleep to avoid rate limiting 23 | num_sleep = random.randint(0, 5) 24 | time.sleep(num_sleep) 25 | failures = 0 26 | while failures < max_failures: 27 | try: 28 | label = labels.index(llm_labeler(x)[0]) 29 | time.sleep(1) 30 | return label 31 | except Exception as e: 32 | failures += 1 33 | print(e) 34 | time.sleep(1) 35 | pass 36 | if failures == max_failures: 37 | return default_label 38 | 39 | # %% ../nbs/00_core.ipynb 5 40 | def label_dataset( 41 | dataset, text_column, labeler_model, labels, sample=0.1, num_workers=4, max_chars=4_096 42 | ): 43 | """ 44 | Filters a dataset using a labeler model. 45 | 46 | Args: 47 | dataset (datasets.Dataset): Dataset to filter 48 | text_column (str): Name of the column containing the text to classify 49 | labeler_model (Any): Model to use for labeling 50 | labels (List[str]): List of labels 51 | sample (float): The fraction of the dataset to label and use for filtering 52 | batch_size (int): Batch size for labeling 53 | num_workers (int): Number of workers for labeling 54 | max_chars (int): Maximum number of characters to truncate the text to before labeling (reduces rate limiting errors) 55 | """ 56 | 57 | # Get a subset of the dataset 58 | subset = dataset.shuffle(seed=115).select(range(int(len(dataset) * sample))) 59 | 60 | # Label the subset 61 | subset = subset.map( 62 | lambda x: {"label": classify(x[text_column][:max_chars], labels, labeler_model)}, 63 | batched=False, 64 | num_proc=num_workers, 65 | ) 66 | 67 | return subset 68 | 69 | # %% ../nbs/00_core.ipynb 7 70 | def train_labeler( 71 | dataset, 72 | text_column, 73 | base_model_name, 74 | labels, 75 | training_args, 76 | test_set_size=0.05, 77 | num_workers=4, 78 | max_length=512, 79 | push_to_hub=False, 80 | ): 81 | """ 82 | Trains a labeler model on a labeled dataset. 83 | 84 | Args: 85 | dataset (datasets.Dataset): Dataset to train on 86 | text_column (str): Name of the text column 87 | base_model_name (str): Name of the base model to use 88 | labels (list): List of labels 89 | training_args (transformers.TrainingArguments): Training arguments 90 | test_set_size (float): Fraction of the dataset to use for testing 91 | num_workers (int): Number of workers for training 92 | max_length (int): Maximum length of the input 93 | """ 94 | # Load the tokenizer 95 | tokenizer = AutoTokenizer.from_pretrained(base_model_name, max_length=max_length) 96 | if tokenizer.pad_token is None: 97 | tokenizer.pad_token = tokenizer.eos_token 98 | 99 | # Load the model 100 | model = AutoModelForSequenceClassification.from_pretrained( 101 | base_model_name, num_labels=len(labels), max_length=max_length 102 | ) 103 | model.config.id2label = {i: label for i, label in enumerate(labels)} 104 | 105 | # Preprocess the dataset 106 | dataset = dataset.map( 107 | lambda x: tokenizer( 108 | x[text_column], padding="max_length", truncation=True, max_length=max_length 109 | ), 110 | batched=True, 111 | num_proc=num_workers, 112 | ) 113 | 114 | # Split the dataset 115 | dataset = dataset.train_test_split(test_size=test_set_size, seed=115) 116 | 117 | # Get the data collator 118 | data_collator = DataCollatorWithPadding(tokenizer=tokenizer) 119 | 120 | def compute_metrics(eval_preds): 121 | acc_metric = evaluate.load("accuracy") 122 | precision_metric = evaluate.load("precision") 123 | recall_metric = evaluate.load("recall") 124 | f1_metric = evaluate.load("f1") 125 | logits, labels = eval_preds 126 | if isinstance(logits, tuple): # Some models return tuples 127 | logits = logits[0] 128 | 129 | predictions = np.argmax(logits, axis=-1) 130 | acc = acc_metric.compute(predictions=predictions, references=labels) 131 | precision = precision_metric.compute(predictions=predictions, references=labels, average="macro" if len(labels) > 2 else "binary") 132 | recall = recall_metric.compute(predictions=predictions, references=labels, average="macro" if len(labels) > 2 else "binary") 133 | f1 = f1_metric.compute(predictions=predictions, references=labels, average="macro" if len(labels) > 2 else "binary") 134 | return {**acc, **precision, **recall, **f1} 135 | 136 | # Get the trainer 137 | trainer = Trainer( 138 | model=model, 139 | args=training_args, 140 | train_dataset=dataset["train"], 141 | eval_dataset=dataset["test"], 142 | data_collator=data_collator, 143 | compute_metrics=compute_metrics, 144 | ) 145 | 146 | # Train the model 147 | trainer.train() 148 | 149 | # Push the model to the hub 150 | if push_to_hub: 151 | trainer.push_to_hub() 152 | 153 | # Return the model 154 | return model, tokenizer 155 | 156 | # %% ../nbs/00_core.ipynb 9 157 | def filter_dataset( 158 | dataset, text_column, labeler_model, labels_to_keep, batch_size=32, num_workers=4 159 | ): 160 | """ 161 | Filters a dataset using a labeler model. 162 | 163 | Args: 164 | dataset (datasets.Dataset): Dataset to filter 165 | text_column (str): Name of the text column 166 | labeler_model (transformers.pipelines.TextClassificationPipeline): Model to use for labeling 167 | labels_to_keep (list): List of labels to keep 168 | batch_size (int): Batch size for labeling 169 | num_workers (int): Number of workers for labeling 170 | """ 171 | 172 | def label(x): 173 | predicted = labeler_model(x, padding=True, truncation=True, max_length=512) 174 | return { 175 | "label": [l["label"] for l in predicted], 176 | "score": [l["score"] for l in predicted], 177 | } 178 | 179 | 180 | # TODO: first just label the dataset with scores and everything 181 | # then just split the dataset into the number of subsets and configs so that people can specify which one they want 182 | 183 | # Label the dataset 184 | dataset = dataset.map( 185 | lambda x: label(x[text_column]), 186 | batched=True, 187 | batch_size=batch_size, 188 | num_proc=num_workers, 189 | ) 190 | 191 | # Filter the dataset 192 | dataset = dataset.filter(lambda x: x["label"] in labels_to_keep) 193 | 194 | return dataset 195 | --------------------------------------------------------------------------------