├── LICENSE ├── README.md ├── classification ├── README.md ├── run_edu_bert.py ├── run_edu_bert.slurm ├── train_edu_bert.py └── train_edu_bert.slurm ├── decontamination ├── README.md └── decontaminate.py ├── deduplication ├── README.md └── deduplicate_dataset.py ├── evaluation ├── README.md ├── eval.slurm └── lighteval_tasks.py ├── fulltext_search ├── README.md ├── index_docs.py ├── index_docs.slurm ├── manticore.conf ├── search_sharded.py └── search_sharded.slurm ├── generation ├── README.md ├── boilerplate_cleanup.py └── llm_swarm_script.py ├── plots ├── clusters_map.png ├── cover.png ├── cover_01.png ├── educational_score.png └── topics_distpng.png └── prompts ├── README.md ├── auto_math_text ├── README.md └── build_science_prompts.py ├── khanacademy ├── README.md ├── generate_textbooks.py └── khan_dl │ ├── khan_dl.py │ ├── main.py │ └── requirements.txt ├── openstax ├── README.md └── build_openstax_prompts.py ├── stanford ├── 1_scraper.ipynb ├── 2_generate_course_outlines.ipynb └── README.md ├── stories ├── README.md ├── build_openhermes_stories_prompts.py ├── build_ultrachat_stories_prompts.py └── filter_openhermes.py ├── web_samples ├── README.md ├── build_web_prompts.py └── filter_and_classify_clusters.py └── wikihow ├── README.md └── wikihowcom-20231012-titles.txt /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cosmopedia 2 | 3 |
4 | Description of Image 5 |

Image generated by DALL-E, the prompt was generated by Mixtral-8x7B-Instruct-v0.1.

6 |
7 | 8 |

[🤗 Cosmopedia dataset] | [🤖 1B-LLM trained on Cosmopedia] | [📰 Blog post] 9 |

10 | blog post: 11 |
12 | 13 | ## Description 14 | Here you can find the code used for creating [Cosmopedia](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia), a dataset of synthetic textbooks, blogposts, stories, posts and WikiHow articles generated by Mixtral-8x7B-Instruct-v0.1. It contains over **30 million files and 25 billion tokens**, making it the largest open synthetic dataset to date. 15 | 16 | Cosmopedia covers a variety of topics; we tried to map world knowledge present in Web datasets like RefinedWeb and RedPajama, and generate synthetic content that covers them. This is the v0.1 of Cosmopedia, with ample room for improvement and topics to be more comprehensively covered. We hope this dataset will help the community's research efforts in the increasingly intriguing domain of synthetic data. 17 | 18 |
19 | clusters 20 |

The clusters of Cosmopedia.

21 |
22 | 23 | You can also find a files frequency plot of single topic clusters in `plots/topic_distpng.png`. 24 | 25 | ## Code structure 26 | - `prompts`: the code for building the prompts in each `seed_data` in Cosmopedia. In `web_samples`, you can also find pointers for the topic clustering we did. 27 | - `generation`: the code to run large scale synthetic generations with [llm-swarm](https://github.com/huggingface/llm-swarm) using the prompts you built. Cosmopedia consists of 25B tokens and was generated in > 10k H100 GPU hours. 28 | - `deduplication`: the script we used to run MinHash deduplication with [datatrove](https://github.com/huggingface/datatrove). 29 | - `decontamination`: the code we used to run n-gram decontamination against evaluation benchmarks, when training models on the dataset like [cosmopedian-1b](https://huggingface.co/HuggingFaceTB/cosmopedian-1b). 30 | -------------------------------------------------------------------------------- /classification/README.md: -------------------------------------------------------------------------------- 1 | # Educational value classifier 2 | 3 | ### 1. Finetune a model for educational value regression 4 | 5 | * edit `train_edu_bert.slurm` 6 | ```bash 7 | --base_model_name="Snowflake/snowflake-arctic-embed-m" \ # BERT-like base model 8 | --dataset_name="HuggingFaceTB/LLM_juries_fineweb_430k_annotations" \ # Llama3-annotated eduational value dataset 9 | --target_column="score" 10 | ``` 11 | * run the training script on a SLURM cluster: 12 | ```bash 13 | sbatch train_edu_bert.slurm 14 | ``` 15 | 16 | ### 2. Annotate a dataset with the educational scores predicted by the model 17 | 18 | ```bash 19 | sbatch run_edu_bert.slurm 20 | ``` -------------------------------------------------------------------------------- /classification/run_edu_bert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | from transformers import AutoTokenizer, AutoModelForSequenceClassification 4 | from datasets import load_dataset 5 | 6 | 7 | def main(args): 8 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 9 | model = AutoModelForSequenceClassification.from_pretrained( 10 | args.model_name, torch_dtype=torch.bfloat16 11 | ) 12 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 13 | model.to(device) 14 | 15 | dataset = load_dataset( 16 | args.dataset_name, 17 | args.dataset_config, 18 | split="train", 19 | cache_dir="/scratch/cosmo/cache/", 20 | num_proc=12, 21 | ) 22 | dataset = dataset.filter( 23 | lambda x, i: i % args.num_shards == args.shard, with_indices=True, num_proc=12 24 | ) 25 | 26 | def compute_scores(batch): 27 | inputs = tokenizer( 28 | batch[args.text_column], 29 | return_tensors="pt", 30 | padding="longest", 31 | truncation=True, 32 | ).to(device) 33 | with torch.no_grad(): 34 | outputs = model(**inputs) 35 | logits = outputs.logits.squeeze(-1).float().cpu().numpy() 36 | 37 | batch["score"] = logits.tolist() 38 | batch["int_score"] = [int(round(max(0, min(score, 5)))) for score in logits] 39 | return batch 40 | 41 | dataset = dataset.map(compute_scores, batched=True, batch_size=512) 42 | 43 | while True: 44 | try: 45 | config_name = f"{args.output_dataset_config}_{args.shard}" 46 | dataset.push_to_hub( 47 | args.output_dataset_name, 48 | config_name=config_name, 49 | private=True, 50 | max_shard_size="4096MB", 51 | ) 52 | break 53 | except Exception as e: 54 | print(e) 55 | continue 56 | 57 | 58 | if __name__ == "__main__": 59 | parser = argparse.ArgumentParser() 60 | 61 | parser.add_argument( 62 | "--model_name", type=str, default="HHuggingFaceFW/fineweb-edu-classifier" 63 | ) 64 | parser.add_argument("--dataset_name", type=str, default="HuggingFaceFW/fineweb") 65 | parser.add_argument("--dataset_config", type=str, default="default") 66 | parser.add_argument( 67 | "--output_dataset_name", type=str, default="HuggingFaceFW/fineweb-edu" 68 | ) 69 | parser.add_argument("--output_dataset_config", type=str, default="default") 70 | parser.add_argument("--text_column", type=str, default="text") 71 | parser.add_argument("--shard", type=int, required=True) 72 | parser.add_argument("--num_shards", type=int, required=True) 73 | 74 | args = parser.parse_args() 75 | main(args) 76 | -------------------------------------------------------------------------------- /classification/run_edu_bert.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=run_edu_bert 3 | #SBATCH --partition hopper-prod 4 | #SBATCH --qos=normal 5 | #SBATCH --requeue 6 | #SBATCH --nodes=1 7 | #SBATCH --ntasks-per-node=1 8 | #SBATCH --cpus-per-task=12 9 | #SBATCH --mem-per-cpu=20G 10 | #SBATCH --gpus=1 11 | #SBATCH -o %x_%j.out 12 | #SBATCH -e %x_%j.err 13 | #SBATCH --time=7-00:00:00 14 | #SBATCH --array=0-127%128 15 | 16 | set -x -e 17 | source ~/.bashrc 18 | source "$CONDA_PREFIX/etc/profile.d/conda.sh" 19 | source activate pytorch 20 | 21 | python run_edu_bert.py \ 22 | --model_name="HuggingFaceFW/fineweb-edu-classifier" \ 23 | --dataset_name="HuggingFaceFW/fineweb" \ 24 | --dataset_config="CC-MAIN-2019-04" \ 25 | --output_dataset_name="HuggingFaceFW/fineweb-edu-annotations" \ 26 | --output_dataset_config="CC-MAIN-2019-04" \ 27 | --text_column="text" \ 28 | --shard ${SLURM_ARRAY_TASK_ID} \ 29 | --num_shards 128 30 | -------------------------------------------------------------------------------- /classification/train_edu_bert.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | AutoTokenizer, 3 | DataCollatorWithPadding, 4 | TrainingArguments, 5 | Trainer, 6 | AutoModelForSequenceClassification, 7 | ) 8 | from datasets import load_dataset, ClassLabel 9 | import numpy as np 10 | import evaluate 11 | import argparse 12 | import os 13 | from sklearn.metrics import classification_report, confusion_matrix 14 | 15 | 16 | def compute_metrics(eval_pred): 17 | precision_metric = evaluate.load("precision") 18 | recall_metric = evaluate.load("recall") 19 | f1_metric = evaluate.load("f1") 20 | accuracy_metric = evaluate.load("accuracy") 21 | 22 | logits, labels = eval_pred 23 | preds = np.round(logits.squeeze()).clip(0, 5).astype(int) 24 | labels = np.round(labels.squeeze()).astype(int) 25 | precision = precision_metric.compute( 26 | predictions=preds, references=labels, average="macro" 27 | )["precision"] 28 | recall = recall_metric.compute( 29 | predictions=preds, references=labels, average="macro" 30 | )["recall"] 31 | f1 = f1_metric.compute(predictions=preds, references=labels, average="macro")["f1"] 32 | accuracy = accuracy_metric.compute(predictions=preds, references=labels)["accuracy"] 33 | 34 | report = classification_report(labels, preds) 35 | cm = confusion_matrix(labels, preds) 36 | print("Validation Report:\n" + report) 37 | print("Confusion Matrix:\n" + str(cm)) 38 | 39 | return { 40 | "precision": precision, 41 | "recall": recall, 42 | "f1_macro": f1, 43 | "accuracy": accuracy, 44 | } 45 | 46 | 47 | def main(args): 48 | dataset = load_dataset( 49 | args.dataset_name, split="train", cache_dir="/scratch/cosmo/cache/", num_proc=8 50 | ) 51 | dataset = dataset.map( 52 | lambda x: {args.target_column: np.clip(int(x[args.target_column]), 0, 5)}, 53 | num_proc=8, 54 | ) 55 | 56 | dataset = dataset.cast_column( 57 | args.target_column, ClassLabel(names=[str(i) for i in range(6)]) 58 | ) 59 | dataset = dataset.train_test_split( 60 | train_size=0.9, seed=42, stratify_by_column=args.target_column 61 | ) 62 | 63 | model = AutoModelForSequenceClassification.from_pretrained( 64 | args.base_model_name, 65 | num_labels=1, 66 | classifier_dropout=0.0, 67 | hidden_dropout_prob=0.0, 68 | output_hidden_states=False, 69 | ) 70 | tokenizer = AutoTokenizer.from_pretrained( 71 | args.base_model_name, 72 | model_max_length=min(model.config.max_position_embeddings, 512), 73 | ) 74 | if not tokenizer.pad_token: 75 | tokenizer.pad_token = tokenizer.eos_token 76 | 77 | def preprocess(examples): 78 | batch = tokenizer(examples["text"], truncation=True) 79 | batch["labels"] = np.float32(examples[args.target_column]) 80 | return batch 81 | 82 | dataset = dataset.map(preprocess, batched=True) 83 | data_collator = DataCollatorWithPadding(tokenizer=tokenizer) 84 | 85 | for param in model.bert.embeddings.parameters(): 86 | param.requires_grad = False 87 | for param in model.bert.encoder.parameters(): 88 | param.requires_grad = False 89 | 90 | training_args = TrainingArguments( 91 | output_dir=args.checkpoint_dir, 92 | hub_model_id=args.output_model_name, 93 | eval_strategy="steps", 94 | save_strategy="steps", 95 | eval_steps=1000, 96 | save_steps=1000, 97 | logging_steps=100, 98 | learning_rate=3e-4, 99 | num_train_epochs=20, 100 | seed=0, 101 | per_device_train_batch_size=256, 102 | per_device_eval_batch_size=128, 103 | eval_on_start=True, 104 | load_best_model_at_end=True, 105 | metric_for_best_model="f1_macro", 106 | greater_is_better=True, 107 | bf16=True, 108 | push_to_hub=True, 109 | ) 110 | 111 | trainer = Trainer( 112 | model=model, 113 | args=training_args, 114 | train_dataset=dataset["train"], 115 | eval_dataset=dataset["test"], 116 | tokenizer=tokenizer, 117 | data_collator=data_collator, 118 | compute_metrics=compute_metrics, 119 | ) 120 | 121 | trainer.train() 122 | trainer.save_model(os.path.join(args.checkpoint_dir, "final")) 123 | 124 | 125 | if __name__ == "__main__": 126 | parser = argparse.ArgumentParser() 127 | parser.add_argument( 128 | "--base_model_name", type=str, default="Snowflake/snowflake-arctic-embed-m" 129 | ) 130 | parser.add_argument( 131 | "--dataset_name", 132 | type=str, 133 | default="HuggingFaceFW/fineweb-edu-llama3-annotations", 134 | ) 135 | parser.add_argument("--target_column", type=str, default="score") 136 | parser.add_argument( 137 | "--checkpoint_dir", 138 | type=str, 139 | default="/fsx/anton/cosmopedia/edu_score/bert_snowflake_regression", 140 | ) 141 | parser.add_argument( 142 | "--output_model_name", type=str, default="HuggingFaceTB/fineweb-edu-scorer" 143 | ) 144 | args = parser.parse_args() 145 | 146 | main(args) 147 | -------------------------------------------------------------------------------- /classification/train_edu_bert.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=train_edu_bert 3 | #SBATCH --partition hopper-prod 4 | #SBATCH --nodes=1 5 | #SBATCH --ntasks-per-node=1 6 | #SBATCH --cpus-per-task=16 7 | #SBATCH --mem-per-cpu=20G 8 | #SBATCH --gpus=1 9 | #SBATCH -o %x_%j.out 10 | #SBATCH -e %x_%j.err 11 | #SBATCH --time=1-00:00:00 12 | 13 | set -x -e 14 | source ~/.bashrc 15 | source "$CONDA_PREFIX/etc/profile.d/conda.sh" 16 | source activate pytorch 17 | 18 | python train_edu_bert.py \ 19 | --base_model_name="Snowflake/snowflake-arctic-embed-m" \ 20 | --dataset_name="HuggingFaceFW/fineweb-edu-llama3-annotations" \ 21 | --target_column="score" \ 22 | --checkpoint_dir="/fsx/anton/cosmopedia/edu_score/snowflake_regression_median_jury" \ 23 | --output_model_name="HuggingFaceTB/fineweb-edu-scorer" 24 | -------------------------------------------------------------------------------- /decontamination/README.md: -------------------------------------------------------------------------------- 1 | # Decontamination 2 | 3 | We use a 10-gram overlap to retrieve potentially contaminated samples, similarly to [Phi-1](https://huggingface.co/papers/2306.11644). 4 | After retrieving the candidates, we run a diff between the dataset sample and the benchmark sample using `difflib.SequenceMatcher` and discard the sample if `len(matched_substrings)/len(benchmark_sample) > 0.5`. 5 | We run decontamination against all the benchmarks we evaluated the Cosmo-1B model on: MMLU, HellaSwag, PIQA, SIQA, Winogrande, OpenBookQA, ARC-easy, ARC-challenge. 6 | 7 | Usage: 8 | ```bash 9 | export HF_DATASETS_CACHE=/scratch/cosmo/cache 10 | export HUGGINGFACE_HUB_CACHE=/scratch/cosmo/cache 11 | 12 | python decontaminate.py --train_dataset "HuggingFaceTB/AMT_2M_Khanacademy_24k" --report_dataset_name "HuggingFaceTB/AMT_2M_Khanacademy_24k_decont_report" --save_decontaminated --decontaminated_dataset_name "HuggingFaceTB/AMT_2M_Khanacademy_24k_decont" 13 | ``` 14 | 15 | 16 | -------------------------------------------------------------------------------- /decontamination/decontaminate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import difflib 3 | import re 4 | import unicodedata 5 | from pathlib import Path 6 | from tqdm.auto import tqdm 7 | from datasets import load_dataset, Dataset 8 | 9 | 10 | def tokenize(text): 11 | """Normalize text by removing diacritics and tokenize.""" 12 | text = "".join(c for c in unicodedata.normalize("NFD", text) if unicodedata.category(c) != "Mn") 13 | tokens = re.findall("\w+", text.lower()) 14 | return tokens 15 | 16 | 17 | def get_ngrams(tokens, n): 18 | """Generate n-grams from tokens.""" 19 | return set(zip(*[tokens[i:] for i in range(n)])) 20 | 21 | 22 | def retrieve_ngrams_batch(batch, eval_ngrams, eval_datasets, eval_texts, ngram_len): 23 | """Find contaminated samples based on n-grams.""" 24 | new_batch = {"completion": [], "ngram": [], "bench_name": [], "bench_text": []} 25 | for completion in batch["completion"]: 26 | tokens = tokenize(completion) 27 | ngrams = get_ngrams(tokens, ngram_len) 28 | for ngram in ngrams: 29 | if ngram in eval_ngrams: 30 | idx = eval_ngrams[ngram] 31 | new_batch["completion"].append(completion) 32 | new_batch["ngram"].append(ngram) 33 | new_batch["bench_name"].append(eval_datasets[idx]) 34 | new_batch["bench_text"].append(eval_texts[idx]) 35 | break 36 | return new_batch 37 | 38 | 39 | def diff_strings(string1, string2): 40 | """Find matching parts between two strings.""" 41 | matcher = difflib.SequenceMatcher(None, string1.lower(), string2.lower(), autojunk=False) 42 | matching_blocks = matcher.get_matching_blocks() 43 | matches = [] 44 | for block in matching_blocks: 45 | start_a, start_b, length = block 46 | if length > 5: 47 | match = string1[start_a:start_a + length] 48 | matches.append(match) 49 | return matches 50 | 51 | 52 | def add_match_stats(example): 53 | gen_text = " ".join(tokenize(example["completion"])) 54 | bench_text = " ".join(tokenize(example["bench_text"])) 55 | matching_parts = diff_strings(gen_text, bench_text) 56 | match = " ".join("".join(matching_parts).split()) 57 | example["diff"] = matching_parts 58 | example["diff_ratio"] = len(match) / len(bench_text) if len(bench_text) > 0 else 0 59 | example["diff_length"] = len(match) 60 | example["longest_diff_part"] = max(matching_parts, key=len, default="") 61 | example["longest_diff_part_length"] = len(example["longest_diff_part"]) 62 | return example 63 | 64 | 65 | def main(args): 66 | # Load the evaluation data to build n-grams index 67 | eval_ngrams, eval_datasets, eval_texts = {}, [], [] 68 | eval_data = load_dataset(args.eval_dataset, split="train", num_proc=args.num_proc) 69 | for example in tqdm(eval_data): 70 | tokens = tokenize(example["text"]) 71 | ngrams = get_ngrams(tokens, args.ngram_length) 72 | if ngrams: 73 | idx = len(eval_texts) 74 | eval_ngrams.update(zip(ngrams, [idx] * len(ngrams))) 75 | eval_datasets.append(example.get("task_name", "unknown")) 76 | eval_texts.append(example["text"]) 77 | 78 | train_dataset_path = Path(args.train_dataset) 79 | if train_dataset_path.exists() and train_dataset_path.suffix in [".json", ".csv"]: 80 | if train_dataset_path.suffix == ".json": 81 | train_data = Dataset.from_json(args.train_dataset) 82 | elif train_dataset_path.suffix == ".csv": 83 | train_data = Dataset.from_csv(args.train_dataset) 84 | else: 85 | train_data = load_dataset(args.train_dataset, split="train", num_proc=args.num_proc) 86 | 87 | contamination_report = train_data.map( 88 | lambda batch: retrieve_ngrams_batch(batch, eval_ngrams, eval_datasets, eval_texts, args.ngram_length), 89 | batched=True, batch_size=1000, num_proc=args.num_proc, remove_columns=train_data.column_names 90 | ) 91 | 92 | contamination_report = contamination_report.map( 93 | lambda example: add_match_stats(example), num_proc=args.num_proc 94 | ) 95 | 96 | contamination_report.push_to_hub(args.report_dataset_name, private=args.private) 97 | 98 | contamination_report = contamination_report.filter(lambda x: x["diff_ratio"] > args.diff_threshold) 99 | 100 | if args.save_decontaminated: 101 | contaminated_completions = set(contamination_report["completion"]) 102 | filtered_data = train_data.filter(lambda x: x["completion"] not in contaminated_completions) 103 | filtered_data.push_to_hub(args.decontaminated_dataset_name, private=args.private) 104 | 105 | 106 | if __name__ == "__main__": 107 | parser = argparse.ArgumentParser(description="Generate a decontamination report for a dataset.") 108 | parser.add_argument("--eval_dataset", type=str, 109 | default="HuggingFaceTB/phi2_eval_data_for_decontamination", 110 | help="Name of the dataset with benchmark samples to use for decontamination.") 111 | parser.add_argument("--train_dataset", type=str, required=True, 112 | help="Path or name of the training dataset to process.") 113 | parser.add_argument("--report_dataset_name", type=str, required=True, 114 | help="Name for the output dataset with decontamination report.") 115 | parser.add_argument("--decontaminated_dataset_name", type=str, help="Name for the decontaminated dataset.") 116 | parser.add_argument("--private", action='store_true', help="Whether to make the output dataset private.") 117 | parser.add_argument("--ngram_length", type=int, default=10, help="Length of the n-grams to consider.") 118 | parser.add_argument("--diff_threshold", type=float, default=0.5, 119 | help="Threshold for filtering based on difference ratio.") 120 | parser.add_argument("--num_proc", type=int, default=16, help="Number of processes to use for map operations.") 121 | parser.add_argument("--save_decontaminated", action='store_true', 122 | help="Whether to save the decontaminated dataset.") 123 | 124 | args = parser.parse_args() 125 | main(args) -------------------------------------------------------------------------------- /deduplication/README.md: -------------------------------------------------------------------------------- 1 | # Deduplication 2 | 3 | We run deduplication on the dataset using MinHash from [datatrove](https://github.com/huggingface/datatrove). 4 | Considering that the seed samples had already undergone deduplication, and we carefully crafted the prompts to ensure distinct outputs even with identical seeds, the volume of duplicates found in Cosmopedia was less than 1% of the files, which were subsequenlty removed. 5 | 6 | The deduplication script is available at `deduplicate_dataset.py`, make sure to follow the installation guidelines in `datatrove` and to change the paths in the file before running it. -------------------------------------------------------------------------------- /deduplication/deduplicate_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from datatrove.executor.slurm import SlurmPipelineExecutor 4 | from datatrove.pipeline.dedup import MinhashDedupSignature 5 | from datatrove.pipeline.dedup.minhash import ( 6 | MinhashConfig, 7 | MinhashDedupBuckets, 8 | MinhashDedupCluster, 9 | MinhashDedupFilter, 10 | ) 11 | from datatrove.pipeline.readers import HuggingFaceDatasetReader 12 | from datatrove.pipeline.tokens import TokensCounter 13 | from datatrove.pipeline.writers.jsonl import JsonlWriter 14 | 15 | 16 | # you can also change ngrams or the number of buckets and their size here 17 | minhash_config = MinhashConfig() 18 | HF_DATA = "cosmopedia-100k" 19 | 20 | S3_MINHASH_BASE_PATH = f"s3://synthetic-datasets-phi/{HF_DATA}/minhash" 21 | S3_LOGS_FOLDER = f"s3://synthetic-datasets-phi/{HF_DATA}/minhash_logs/" 22 | LOCAL_LOGS_FOLDER = f"./logs/dedup_extras/{HF_DATA}" 23 | os.makedirs(LOCAL_LOGS_FOLDER, exist_ok = True) 24 | 25 | TOTAL_TASKS = 120 26 | 27 | INPUT_READER = HuggingFaceDatasetReader( 28 | dataset=f"HuggingFaceTB/{HF_DATA}", # dataset name 29 | dataset_options={ 30 | "split": "train" 31 | }, 32 | text_key="completion" 33 | ) 34 | # stage 1 computes minhash signatures for each task (each task gets a set of files) 35 | stage1 = SlurmPipelineExecutor( 36 | job_name="mh1", 37 | pipeline=[ 38 | INPUT_READER, 39 | MinhashDedupSignature(output_folder=f"{S3_MINHASH_BASE_PATH}/signatures", config=minhash_config), 40 | ], 41 | tasks=TOTAL_TASKS, 42 | time="5:00:00", 43 | partition="hopper-cpu", 44 | logging_dir=f"{S3_LOGS_FOLDER}/signatures", 45 | slurm_logs_folder=f"{LOCAL_LOGS_FOLDER}/signatures/slurm_logs", 46 | qos="high", 47 | ) 48 | 49 | # stage 2 finds matches between signatures in each bucket 50 | stage2 = SlurmPipelineExecutor( 51 | job_name="mh2", 52 | pipeline=[ 53 | MinhashDedupBuckets( 54 | input_folder=f"{S3_MINHASH_BASE_PATH}/signatures", 55 | output_folder=f"{S3_MINHASH_BASE_PATH}/buckets", 56 | config=minhash_config, 57 | ), 58 | ], 59 | tasks=minhash_config.num_buckets, 60 | time="90:00:00", 61 | partition="hopper-prod", 62 | logging_dir=f"{S3_LOGS_FOLDER}/buckets", 63 | depends=stage1, 64 | slurm_logs_folder=f"{LOCAL_LOGS_FOLDER}/buckets/slurm_logs", 65 | qos="high", 66 | ) 67 | 68 | # stage 3 creates clusters of duplicates using the results from all buckets 69 | stage3 = SlurmPipelineExecutor( 70 | job_name="mh3", 71 | pipeline=[ 72 | MinhashDedupCluster( 73 | input_folder=f"{S3_MINHASH_BASE_PATH}/buckets", 74 | output_folder=f"{S3_MINHASH_BASE_PATH}/remove_ids", 75 | config=minhash_config, 76 | ), 77 | ], 78 | tasks=1, 79 | time="90:00:00", 80 | partition="hopper-prod", 81 | logging_dir=f"{S3_LOGS_FOLDER}/clusters", 82 | mem_per_cpu_gb=70, 83 | cpus_per_task=2, 84 | depends=stage2, 85 | slurm_logs_folder=f"{LOCAL_LOGS_FOLDER}/clusters/slurm_logs", 86 | ) 87 | 88 | # stage 4 reads the original input data and removes all but 1 sample per duplicate cluster 89 | # the data must match exactly stage 1, so number of tasks and the input source must be the same 90 | stage4 = SlurmPipelineExecutor( 91 | job_name="mh4", 92 | pipeline=[ 93 | INPUT_READER, 94 | TokensCounter(), # nice way to see how many tokens we had before and after deduplication 95 | MinhashDedupFilter( 96 | input_folder=f"{S3_MINHASH_BASE_PATH}/remove_ids", 97 | exclusion_writer=JsonlWriter(f"{S3_MINHASH_BASE_PATH}/removed"), 98 | ), 99 | JsonlWriter(output_folder=f"{S3_MINHASH_BASE_PATH}/deduplicated_output"), # output_folder="hf_stack" 100 | ], 101 | tasks=TOTAL_TASKS, 102 | time="50:00:00", 103 | partition="hopper-cpu", 104 | logging_dir=f"{S3_LOGS_FOLDER}/filter", 105 | depends=stage3, 106 | slurm_logs_folder=f"{LOCAL_LOGS_FOLDER}/filter/slurm_logs", 107 | ) 108 | 109 | 110 | stage4.run() 111 | -------------------------------------------------------------------------------- /evaluation/README.md: -------------------------------------------------------------------------------- 1 | # Benchmark evaluation 2 | 3 | This is an extended version of the [FineWeb-v1 evaluation script](https://huggingface.co/datasets/HuggingFaceFW/fineweb/blob/main/lighteval_tasks.py) 4 | In particular, we add the MMLU-Pro, TriviaQA, and GSM8k benchmarks. 5 | 6 | To run the script, please install the latest version of the `lighteval` library: 7 | ```bash 8 | git clone https://github.com/huggingface/lighteval.git 9 | cd lighteval 10 | conda create -n lighteval python=3.10 && conda activate lighteval 11 | pip install '.[accelerate,quantization,adapters]' 12 | ``` 13 | 14 | Then, you can run the evaluation script with the following command: 15 | ```bash 16 | MODEL = "openai-community/gpt2" 17 | accelerate launch --num_processes=1 --main_process_port=29600 "lighteval/run_evals_accelerate.py" --model_args="pretrained=$MODEL" \ 18 | --custom_tasks "lighteval_tasks.py" --output_dir $OUTPUT_DIR --override_batch_size 16 \ 19 | --tasks "custom|hellaswag|0|1,custom|winogrande|0|1,custom|piqa|0|1,custom|siqa|0|1,custom|openbookqa|0|1,custom|arc:easy|0|1,custom|arc:challenge|0|1,custom|commonsense_qa|0|1,custom|trivia_qa|0|1,custom|mmlu_pro_cloze|0|1,custom|gsm8k|5|1,custom|mmlu_cloze:abstract_algebra|0|1,custom|mmlu_cloze:anatomy|0|1,custom|mmlu_cloze:astronomy|0|1,custom|mmlu_cloze:business_ethics|0|1,custom|mmlu_cloze:clinical_knowledge|0|1,custom|mmlu_cloze:college_biology|0|1,custom|mmlu_cloze:college_chemistry|0|1,custom|mmlu_cloze:college_computer_science|0|1,custom|mmlu_cloze:college_mathematics|0|1,custom|mmlu_cloze:college_medicine|0|1,custom|mmlu_cloze:college_physics|0|1,custom|mmlu_cloze:computer_security|0|1,custom|mmlu_cloze:conceptual_physics|0|1,custom|mmlu_cloze:econometrics|0|1,custom|mmlu_cloze:electrical_engineering|0|1,custom|mmlu_cloze:elementary_mathematics|0|1,custom|mmlu_cloze:formal_logic|0|1,custom|mmlu_cloze:global_facts|0|1,custom|mmlu_cloze:high_school_biology|0|1,custom|mmlu_cloze:high_school_chemistry|0|1,custom|mmlu_cloze:high_school_computer_science|0|1,custom|mmlu_cloze:high_school_european_history|0|1,custom|mmlu_cloze:high_school_geography|0|1,custom|mmlu_cloze:high_school_government_and_politics|0|1,custom|mmlu_cloze:high_school_macroeconomics|0|1,custom|mmlu_cloze:high_school_mathematics|0|1,custom|mmlu_cloze:high_school_microeconomics|0|1,custom|mmlu_cloze:high_school_physics|0|1,custom|mmlu_cloze:high_school_psychology|0|1,custom|mmlu_cloze:high_school_statistics|0|1,custom|mmlu_cloze:high_school_us_history|0|1,custom|mmlu_cloze:high_school_world_history|0|1,custom|mmlu_cloze:human_aging|0|1,custom|mmlu_cloze:human_sexuality|0|1,custom|mmlu_cloze:international_law|0|1,custom|mmlu_cloze:jurisprudence|0|1,custom|mmlu_cloze:logical_fallacies|0|1,custom|mmlu_cloze:machine_learning|0|1,custom|mmlu_cloze:management|0|1,custom|mmlu_cloze:marketing|0|1,custom|mmlu_cloze:medical_genetics|0|1,custom|mmlu_cloze:miscellaneous|0|1,custom|mmlu_cloze:moral_disputes|0|1,custom|mmlu_cloze:moral_scenarios|0|1,custom|mmlu_cloze:nutrition|0|1,custom|mmlu_cloze:philosophy|0|1,custom|mmlu_cloze:prehistory|0|1,custom|mmlu_cloze:professional_accounting|0|1,custom|mmlu_cloze:professional_law|0|1,custom|mmlu_cloze:professional_medicine|0|1,custom|mmlu_cloze:professional_psychology|0|1,custom|mmlu_cloze:public_relations|0|1,custom|mmlu_cloze:security_studies|0|1,custom|mmlu_cloze:sociology|0|1,custom|mmlu_cloze:us_foreign_policy|0|1,custom|mmlu_cloze:virology|0|1,custom|mmlu_cloze:world_religions|0|1" 20 | ``` 21 | -------------------------------------------------------------------------------- /evaluation/eval.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=eval_cosmo 3 | #SBATCH --partition hopper-prod 4 | #SBATCH --nodes=1 5 | #SBATCH --ntasks-per-node=1 6 | #SBATCH --cpus-per-task=48 7 | #SBATCH --mem-per-cpu=20G 8 | #SBATCH --gpus=8 9 | #SBATCH -o %x_%j.out 10 | #SBATCH -e %x_%j.err 11 | #SBATCH --time=1-00:00:00 12 | 13 | set -x -e 14 | source ~/.bashrc 15 | source "/admin/home/anton/miniforge3/etc/profile.d/conda.sh" 16 | source activate cosmolighteval 17 | 18 | export HF_HOME="/fsx/anton/cosmo/cache/" 19 | export HF_DATASETS_CACHE="/fsx/anton/cosmo/cache/" 20 | 21 | MODELS=( 22 | "openai-community/gpt2" 23 | "openai-community/gpt2-medium" 24 | "openai-community/gpt2-xl" 25 | "karpathy/gpt2_1558M_final4_hf" 26 | "EleutherAI/pythia-160m" 27 | "Qwen/Qwen2-0.5B" 28 | "HuggingFaceTB/cosmo2-1.7B-1T" 29 | "HuggingFaceTB/cosmo2-149M-600B-fp32" 30 | "HuggingFaceTB/cosmo2-362M-600B-fp32" 31 | "HuggingFaceTB/cosmo2-1.7B-900B" 32 | "HuggingFaceTB/mixture11-600B" 33 | "HuggingFaceTB/cosmo2-base-magpie-lr-5e-5" 34 | "HuggingFaceTB/cosmo-300B-with-decay-instruct-mixture-5-bis" 35 | "HuggingFaceTB/cosmo2-600B-tokens-base-mixture" 36 | "microsoft/phi-1_5" 37 | "microsoft/phi-2" 38 | "HuggingFaceTB/cosmo-1b" 39 | "HuggingFaceTB/cosmo2-test-classic" 40 | "HuggingFaceFW/ablation-model-fineweb-edu" 41 | "HuggingFaceFW/ablation-model-fineweb-v1" 42 | "Qwen/Qwen1.5-1.8B" 43 | "Qwen/Qwen2-1.5B" 44 | "Qwen/Qwen1.5-0.5B" 45 | "stabilityai/stablelm-2-1_6b" 46 | "allenai/OLMo-1B-hf" 47 | "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" 48 | "Qwen/Qwen2-1.5B-Instruct" 49 | "Qwen/Qwen2-0.5B-Instruct" 50 | "HuggingFaceFW/ablation-model-refinedweb" 51 | "HuggingFaceFW/ablation-model-c4" 52 | "HuggingFaceFW/ablation-model-dolma-v1_6" 53 | "HuggingFaceFW/ablation-model-slimpajama" 54 | "HuggingFaceFW/ablation-model-the-pile" 55 | "HuggingFaceFW/ablation-model-redpajama2" 56 | ) 57 | OUTPUT_DIR="/fsx/anton/cosmopedia/eval_results_cosmo2/" 58 | OUTPUT_DATASET="HuggingFaceTB/eval_results_cosmo2" 59 | 60 | for model in "${MODELS[@]}" 61 | do 62 | accelerate launch --num_processes=8 --main_process_port=29600 "/admin/home/anton/repos/lighteval/run_evals_accelerate.py" --model_args="pretrained=$model" \ 63 | --custom_tasks "lighteval_tasks.py" --output_dir $OUTPUT_DIR --override_batch_size 16 \ 64 | --tasks "custom|hellaswag|0|1,custom|winogrande|0|1,custom|piqa|0|1,custom|siqa|0|1,custom|openbookqa|0|1,custom|arc:easy|0|1,custom|arc:challenge|0|1,custom|commonsense_qa|0|1,custom|boolq|0|1,custom|trivia_qa|0|1,custom|trivia_qa|5|1,custom|mmlu_pro_cloze|0|1,custom|mmlu_stem_mc|0|1,custom|mmlu_stem_cloze|0|1,custom|gsm8k|5|1,custom|mmlu_mc:abstract_algebra|0|1,custom|mmlu_mc:anatomy|0|1,custom|mmlu_mc:astronomy|0|1,custom|mmlu_mc:business_ethics|0|1,custom|mmlu_mc:clinical_knowledge|0|1,custom|mmlu_mc:college_biology|0|1,custom|mmlu_mc:college_chemistry|0|1,custom|mmlu_mc:college_computer_science|0|1,custom|mmlu_mc:college_mathematics|0|1,custom|mmlu_mc:college_medicine|0|1,custom|mmlu_mc:college_physics|0|1,custom|mmlu_mc:computer_security|0|1,custom|mmlu_mc:conceptual_physics|0|1,custom|mmlu_mc:econometrics|0|1,custom|mmlu_mc:electrical_engineering|0|1,custom|mmlu_mc:elementary_mathematics|0|1,custom|mmlu_mc:formal_logic|0|1,custom|mmlu_mc:global_facts|0|1,custom|mmlu_mc:high_school_biology|0|1,custom|mmlu_mc:high_school_chemistry|0|1,custom|mmlu_mc:high_school_computer_science|0|1,custom|mmlu_mc:high_school_european_history|0|1,custom|mmlu_mc:high_school_geography|0|1,custom|mmlu_mc:high_school_government_and_politics|0|1,custom|mmlu_mc:high_school_macroeconomics|0|1,custom|mmlu_mc:high_school_mathematics|0|1,custom|mmlu_mc:high_school_microeconomics|0|1,custom|mmlu_mc:high_school_physics|0|1,custom|mmlu_mc:high_school_psychology|0|1,custom|mmlu_mc:high_school_statistics|0|1,custom|mmlu_mc:high_school_us_history|0|1,custom|mmlu_mc:high_school_world_history|0|1,custom|mmlu_mc:human_aging|0|1,custom|mmlu_mc:human_sexuality|0|1,custom|mmlu_mc:international_law|0|1,custom|mmlu_mc:jurisprudence|0|1,custom|mmlu_mc:logical_fallacies|0|1,custom|mmlu_mc:machine_learning|0|1,custom|mmlu_mc:management|0|1,custom|mmlu_mc:marketing|0|1,custom|mmlu_mc:medical_genetics|0|1,custom|mmlu_mc:miscellaneous|0|1,custom|mmlu_mc:moral_disputes|0|1,custom|mmlu_mc:moral_scenarios|0|1,custom|mmlu_mc:nutrition|0|1,custom|mmlu_mc:philosophy|0|1,custom|mmlu_mc:prehistory|0|1,custom|mmlu_mc:professional_accounting|0|1,custom|mmlu_mc:professional_law|0|1,custom|mmlu_mc:professional_medicine|0|1,custom|mmlu_mc:professional_psychology|0|1,custom|mmlu_mc:public_relations|0|1,custom|mmlu_mc:security_studies|0|1,custom|mmlu_mc:sociology|0|1,custom|mmlu_mc:us_foreign_policy|0|1,custom|mmlu_mc:virology|0|1,custom|mmlu_mc:world_religions|0|1,custom|mmlu_cloze:abstract_algebra|0|1,custom|mmlu_cloze:anatomy|0|1,custom|mmlu_cloze:astronomy|0|1,custom|mmlu_cloze:business_ethics|0|1,custom|mmlu_cloze:clinical_knowledge|0|1,custom|mmlu_cloze:college_biology|0|1,custom|mmlu_cloze:college_chemistry|0|1,custom|mmlu_cloze:college_computer_science|0|1,custom|mmlu_cloze:college_mathematics|0|1,custom|mmlu_cloze:college_medicine|0|1,custom|mmlu_cloze:college_physics|0|1,custom|mmlu_cloze:computer_security|0|1,custom|mmlu_cloze:conceptual_physics|0|1,custom|mmlu_cloze:econometrics|0|1,custom|mmlu_cloze:electrical_engineering|0|1,custom|mmlu_cloze:elementary_mathematics|0|1,custom|mmlu_cloze:formal_logic|0|1,custom|mmlu_cloze:global_facts|0|1,custom|mmlu_cloze:high_school_biology|0|1,custom|mmlu_cloze:high_school_chemistry|0|1,custom|mmlu_cloze:high_school_computer_science|0|1,custom|mmlu_cloze:high_school_european_history|0|1,custom|mmlu_cloze:high_school_geography|0|1,custom|mmlu_cloze:high_school_government_and_politics|0|1,custom|mmlu_cloze:high_school_macroeconomics|0|1,custom|mmlu_cloze:high_school_mathematics|0|1,custom|mmlu_cloze:high_school_microeconomics|0|1,custom|mmlu_cloze:high_school_physics|0|1,custom|mmlu_cloze:high_school_psychology|0|1,custom|mmlu_cloze:high_school_statistics|0|1,custom|mmlu_cloze:high_school_us_history|0|1,custom|mmlu_cloze:high_school_world_history|0|1,custom|mmlu_cloze:human_aging|0|1,custom|mmlu_cloze:human_sexuality|0|1,custom|mmlu_cloze:international_law|0|1,custom|mmlu_cloze:jurisprudence|0|1,custom|mmlu_cloze:logical_fallacies|0|1,custom|mmlu_cloze:machine_learning|0|1,custom|mmlu_cloze:management|0|1,custom|mmlu_cloze:marketing|0|1,custom|mmlu_cloze:medical_genetics|0|1,custom|mmlu_cloze:miscellaneous|0|1,custom|mmlu_cloze:moral_disputes|0|1,custom|mmlu_cloze:moral_scenarios|0|1,custom|mmlu_cloze:nutrition|0|1,custom|mmlu_cloze:philosophy|0|1,custom|mmlu_cloze:prehistory|0|1,custom|mmlu_cloze:professional_accounting|0|1,custom|mmlu_cloze:professional_law|0|1,custom|mmlu_cloze:professional_medicine|0|1,custom|mmlu_cloze:professional_psychology|0|1,custom|mmlu_cloze:public_relations|0|1,custom|mmlu_cloze:security_studies|0|1,custom|mmlu_cloze:sociology|0|1,custom|mmlu_cloze:us_foreign_policy|0|1,custom|mmlu_cloze:virology|0|1,custom|mmlu_cloze:world_religions|0|1" 65 | done 66 | 67 | huggingface-cli upload $OUTPUT_DATASET $OUTPUT_DIR / --repo-type dataset --delete="*" 68 | # huggingface-cli upload HuggingFaceTB/eval_results_cosmo2 /fsx/anton/cosmopedia/eval_results_cosmo2/ / --repo-type dataset --delete="*" -------------------------------------------------------------------------------- /evaluation/lighteval_tasks.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: F405, F403, F401 2 | """ 3 | Custom evaluation tasks for lighteval 4 | 5 | Do note that we ran the evals with `max_samples=1000` to speed up large evals. 6 | Most custom prompt changes were in an attempt to improve signal for small models in general. 7 | 8 | This file generally creates just a TASKS_TABLE and TASKS_GROUPS which are then imported by LightEval. 9 | 10 | Example usage (lighteval_tasks.py is the path to this file): 11 | =================== 12 | accelerate launch --num_processes=1 lighteval/run_evals_accelerate.py --model_args="pretrained=HuggingFaceTB/cosmo-1b" \ 13 | --custom_tasks "lighteval_tasks.py" --output_dir [OUTPUTPATH] --max_samples 1000 \ 14 | --tasks "custom|hellaswag|0|1,custom|winogrande|0|1,custom|piqa|0|1,custom|siqa|0|1,custom|openbookqa|0|1,custom|arc:easy|0|1,custom|arc:challenge|0|1,custom|commonsense_qa|0|1,custom|mmlu:abstract_algebra|0|1,custom|mmlu:anatomy|0|1,custom|mmlu:astronomy|0|1,custom|mmlu:business_ethics|0|1,custom|mmlu:clinical_knowledge|0|1,custom|mmlu:college_biology|0|1,custom|mmlu:college_chemistry|0|1,custom|mmlu:college_computer_science|0|1,custom|mmlu:college_mathematics|0|1,custom|mmlu:college_medicine|0|1,custom|mmlu:college_physics|0|1,custom|mmlu:computer_security|0|1,custom|mmlu:conceptual_physics|0|1,custom|mmlu:econometrics|0|1,custom|mmlu:electrical_engineering|0|1,custom|mmlu:elementary_mathematics|0|1,custom|mmlu:formal_logic|0|1,custom|mmlu:global_facts|0|1,custom|mmlu:high_school_biology|0|1,custom|mmlu:high_school_chemistry|0|1,custom|mmlu:high_school_computer_science|0|1,custom|mmlu:high_school_european_history|0|1,custom|mmlu:high_school_geography|0|1,custom|mmlu:high_school_government_and_politics|0|1,custom|mmlu:high_school_macroeconomics|0|1,custom|mmlu:high_school_mathematics|0|1,custom|mmlu:high_school_microeconomics|0|1,custom|mmlu:high_school_physics|0|1,custom|mmlu:high_school_psychology|0|1,custom|mmlu:high_school_statistics|0|1,custom|mmlu:high_school_us_history|0|1,custom|mmlu:high_school_world_history|0|1,custom|mmlu:human_aging|0|1,custom|mmlu:human_sexuality|0|1,custom|mmlu:international_law|0|1,custom|mmlu:jurisprudence|0|1,custom|mmlu:logical_fallacies|0|1,custom|mmlu:machine_learning|0|1,custom|mmlu:management|0|1,custom|mmlu:marketing|0|1,custom|mmlu:medical_genetics|0|1,custom|mmlu:miscellaneous|0|1,custom|mmlu:moral_disputes|0|1,custom|mmlu:moral_scenarios|0|1,custom|mmlu:nutrition|0|1,custom|mmlu:philosophy|0|1,custom|mmlu:prehistory|0|1,custom|mmlu:professional_accounting|0|1,custom|mmlu:professional_law|0|1,custom|mmlu:professional_medicine|0|1,custom|mmlu:professional_psychology|0|1,custom|mmlu:public_relations|0|1,custom|mmlu:security_studies|0|1,custom|mmlu:sociology|0|1,custom|mmlu:us_foreign_policy|0|1,custom|mmlu:virology|0|1,custom|mmlu:world_religions|0|1" 15 | =================== 16 | 17 | More info here: https://github.com/huggingface/lighteval?tab=readme-ov-file#evaluate-a-model-on-extended-community-or-custom-tasks 18 | For more info on differences between MMLU implementations: https://huggingface.co/blog/open-llm-leaderboard-mmlu#1001-flavors-of-mmlu 19 | In particular, the default leaderboard MMLU implementation (which uses "A", "B", etc as answer targets) gives generally random results on small/non instruction tuned models. 20 | Instead, we use the full MMLU answer as the target. 21 | """ 22 | import re 23 | from typing import List, Tuple 24 | 25 | from lighteval.metrics.metrics import Metrics 26 | from lighteval.tasks.lighteval_task import LightevalTaskConfig 27 | from lighteval.tasks.requests import Doc 28 | from lighteval.tasks.default_prompts import LETTER_INDICES 29 | 30 | _TASKS_STRINGS: List[Tuple[LightevalTaskConfig, str]] = [] 31 | _TASKS: List[LightevalTaskConfig] = [] 32 | 33 | ## COMMON_SENSE_REASONING_TASKS ## 34 | COMMON_SENSE_REASONING_TASKS = [ 35 | LightevalTaskConfig( 36 | name="hellaswag", 37 | prompt_function="hellaswag_prompt", 38 | hf_repo="hellaswag", 39 | hf_subset="default", 40 | metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace], 41 | ), 42 | LightevalTaskConfig( 43 | name="winogrande", 44 | prompt_function="winogrande", 45 | hf_repo="winogrande", 46 | hf_subset="winogrande_xl", 47 | metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace], 48 | ), 49 | LightevalTaskConfig( 50 | name="piqa", 51 | prompt_function="piqa_harness", 52 | hf_repo="piqa", 53 | hf_subset="plain_text", 54 | metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace], 55 | ), 56 | LightevalTaskConfig( 57 | name="siqa", 58 | prompt_function="siqa_prompt", 59 | hf_repo="lighteval/siqa", 60 | hf_subset="default", 61 | hf_avail_splits=["train", "validation"], 62 | metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace], 63 | ), 64 | LightevalTaskConfig( 65 | name="openbookqa", 66 | prompt_function="openbookqa", 67 | hf_repo="openbookqa", 68 | hf_subset="main", 69 | metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace], 70 | ), 71 | LightevalTaskConfig( 72 | name="arc:easy", 73 | prompt_function="arc", 74 | hf_repo="ai2_arc", 75 | hf_subset="ARC-Easy", 76 | evaluation_splits=["test"], 77 | generation_size=1, 78 | metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace], 79 | ), 80 | LightevalTaskConfig( 81 | name="arc:challenge", 82 | prompt_function="arc", 83 | hf_repo="ai2_arc", 84 | hf_subset="ARC-Challenge", 85 | evaluation_splits=["test"], 86 | generation_size=1, 87 | metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace], 88 | ), 89 | LightevalTaskConfig( 90 | name="commonsense_qa", 91 | prompt_function="commonsense_qa_prompt", 92 | hf_repo="commonsense_qa", 93 | hf_subset="default", 94 | metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace], 95 | ), 96 | LightevalTaskConfig( 97 | name="mmlu_pro_cloze", 98 | prompt_function="mmlu_pro_cloze_prompt", 99 | hf_repo="TIGER-Lab/MMLU-Pro", 100 | hf_subset="default", 101 | metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace], 102 | evaluation_splits=["test"], 103 | few_shots_split="validation", 104 | few_shots_select=None, 105 | generation_size=-1, 106 | stop_sequence=None, 107 | output_regex=None, 108 | frozen=False, 109 | ), 110 | LightevalTaskConfig( 111 | name="mmlu_pro_mc", 112 | prompt_function="mmlu_pro_mc_prompt", 113 | hf_repo="TIGER-Lab/MMLU-Pro", 114 | hf_subset="default", 115 | metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace], 116 | evaluation_splits=["test"], 117 | few_shots_split="validation", 118 | few_shots_select=None, 119 | generation_size=1, 120 | stop_sequence=None, 121 | output_regex=None, 122 | frozen=False, 123 | ), 124 | LightevalTaskConfig( 125 | name="boolq", 126 | prompt_function="boolq_prompt", 127 | hf_repo="super_glue", 128 | hf_subset="boolq", 129 | metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace], 130 | trust_dataset=True, 131 | stop_sequence=["\n"], 132 | ), 133 | LightevalTaskConfig( 134 | name="trivia_qa", 135 | prompt_function="triviaqa", 136 | hf_repo="mandarjoshi/trivia_qa", 137 | hf_subset="rc.nocontext", 138 | hf_avail_splits=["train", "validation"], 139 | evaluation_splits=["validation"], 140 | metric=[Metrics.quasi_exact_match_triviaqa], 141 | generation_size=20, 142 | trust_dataset=True, 143 | stop_sequence=["\n", ".", ","], 144 | few_shots_select="random_sampling_from_train", 145 | ), 146 | ] 147 | 148 | 149 | def boolq_prompt(line, task_name: str = None): 150 | return Doc( 151 | task_name=task_name, 152 | query=f"{line['passage']}\nQuestion: {line['question'].capitalize()}?\nAnswer:", 153 | choices=[" No", " Yes"], # Only gold 154 | gold_index=int(line["label"]), 155 | ) 156 | 157 | 158 | def mmlu_pro_cloze_prompt(line, task_name: str = None): 159 | """MMLU-Pro prompt without letters""" 160 | topic = line["category"] 161 | prompt = f"The following are questions about {topic.replace('_', ' ')}.\nQuestion: " 162 | prompt += line["question"] + "\nAnswer:" 163 | 164 | return Doc( 165 | task_name=task_name, 166 | query=prompt, 167 | choices=[f" {c}" for c in line["options"]], 168 | gold_index=line["answer_index"], 169 | instruction=f"The following are questions about {topic.replace('_', ' ')}.\n", 170 | ) 171 | 172 | 173 | def mmlu_pro_mc_prompt(line, task_name: str = None): 174 | topic = line["category"] 175 | query = f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n" 176 | query += line["question"] + "\n" 177 | query += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["options"])]) 178 | query += "Answer:" 179 | 180 | return Doc( 181 | task_name=task_name, 182 | query=query, 183 | choices=LETTER_INDICES[: len(line["options"])], 184 | gold_index=line["answer_index"], 185 | instruction=f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n", 186 | target_for_fewshot_sorting=LETTER_INDICES[line["answer_index"]], 187 | ) 188 | 189 | 190 | def commonsense_qa_prompt(line, task_name: str = None): 191 | return Doc( 192 | task_name=task_name, 193 | query=line["question"], 194 | choices=[f" {c}" for c in line["choices"]["text"]], 195 | gold_index=LETTER_INDICES.index(line["answerKey"].strip()), 196 | instruction="", 197 | ) 198 | 199 | 200 | def siqa_prompt(line, task_name: str = None): 201 | return Doc( 202 | task_name=task_name, 203 | query=line["context"] + " " + line["question"], 204 | choices=[f" {c}" for c in [line["answerA"], line["answerB"], line["answerC"]]], 205 | gold_index=int(line["label"]) - 1, 206 | instruction="", 207 | ) 208 | 209 | 210 | def hellaswag_prompt(line, task_name: str = None): 211 | def preprocess(text): 212 | """Comes from AiHarness""" 213 | # text = text.strip() 214 | # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag. 215 | text = text.replace(" [title]", ". ") 216 | text = re.sub("\\[.*?\\]", "", text) 217 | text = text.replace(" ", " ") 218 | return text 219 | 220 | ctx = f"{line['ctx_a']} {line['ctx_b'].capitalize()} " 221 | return Doc( 222 | task_name=task_name, 223 | query=preprocess(line["activity_label"] + ": " + ctx), 224 | choices=[" " + preprocess(ending) for ending in line["endings"]], 225 | gold_index=int(line["label"]) if line["label"] != "" else -1, # -1 for test 226 | # "metric": "choices_loglikelihood", 227 | ) 228 | 229 | 230 | GSM8K = LightevalTaskConfig( 231 | name="gsm8k", 232 | prompt_function="gsm8k", 233 | hf_repo="gsm8k", 234 | hf_subset="main", 235 | hf_avail_splits=["train", "test"], 236 | evaluation_splits=["test"], 237 | metric=[Metrics.quasi_exact_match_gsm8k], 238 | generation_size=256, 239 | stop_sequence=["Question:", "Question"], 240 | few_shots_select="random_sampling_from_train", 241 | ) 242 | MATH_TASKS = [ 243 | LightevalTaskConfig( 244 | name=f"math:{subset}", 245 | prompt_function="math", 246 | hf_repo="lighteval/MATH", 247 | hf_subset=subset, 248 | hf_avail_splits=["train", "test"], 249 | evaluation_splits=["test"], 250 | metric=[Metrics.quasi_exact_match_math], 251 | generation_size=256, 252 | stop_sequence=["Problem:", "Problem"], 253 | few_shots_select="random_sampling_from_train", 254 | ) 255 | for subset in [ 256 | "algebra", 257 | "counting_and_probability", 258 | "geometry", 259 | "intermediate_algebra", 260 | "number_theory", 261 | "prealgebra", 262 | "precalculus", 263 | ] 264 | ] 265 | 266 | # 0 short for common sense 267 | COMMON_SENSE_REASONING_STRING = [(t, f"custom|{t.name}|0|1") for t in COMMON_SENSE_REASONING_TASKS] 268 | _TASKS_STRINGS.extend(COMMON_SENSE_REASONING_STRING) 269 | _TASKS_STRINGS.extend([(GSM8K, f"custom|{GSM8K.name}|5|1")]) 270 | _TASKS_STRINGS.extend([(t, f"custom|{t.name}|4|1") for t in MATH_TASKS]) 271 | _TASKS += COMMON_SENSE_REASONING_TASKS 272 | _TASKS += [GSM8K] + MATH_TASKS 273 | 274 | ## MMLU ## 275 | class CustomMMLUEvaluationTask(LightevalTaskConfig): 276 | def __init__( 277 | self, 278 | name, 279 | prompt_function="mmlu_prompt", 280 | hf_repo="lighteval/mmlu", 281 | hf_subset=None, 282 | # metric=[Metrics.loglikelihood_acc_single_token], 283 | metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace], 284 | hf_avail_splits=None, 285 | evaluation_splits=["test"], 286 | few_shots_split="dev", 287 | few_shots_select=None, 288 | generation_size=-1, 289 | stop_sequence=None, 290 | output_regex=None, 291 | frozen=False, 292 | ): 293 | super().__init__( 294 | name=name, 295 | prompt_function=prompt_function, 296 | hf_repo=hf_repo, 297 | hf_subset=hf_subset, 298 | metric=metric, 299 | hf_avail_splits=hf_avail_splits, 300 | evaluation_splits=evaluation_splits, 301 | few_shots_split=few_shots_split, 302 | few_shots_select=few_shots_select, 303 | generation_size=generation_size, 304 | stop_sequence=stop_sequence, 305 | output_regex=output_regex, 306 | frozen=frozen, 307 | ) 308 | 309 | MMLU_TASKS = [] 310 | mmlu_subsets = [ 311 | "abstract_algebra", 312 | "anatomy", 313 | "astronomy", 314 | "business_ethics", 315 | "clinical_knowledge", 316 | "college_biology", 317 | "college_chemistry", 318 | "college_computer_science", 319 | "college_mathematics", 320 | "college_medicine", 321 | "college_physics", 322 | "computer_security", 323 | "conceptual_physics", 324 | "econometrics", 325 | "electrical_engineering", 326 | "elementary_mathematics", 327 | "formal_logic", 328 | "global_facts", 329 | "high_school_biology", 330 | "high_school_chemistry", 331 | "high_school_computer_science", 332 | "high_school_european_history", 333 | "high_school_geography", 334 | "high_school_government_and_politics", 335 | "high_school_macroeconomics", 336 | "high_school_mathematics", 337 | "high_school_microeconomics", 338 | "high_school_physics", 339 | "high_school_psychology", 340 | "high_school_statistics", 341 | "high_school_us_history", 342 | "high_school_world_history", 343 | "human_aging", 344 | "human_sexuality", 345 | "international_law", 346 | "jurisprudence", 347 | "logical_fallacies", 348 | "machine_learning", 349 | "management", 350 | "marketing", 351 | "medical_genetics", 352 | "miscellaneous", 353 | "moral_disputes", 354 | "moral_scenarios", 355 | "nutrition", 356 | "philosophy", 357 | "prehistory", 358 | "professional_accounting", 359 | "professional_law", 360 | "professional_medicine", 361 | "professional_psychology", 362 | "public_relations", 363 | "security_studies", 364 | "sociology", 365 | "us_foreign_policy", 366 | "virology", 367 | "world_religions", 368 | ] 369 | 370 | for answer_type in ("mc", "cloze"): 371 | prompt_function = f"mmlu_{answer_type}_prompt" 372 | generation_size = -1 if answer_type == "cloze" else 1 373 | for subset in mmlu_subsets: 374 | MMLU_TASKS.append( 375 | CustomMMLUEvaluationTask( 376 | name=f"mmlu_{answer_type}:{subset}", 377 | prompt_function=prompt_function, 378 | hf_subset=subset, 379 | generation_size=generation_size 380 | ) 381 | ) 382 | 383 | MMLU_TASKS += [ 384 | CustomMMLUEvaluationTask( 385 | name=f"mmlu_stem_mc", 386 | hf_repo="TIGER-Lab/MMLU-STEM", 387 | prompt_function="mmlu_mc_prompt", 388 | hf_subset="default", 389 | generation_size=1 390 | ), 391 | CustomMMLUEvaluationTask( 392 | name=f"mmlu_stem_cloze", 393 | hf_repo="TIGER-Lab/MMLU-STEM", 394 | prompt_function="mmlu_cloze_prompt", 395 | hf_subset="default", 396 | generation_size=-1 397 | ), 398 | ] 399 | 400 | 401 | def mmlu_cloze_prompt(line, task_name: str = None): 402 | """MMLU prompt without letters""" 403 | topic = line["subject"] 404 | prompt = f"The following are questions about {topic.replace('_', ' ')}.\nQuestion: " 405 | prompt += line["question"] + "\nAnswer:" 406 | 407 | return Doc( 408 | task_name=task_name, 409 | query=prompt, 410 | choices=[f" {c}" for c in line["choices"]], 411 | gold_index=line["answer"], 412 | instruction=f"The following are questions about {topic.replace('_', ' ')}.\n", 413 | ) 414 | 415 | 416 | def mmlu_mc_prompt(line, task_name: str = None): 417 | topic = line["subject"] 418 | query = f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n" 419 | query += line["question"] + "\n" 420 | query += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["choices"])]) 421 | query += "Answer:" 422 | 423 | gold_ix = LETTER_INDICES.index(line["answer"]) if isinstance(line["answer"], str) else line["answer"] 424 | 425 | return Doc( 426 | task_name=task_name, 427 | query=query, 428 | choices=[" A", " B", " C", " D"], 429 | gold_index=gold_ix, 430 | instruction=f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n", 431 | target_for_fewshot_sorting=[" A", " B", " C", " D"][gold_ix], 432 | ) 433 | 434 | 435 | MMLU_STRING = [(t, f"custom|{t.name}|0|1") for t in MMLU_TASKS] 436 | _TASKS_STRINGS.extend(MMLU_STRING) 437 | _TASKS += MMLU_TASKS 438 | 439 | # common sense reasoning + mmlu 440 | EARLY_SIGNAL_TASKS = ",".join([t[1] for t in COMMON_SENSE_REASONING_STRING] + [t[1] for t in MMLU_STRING]) 441 | 442 | # Convert to dict for lighteval 443 | TASKS_TABLE = _TASKS 444 | # You can have a few pre-organised groups of tasks 445 | TASKS_GROUPS = { 446 | "early-signal": EARLY_SIGNAL_TASKS, 447 | "math": f"custom|{GSM8K.name}|5|1" + "," + ",".join([f"custom|{t.name}|4|1" for t in MATH_TASKS]), 448 | } 449 | -------------------------------------------------------------------------------- /fulltext_search/README.md: -------------------------------------------------------------------------------- 1 | # Fulltext search with BISAC topics 2 | 3 | This is a simple example of how to use Manticore for fulltext search over CommonCrawl pages. 4 | 5 | Due to the size of the corpus and the specifics of the HuggingFace cluster, we only provide these scripts as a reference. -------------------------------------------------------------------------------- /fulltext_search/index_docs.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | import sys 4 | import random 5 | 6 | import requests 7 | from datasets import load_dataset 8 | 9 | 10 | def insert_batch(batch): 11 | ndjson = "" 12 | 13 | index_name = f"fineweb{random.randint(0, 63)}" 14 | 15 | for text, _id, url, language_score, token_count in zip( 16 | batch["text"], 17 | batch["id"], 18 | batch["url"], 19 | batch["language_score"], 20 | batch["token_count"], 21 | ): 22 | doc = { 23 | "insert": { 24 | "index": index_name, 25 | "_id": _id.split(":")[-1].strip(">"), 26 | "doc": { 27 | "content": text, 28 | "fw_id": _id.split(":")[-1].strip(">"), 29 | "url": url, 30 | "language_score": language_score, 31 | "token_count": token_count, 32 | }, 33 | } 34 | } 35 | ndjson += json.dumps(doc) + "\n" 36 | 37 | response = None 38 | while response is None: 39 | try: 40 | response = requests.post( 41 | "http://127.0.0.1:9308/bulk", 42 | headers={"Content-Type": "application/x-ndjson"}, 43 | data=ndjson, 44 | ) 45 | except requests.exceptions.ConnectionError as e: 46 | print(e, file=sys.stderr) 47 | time.sleep(1) 48 | pass 49 | 50 | return {"response": [response.status_code]} 51 | 52 | 53 | def main(): 54 | sql_url = "http://127.0.0.1:9308/sql?mode=raw" 55 | 56 | print("Removing table", file=sys.stderr) 57 | while True: 58 | try: 59 | requests.post(sql_url, data={"query": "drop table if exists fineweb"}) 60 | break 61 | except requests.exceptions.ConnectionError as e: 62 | print(e, file=sys.stderr) 63 | time.sleep(5) 64 | pass 65 | 66 | print("Creating table", file=sys.stderr) 67 | for i in range(64): 68 | response = requests.post( 69 | sql_url, data={"query": f"drop table if exists fineweb{i}"} 70 | ) 71 | print(response.text, file=sys.stderr) 72 | local_query = f"create table fineweb{i}(content text, fw_id string, url string, language_score float, token_count int) charset_table='non_cjk' stopwords='en' morphology='stem_en'" 73 | response = requests.post(sql_url, data={"query": local_query}) 74 | print(response.text, file=sys.stderr) 75 | 76 | distributed_query = "create table fineweb type='distributed'" 77 | for i in range(64): 78 | distributed_query += f" local='fineweb{i}'" 79 | response = requests.post(sql_url, data={"query": distributed_query}) 80 | print(response.text, file=sys.stderr) 81 | 82 | for dump in ["CC-MAIN-2024-10", "CC-MAIN-2023-50"]: 83 | print("Loading dataset", file=sys.stderr) 84 | dataset = load_dataset( 85 | "HuggingFaceFW/fineweb", 86 | dump, 87 | split="train", 88 | num_proc=64, 89 | cache_dir="/scratch/cosmo/.cache", 90 | ) 91 | dataset = dataset.select_columns( 92 | ["text", "id", "url", "language_score", "token_count"] 93 | ) 94 | dataset = dataset.map( 95 | insert_batch, 96 | batched=True, 97 | batch_size=10000, 98 | remove_columns=["text", "id", "url", "language_score", "token_count"], 99 | num_proc=64, 100 | ) 101 | for _ in dataset: 102 | pass 103 | 104 | time.sleep(30) 105 | for i in range(64): 106 | print(f"Optimizing table fineweb{i}", file=sys.stderr) 107 | response = requests.post( 108 | sql_url, 109 | data={"query": f"FLUSH TABLE fineweb{i}"}, 110 | timeout=600, 111 | ) 112 | print(response.text, file=sys.stderr) 113 | response = requests.post( 114 | sql_url, 115 | data={"query": f"OPTIMIZE TABLE fineweb{i} OPTION cutoff=16, sync=1"}, 116 | timeout=600, 117 | ) 118 | print(response.text, file=sys.stderr) 119 | response = requests.post( 120 | sql_url, 121 | data={"query": f"FREEZE fineweb{i}"}, 122 | timeout=600, 123 | ) 124 | print(response.text, file=sys.stderr) 125 | 126 | response = requests.post( 127 | "http://127.0.0.1:9308/search", 128 | data='{"index":"fineweb","query":{"match":{"*":"hello world"}}}', 129 | ) 130 | print(response.text, file=sys.stderr) 131 | 132 | # print("Backing up the index", file=sys.stderr) 133 | # time.sleep(30) 134 | # response = requests.post( 135 | # sql_url, 136 | # data={"query": "BACKUP TO /tmp/backups"}, 137 | # ) 138 | # print(response.text, file=sys.stderr) 139 | 140 | 141 | if __name__ == "__main__": 142 | main() 143 | -------------------------------------------------------------------------------- /fulltext_search/index_docs.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=index_fineweb 3 | #SBATCH --partition hopper-prod 4 | #SBATCH --nodes=1 5 | #SBATCH --ntasks-per-node=1 6 | #SBATCH --cpus-per-task=96 7 | #SBATCH --mem-per-cpu=20G 8 | #SBATCH -o %x_%j.out 9 | #SBATCH -e %x_%j.err 10 | #SBATCH --time=7-00:00:00 11 | 12 | set -x -e 13 | source ~/.bashrc 14 | source "$CONDA_PREFIX/etc/profile.d/conda.sh" 15 | source activate pyspark 16 | 17 | ulimit -n 99999 18 | 19 | mkdir -p /scratch/cosmo/manticore_idx 20 | rm -rf /scratch/cosmo/manticore_idx/* 21 | srun --container-image='manticoresearch/manticore:6.2.12' \ 22 | --container-env=EXTRA=1 \ 23 | --container-mounts="/scratch/cosmo/manticore_idx:/var/lib/manticore:z,$(pwd)/manticore.conf:/etc/manticoresearch/manticore.conf" \ 24 | --no-container-mount-home \ 25 | --qos high \ 26 | /bin/bash -c 'mkdir -p /var/run/manticore && chown manticore:manticore /var/run/manticore && mkdir -p /var/run/mysqld && chown manticore:manticore /var/run/mysqld && export EXTRA=1 && source /entrypoint.sh && docker_setup_env && /entrypoint.sh searchd -c /etc/manticoresearch/manticore.conf --nodetach' & 27 | 28 | python index_docs.py 29 | 30 | sleep 1000 31 | 32 | rclone copy -P --transfers 32 /scratch/cosmo/manticore_idx/ s3:cosmopedia-data/manticore_idx/CC-MAIN-2024-10-2023-50/ 33 | 34 | sleep 1000000000 -------------------------------------------------------------------------------- /fulltext_search/manticore.conf: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | ip=`hostname -i|rev|cut -d\ -f 1|rev` 3 | cat << EOF 4 | searchd { 5 | 6 | # https://manual.manticoresearch.com/Server_settings/Searchd#access_plain_attrs 7 | # access_plain_attrs = mmap_preread 8 | 9 | # https://manual.manticoresearch.com/Server_settings/Searchd#access_blob_attrs 10 | # access_blob_attrs = mmap_preread 11 | 12 | # https://manual.manticoresearch.com/Server_settings/Searchd#access_doclists 13 | # access_doclists = file 14 | 15 | # https://manual.manticoresearch.com/Server_settings/Searchd#access_hitlists 16 | # access_hitlists = file 17 | 18 | # https://manual.manticoresearch.com/Server_settings/Searchd#agent_connect_timeout 19 | # agent_connect_timeout = 20 | 21 | # https://manual.manticoresearch.com/Server_settings/Searchd#agent_query_timeout 22 | # agent_query_timeout = 23 | 24 | # https://manual.manticoresearch.com/Server_settings/Searchd#agent_retry_count 25 | # agent_retry_count = 0 26 | 27 | # https://manual.manticoresearch.com/Server_settings/Searchd#agent_retry_delay 28 | # agent_retry_delay = 500 29 | 30 | # https://manual.manticoresearch.com/Server_settings/Searchd#attr_flush_period 31 | # attr_flush_period = 0 32 | 33 | # https://manual.manticoresearch.com/Server_settings/Searchd#binlog_flush 34 | # binlog_flush = 2 35 | 36 | # https://manual.manticoresearch.com/Server_settings/Searchd#binlog_max_log_size 37 | # binlog_max_log_size = 268435456 38 | 39 | # https://manual.manticoresearch.com/Server_settings/Searchd#binlog_path 40 | # binlog_path = 41 | 42 | # https://manual.manticoresearch.com/Server_settings/Searchd#client_timeout 43 | # client_timeout = 300 44 | 45 | # https://manual.manticoresearch.com/Server_settings/Searchd#collation_libc_locale 46 | # collation_libc_locale = C 47 | 48 | # https://manual.manticoresearch.com/Server_settings/Searchd#collation_server 49 | # collation_server = libc_ci 50 | 51 | # https://manual.manticoresearch.com/Server_settings/Searchd#data_dir 52 | data_dir = /var/lib/manticore 53 | 54 | # https://manual.manticoresearch.com/Server_settings/Searchd#docstore_cache_size 55 | # docstore_cache_size = 16m 56 | 57 | # https://manual.manticoresearch.com/Server_settings/Searchd#expansion_limit 58 | # expansion_limit = 0 59 | 60 | # https://manual.manticoresearch.com/Server_settings/Searchd#grouping_in_utc 61 | # grouping_in_utc = 0 62 | 63 | # https://manual.manticoresearch.com/Server_settings/Searchd#ha_period_karma 64 | # ha_period_karma = 60 65 | 66 | # https://manual.manticoresearch.com/Server_settings/Searchd#ha_ping_interval 67 | # ha_ping_interval = 1000 68 | 69 | # https://manual.manticoresearch.com/Server_settings/Searchd#hostname_lookup 70 | # hostname_lookup = 71 | 72 | # https://manual.manticoresearch.com/Server_settings/Searchd#jobs_queue_size 73 | # jobs_queue_size = 74 | 75 | # https://manual.manticoresearch.com/Server_settings/Searchd#listen_backlog 76 | # listen_backlog = 5 77 | 78 | # https://manual.manticoresearch.com/Server_settings/Searchd#listen 79 | # listen_env = this directive allows to append listeners from environment variables 80 | 81 | listen = 9306:mysql41 82 | listen = /var/run/mysqld/mysqld.sock:mysql41 83 | listen = $ip:9312 84 | listen = 9308:http 85 | listen = $ip:9315-9325:replication 86 | 87 | # https://manual.manticoresearch.com/Server_settings/Searchd#listen_tfo 88 | # listen_tfo = 0 89 | 90 | # https://manual.manticoresearch.com/Server_settings/Searchd#log 91 | log = /var/log/manticore/searchd.log 92 | 93 | # https://manual.manticoresearch.com/Server_settings/Searchd#max_batch_queries 94 | # max_batch_queries = 32 95 | 96 | # https://manual.manticoresearch.com/Server_settings/Searchd#threads 97 | threads = 64 98 | 99 | # https://manual.manticoresearch.com/Server_settings/Searchd#max_filters 100 | # max_filters = 256 101 | 102 | # https://manual.manticoresearch.com/Server_settings/Searchd#max_filter_values 103 | # max_filter_values = 4096 104 | 105 | # https://manual.manticoresearch.com/Server_settings/Searchd#max_open_files 106 | max_open_files = 65535 107 | 108 | optimize_cutoff = 2 109 | 110 | # https://manual.manticoresearch.com/Server_settings/Searchd#max_packet_size 111 | max_packet_size = 128M 112 | 113 | # https://manual.manticoresearch.com/Server_settings/Searchd#mysql_version_string 114 | # mysql_version_string = 115 | 116 | # https://manual.manticoresearch.com/Server_settings/Searchd#net_workers 117 | # net_workers = 1 118 | 119 | # https://manual.manticoresearch.com/Server_settings/Searchd#net_wait_tm 120 | # net_wait_tm = -1 121 | 122 | # https://manual.manticoresearch.com/Server_settings/Searchd#net_throttle_accept 123 | # net_throttle_accept = 0 124 | 125 | # https://manual.manticoresearch.com/Server_settings/Searchd#net_throttle_action 126 | # net_throttle_action = 0 127 | 128 | # https://manual.manticoresearch.com/Server_settings/Searchd#node_address 129 | # node_address = 130 | 131 | # https://manual.manticoresearch.com/Server_settings/Searchd#ondisk_attrs_default 132 | # ondisk_attrs_default = 0 133 | 134 | # https://manual.manticoresearch.com/Server_settings/Searchd#persistent_connections_limit 135 | # persistent_connections_limit = 136 | 137 | # https://manual.manticoresearch.com/Server_settings/Searchd#pid_file 138 | pid_file = /var/run/manticore/searchd.pid 139 | 140 | # https://manual.manticoresearch.com/Server_settings/Searchd#predicted_time_costs 141 | # predicted_time_costs = doc=64, hit=48, skip=2048, match=64 142 | 143 | # https://manual.manticoresearch.com/Server_settings/Searchd#preopen_indexes 144 | # preopen_indexes = 1 145 | 146 | # https://manual.manticoresearch.com/Server_settings/Searchd#qcache_max_bytes 147 | # qcache_max_bytes = 16Mb 148 | 149 | # https://manual.manticoresearch.com/Server_settings/Searchd#qcache_thresh_msec 150 | # qcache_thresh_msec = 3000 151 | 152 | # https://manual.manticoresearch.com/Server_settings/Searchd#qcache_ttl_sec 153 | # qcache_ttl_sec = 60 154 | 155 | # https://manual.manticoresearch.com/Server_settings/Searchd#query_log_format 156 | query_log_format = sphinxql 157 | 158 | # https://manual.manticoresearch.com/Server_settings/Searchd#query_log_min_msec 159 | # query_log_min_msec = 0 160 | 161 | # https://manual.manticoresearch.com/Server_settings/Searchd#query_log 162 | # query_log = /var/log/manticore/query.log 163 | 164 | # https://manual.manticoresearch.com/Server_settings/Searchd#query_log_mode 165 | # query_log_mode = 600 166 | 167 | # https://manual.manticoresearch.com/Server_settings/Searchd#max_connections 168 | # max_connections = 169 | 170 | # https://manual.manticoresearch.com/Server_settings/Searchd#network_timeout 171 | # network_timeout = 5 172 | 173 | # https://manual.manticoresearch.com/Server_settings/Searchd#read_buffer 174 | # read_buffer = 256K 175 | 176 | # https://manual.manticoresearch.com/Server_settings/Searchd#read_buffer_docs 177 | # read_buffer_docs = 256K 178 | 179 | # https://manual.manticoresearch.com/Server_settings/Searchd#read_buffer_hits 180 | # read_buffer_hits = 256K 181 | 182 | # https://manual.manticoresearch.com/Server_settings/Searchd#read_unhinted 183 | # read_unhinted 32K 184 | 185 | # https://manual.manticoresearch.com/Server_settings/Searchd#rt_flush_period 186 | # rt_flush_period = 187 | 188 | # https://manual.manticoresearch.com/Server_settings/Searchd#rt_merge_iops 189 | # rt_merge_iops = 0 190 | 191 | # https://manual.manticoresearch.com/Server_settings/Searchd#rt_merge_maxiosize 192 | # rt_merge_maxiosize = 0 193 | 194 | # https://manual.manticoresearch.com/Server_settings/Searchd#seamless_rotate 195 | # seamless_rotate = 1 196 | 197 | # https://manual.manticoresearch.com/Server_settings/Searchd#server_id 198 | # server_id = 199 | 200 | # https://manual.manticoresearch.com/Server_settings/Searchd#shutdown_timeout 201 | # shutdown_timeout = 3 202 | 203 | # https://manual.manticoresearch.com/Server_settings/Searchd#shutdown_token 204 | # shutdown_token = 205 | 206 | # https://manual.manticoresearch.com/Server_settings/Searchd#snippets_file_prefix 207 | # snippets_file_prefix = 208 | 209 | # https://manual.manticoresearch.com/Server_settings/Searchd#sphinxql_state 210 | # sphinxql_state = 211 | 212 | # https://manual.manticoresearch.com/Server_settings/Searchd#sphinxql_timeout 213 | # sphinxql_timeout = 900 214 | 215 | # https://manual.manticoresearch.com/Server_settings/Searchd#ssl_ca 216 | # ssl_ca = 217 | 218 | # https://manual.manticoresearch.com/Server_settings/Searchd#ssl_cert 219 | # ssl_cert = 220 | 221 | # https://manual.manticoresearch.com/Server_settings/Searchd#ssl_key 222 | # ssl_key = 223 | 224 | # https://manual.manticoresearch.com/Server_settings/Searchd#subtree_docs_cache 225 | # subtree_docs_cache = 0 226 | 227 | # https://manual.manticoresearch.com/Server_settings/Searchd#subtree_hits_cache 228 | # subtree_hits_cache = 0 229 | 230 | # https://manual.manticoresearch.com/Server_settings/Searchd#thread_stack 231 | # thread_stack = 232 | 233 | # https://manual.manticoresearch.com/Server_settings/Searchd#unlink_old 234 | # unlink_old = 1 235 | 236 | # https://manual.manticoresearch.com/Server_settings/Searchd#watchdog 237 | # watchdog = 1 238 | 239 | # https://manual.manticoresearch.com/Server_settings/Searchd#secondary_indexes 240 | secondary_indexes = 0 241 | } 242 | 243 | common { 244 | 245 | # https://manual.manticoresearch.com/Server_settings/Common#lemmatizer_base 246 | # lemmatizer_base = /usr/local/share 247 | 248 | # https://manual.manticoresearch.com/Server_settings/Common#progressive_merge 249 | # progressive_merge = 250 | 251 | # https://manual.manticoresearch.com/Server_settings/Common#json_autoconv_keynames 252 | # json_autoconv_keynames = 253 | 254 | # https://manual.manticoresearch.com/Server_settings/Common#json_autoconv_numbers 255 | # json_autoconv_numbers = 0 256 | 257 | # https://manual.manticoresearch.com/Server_settings/Common#on_json_attr_error 258 | # on_json_attr_error = ignore_attr 259 | 260 | # https://manual.manticoresearch.com/Server_settings/Common#plugin_dir 261 | # plugin_dir = 262 | 263 | } 264 | 265 | EOF -------------------------------------------------------------------------------- /fulltext_search/search_sharded.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import sys 4 | import time 5 | 6 | import requests 7 | from datasets import load_dataset 8 | 9 | 10 | def get_args(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument( 13 | "--input_dataset", type=str, default="HuggingFaceTB/bisac_expanded_final" 14 | ) 15 | parser.add_argument("--n_pages", type=int, default=2000) 16 | parser.add_argument( 17 | "--output_dataset", 18 | type=str, 19 | default="HuggingFaceTB/bisac_boosted_new_index_2000", 20 | ) 21 | parser.add_argument("--shard", type=int, required=True) 22 | parser.add_argument("--num_shards", type=int, required=True) 23 | return parser.parse_args() 24 | 25 | 26 | # wait until the server is up 27 | while True: 28 | try: 29 | requests.post( 30 | "http://127.0.0.1:9308/search", 31 | data='{"index": "fineweb", "query": {"match": {"content": "ping"}}}', 32 | ) 33 | break 34 | except requests.exceptions.ConnectionError: 35 | time.sleep(10) 36 | pass 37 | 38 | 39 | args = get_args() 40 | data = load_dataset( 41 | args.input_dataset, split="train", cache_dir="/scratch/cosmo/.cache" 42 | ) 43 | data = data.filter(lambda x, i: i % args.num_shards == args.shard, with_indices=True) 44 | data = data.select_columns(["top_category", "subcategory", "subtopic"]) 45 | 46 | 47 | def run_query(query, n_pages): 48 | while True: 49 | try: 50 | max_pages = 4_000 51 | response = requests.post( 52 | "http://127.0.0.1:9308/search", 53 | data=json.dumps( 54 | { 55 | "index": "fineweb", 56 | "size": n_pages, 57 | "query": query, 58 | "max_matches": max_pages, 59 | } 60 | ), 61 | timeout=1000, 62 | ) 63 | if response.status_code != 200: 64 | print(response.text, file=sys.stderr) 65 | time.sleep(5) 66 | continue 67 | else: 68 | hits = response.json()["hits"]["hits"] 69 | return hits 70 | except requests.exceptions.ConnectionError as e: 71 | print(e, file=sys.stderr) 72 | time.sleep(5) 73 | continue 74 | 75 | 76 | def search_topic(sample): 77 | top_category = sample["top_category"][0].strip() 78 | subcategory = sample["subcategory"][0].strip() 79 | subtopic = sample["subtopic"][0].strip() 80 | for c in ["!", '"', "$", "'", "(", ")", "/", "<", "@", "\\", "^", "|", "~"]: 81 | top_category = top_category.replace(c, " ") 82 | subcategory = subcategory.replace(c, " ") 83 | subtopic = subtopic.replace(c, " ") 84 | # boosting the IDF score of subtopic tokens 85 | boosted_subtopic = " ".join([w + "^2" for w in subtopic.split()]) 86 | match_query = " ".join([top_category, subcategory, subtopic]) 87 | boosted_query = " ".join([top_category, subcategory, boosted_subtopic]) 88 | 89 | boosted_hits = run_query({"query_string": boosted_query}, args.n_pages) 90 | print(f"Boosted hits: {len(boosted_hits)} for {boosted_query}", file=sys.stderr) 91 | if len(boosted_hits) < args.n_pages: 92 | match_hits = run_query( 93 | {"match": {"content": match_query}}, args.n_pages + len(boosted_hits) 94 | ) 95 | print(f"Match hits: {len(match_hits)} for {match_query}", file=sys.stderr) 96 | else: 97 | match_hits = [] 98 | 99 | hit_ids = set() 100 | hits = [] 101 | for hit in boosted_hits + match_hits: 102 | if hit["_id"] not in hit_ids: 103 | hits.append(hit) 104 | hit_ids.add(hit["_id"]) 105 | hits = hits[: args.n_pages] 106 | 107 | results = { 108 | "top_category": sample["top_category"] * len(hits), 109 | "subcategory": sample["subcategory"] * len(hits), 110 | "subtopic": sample["subtopic"] * len(hits), 111 | "topic_hits": hits, 112 | "num_hits": [len(hits)] * len(hits), 113 | } 114 | return results 115 | 116 | 117 | data = data.map(search_topic, batched=True, batch_size=1, num_proc=2) 118 | data.push_to_hub( 119 | f"{args.output_dataset}_{args.shard}", private=True, max_shard_size="4096MB" 120 | ) 121 | -------------------------------------------------------------------------------- /fulltext_search/search_sharded.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=cosmo_search_sharded 3 | #SBATCH --partition hopper-prod 4 | #SBATCH --nodes=1 5 | #SBATCH --ntasks-per-node=1 6 | #SBATCH --cpus-per-task=96 7 | #SBATCH --mem-per-cpu=20G 8 | #SBATCH -o %x_%j.out 9 | #SBATCH -e %x_%j.err 10 | #SBATCH --array=0-15%8 11 | 12 | set -x -e 13 | source ~/.bashrc 14 | source "$CONDA_PREFIX/etc/profile.d/conda.sh" 15 | source activate pyspark 16 | 17 | ulimit -n 99999 18 | 19 | mkdir -p /scratch/cosmo/manticore_idx 20 | rm -rf /scratch/cosmo/manticore_idx/* 21 | rclone copy -P --transfers 32 s3:cosmopedia-data/manticore_idx/CC-MAIN-2024-10-2023-50/ /scratch/cosmo/manticore_idx/ 22 | 23 | srun --container-image='manticoresearch/manticore:6.2.12' \ 24 | --container-mounts="/scratch/cosmo/manticore_idx:/var/lib/manticore:z,$(pwd)/manticore.conf:/etc/manticoresearch/manticore.conf" \ 25 | --no-container-mount-home \ 26 | --qos high \ 27 | /bin/bash -c 'mkdir -p /var/run/manticore && chown manticore:manticore /var/run/manticore && mkdir -p /var/run/mysqld && chown manticore:manticore /var/run/mysqld && /entrypoint.sh searchd -c /etc/manticoresearch/manticore.conf --nodetach' & 28 | 29 | python search_sharded.py --shard ${SLURM_ARRAY_TASK_ID} --num_shards 16 -------------------------------------------------------------------------------- /generation/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Synthetic data generation 3 | 4 | If you have a large dataset of prompts and want to generate content using an Open-Source LLM like [Mixtral-8x7B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1), you can use `llm-swarm` which spins TGI or vLLM instances on `slurm`clutsres, we used it to generated Cosmopedia, which con sists of 25B tokens. The full generation took around > 10k H100 GPU hours. 5 | 6 | You can find the instructions for running the generation in `llm-swarm` here: https://github.com/huggingface/llm-swarm/tree/loubna/examples/textbooks 7 | 8 | The generation script is also available here (to be used within `examples/textbooks` of the library.) 9 | 10 | ```bash 11 | # after having followed all the installation guidlines in llm-swrarm + install wandb 12 | # 100k subset 13 | python generate_syntehtic_textbooks.py \ 14 | --model mistralai/Mixtral-8x7B-Instruct-v0.1 \ 15 | --instances 2 \ 16 | --prompts_dataset "HuggingFaceTB/cosmopedia-100k" \ 17 | --prompt_column prompt \ 18 | --max_samples 2000 \ 19 | --checkpoint_path "./synthetic_data" \ 20 | --checkpoint_interval 1000 21 | ``` 22 | -------------------------------------------------------------------------------- /generation/boilerplate_cleanup.py: -------------------------------------------------------------------------------- 1 | import re 2 | import argparse 3 | from datasets import load_dataset 4 | 5 | 6 | patterns = [ 7 | # alien stories 8 | r"^Hello.*?[.!]\s+", 9 | #r"^I'm( so)? excited to.*?[.!]\s+", 10 | r"^My name is.*?[.!]\s+", 11 | r"^You've just arrived.*?[.!]\s+", 12 | 13 | # wikihow 14 | r"^\*\*Welcome, .*?[.!]\*\*\s+", 15 | r"^(\*\*)?Warning:.*?[.!]\s+", 16 | r"^We're thrilled.*?[.!]\s+", 17 | 18 | # middle school 19 | r"^Welcome, .*?[.!]\s+", 20 | ] 21 | patterns = [re.compile(p, flags=re.IGNORECASE|re.MULTILINE) for p in patterns] 22 | 23 | def clean_text(sample): 24 | sample['completion_unfiltered'] = sample['completion'] 25 | for pattern in patterns: 26 | sample['completion'] = pattern.sub('', sample['completion'].strip()) 27 | return sample 28 | 29 | 30 | if __name__ == "__main__": 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("--dataset", type=str, default="HuggingFaceTB/alien_stories_0_1M_llama3") 33 | args = parser.parse_args() 34 | 35 | data = load_dataset(args.dataset, split="train", cache_dir="/scratch/cosmo/cache", num_proc=32) 36 | data = data.map(clean_text, num_proc=32) 37 | data.push_to_hub(args.dataset, private=True) -------------------------------------------------------------------------------- /generation/llm_swarm_script.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import multiprocessing 3 | import os 4 | import time 5 | from dataclasses import asdict, dataclass 6 | 7 | from datasets import Dataset, load_dataset 8 | from huggingface_hub import AsyncInferenceClient 9 | from llm_swarm import LLMSwarm, LLMSwarmConfig 10 | from tqdm.asyncio import tqdm_asyncio 11 | from transformers import AutoTokenizer, HfArgumentParser 12 | 13 | import wandb 14 | 15 | HF_TOKEN = os.environ.get("HF_TOKEN", None) 16 | 17 | 18 | @dataclass 19 | class Args: 20 | # gneration parameters 21 | max_new_tokens: int = 2500 22 | """Max new tokens""" 23 | temperature: float = 0.6 24 | """Generation temperature""" 25 | top_p: float = 0.95 26 | """Generation top_p""" 27 | top_k: int = 50 28 | """Generation top_k""" 29 | repetition_penalty: float = 1.2 30 | """Generation repetition_penalty""" 31 | # prompts dataset parameters 32 | prompts_dataset: str = "HuggingFaceTB/cosmopedia-100k" 33 | """Dataset containing the prompts""" 34 | max_samples: int = 5000 35 | """The maximum number of samples to generate (use -1 for all))""" 36 | start_sample: int = -1 37 | """First sample to process""" 38 | end_sample: int = -1 39 | """Last sample to process""" 40 | seed: int = 42 41 | """Seed for shuffling""" 42 | prompt_column: str = "prompt" 43 | """Name of the column containing the prompt""" 44 | shuffle_dataset: bool = False 45 | """Whether to shuffle the prompts""" 46 | debug: bool = False 47 | """Debugging mode""" 48 | # logging parameters 49 | repo_id: str = "HuggingFaceTB/synthetic_data_test" 50 | """The repo id to push to""" 51 | checkpoint_path: str = "./synthetic_data" 52 | """Path for saving intermediate generations""" 53 | checkpoint_interval: int = 1_000 54 | """Interval for saving intermediate generations""" 55 | wandb_username: str = "loubnabnl" 56 | """Wandb username""" 57 | min_token_length: int = 150 58 | """Minimum number of tokens in a generation to be kept in the final dataset""" 59 | push_to_hub: bool = True 60 | """Whether to push to hub""" 61 | 62 | 63 | parser = HfArgumentParser((Args, LLMSwarmConfig)) 64 | args, isc = parser.parse_args_into_dataclasses() 65 | # args used in wandb 66 | args_dict = asdict(args) 67 | args_dict.update( 68 | { 69 | "per_instance_max_parallel_requests": isc.per_instance_max_parallel_requests, 70 | "instances": isc.instances, 71 | "inference_engine": isc.inference_engine, 72 | "model": isc.model, 73 | } 74 | ) 75 | print(args_dict) 76 | 77 | tokenizer = AutoTokenizer.from_pretrained(isc.model) 78 | 79 | num_proc = 1 if args.debug else multiprocessing.cpu_count() 80 | ds = load_dataset( 81 | args.prompts_dataset, token=HF_TOKEN, split="train", num_proc=num_proc 82 | ) 83 | 84 | if args.shuffle_dataset: 85 | ds = ds.shuffle(seed=args.seed) 86 | 87 | if args.start_sample >= 0: 88 | end_sample = len(ds) if args.end_sample < 0 else args.end_sample 89 | print(f"Loading a defined range of samples: ({args.start_sample}, {end_sample})...") 90 | ds = ds.select(range(args.start_sample, end_sample)) 91 | elif args.max_samples > 0: 92 | print(f"Loading the first {args.max_samples} samples...") 93 | ds = ds.select(range(args.max_samples)) 94 | 95 | 96 | with LLMSwarm(isc) as llm_swarm: 97 | semaphore = asyncio.Semaphore(llm_swarm.suggested_max_parallel_requests) 98 | client = AsyncInferenceClient(model=llm_swarm.endpoint) 99 | STOP_SEQ = ["<|endoftext|>"] 100 | 101 | MAX_RETRIES = 6 # maximum number of retries 102 | RETRY_DELAY = 4 # delay in seconds between retries 103 | 104 | async def process_text(sample): 105 | token_length = 0 106 | attempt = 0 107 | while attempt < MAX_RETRIES: 108 | try: 109 | async with semaphore: 110 | completion = await client.text_generation( 111 | prompt=tokenizer.apply_chat_template( 112 | [{"role": "user", "content": sample[args.prompt_column]}], 113 | tokenize=False, 114 | ), 115 | max_new_tokens=args.max_new_tokens, 116 | stop_sequences=STOP_SEQ, 117 | temperature=args.temperature, 118 | top_p=args.top_p, 119 | top_k=args.top_k, 120 | repetition_penalty=args.repetition_penalty, 121 | ) 122 | for stop_seq in STOP_SEQ: 123 | if completion.endswith(stop_seq): 124 | completion = completion[: -len(stop_seq)].rstrip() 125 | token_length += len(tokenizer.encode(completion)) 126 | sample["completion"] = completion 127 | sample["token_length"] = token_length 128 | return sample 129 | 130 | except Exception as e: 131 | attempt += 1 132 | if attempt < MAX_RETRIES: 133 | print( 134 | f"Request failed, retrying in {RETRY_DELAY} seconds... (Attempt {attempt}/{MAX_RETRIES})" 135 | ) 136 | await asyncio.sleep(RETRY_DELAY) 137 | else: 138 | print( 139 | f"Max retries reached. Failed to process the request with error {str(e)}." 140 | ) 141 | sample["completion"] = "" 142 | sample["token_length"] = 0 143 | return sample 144 | 145 | async def main(): 146 | start_time = time.time() 147 | total_tokens = 0 148 | saving_time = 0 149 | 150 | repo_id = ( 151 | f"{args.repo_id}_{args.prompt_column}" 152 | if args.prompt_column not in args.repo_id 153 | else args.repo_id 154 | ) 155 | wandb.init( 156 | project="synthetic_data", 157 | entity=args.wandb_username, 158 | name=repo_id.split("/")[1], 159 | ) 160 | wandb.config.update(args_dict) 161 | 162 | repo_id = ( 163 | f"{args.repo_id}_{args.prompt_column}" 164 | if args.prompt_column not in args.repo_id 165 | else args.repo_id 166 | ) 167 | checkpoint_dir = f"{args.checkpoint_path}/{repo_id.split('/')[1]}/data" 168 | os.makedirs(checkpoint_dir, exist_ok=True) 169 | print(f"Will be saving at {checkpoint_dir}") 170 | 171 | total_samples = len(ds) 172 | for i in range(0, total_samples, args.checkpoint_interval): 173 | batch_time = time.time() 174 | # Processing a chunk 175 | print( 176 | f"Processing chunk {int(i/args.checkpoint_interval)}/{int(total_samples/args.checkpoint_interval)}" 177 | ) 178 | end_index = min(i + args.checkpoint_interval, total_samples) 179 | chunk = ds.select(range(i, end_index)) 180 | chunk_results = await tqdm_asyncio.gather( 181 | *(process_text(sample) for sample in chunk) 182 | ) 183 | # Save the chunk results and log throughput 184 | temp_time = time.time() 185 | time_per_chunk = temp_time - batch_time 186 | checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_{i}.json") 187 | intermediate_ds = Dataset.from_list(chunk_results) 188 | intermediate_ds.to_json(checkpoint_path) 189 | batch_tokens = sum(intermediate_ds["token_length"]) 190 | total_tokens += batch_tokens 191 | saving_time += time.time() - temp_time 192 | print(f"💾 Checkpoint (samples {i}-{i + args.checkpoint_interval}) saved at {checkpoint_path}.") 193 | wandb.log( 194 | { 195 | "sample": i + args.checkpoint_interval, 196 | "batch": int(i / args.checkpoint_interval), 197 | "total_tokens (M)": total_tokens / 1e6, 198 | "tokens_per_batch": batch_tokens, 199 | "time_per_batch (s)": time_per_chunk, 200 | "generated_tokens_per_sec": int(batch_tokens / time_per_chunk), 201 | "generated_tokens_per_sec_per_node": int( 202 | batch_tokens / (time_per_chunk * isc.instances) 203 | ), 204 | } 205 | ) 206 | 207 | end_time = time.time() 208 | 209 | print( 210 | "Done processing and saving all chunks 🎉! Let's get some stats and push to hub..." 211 | ) 212 | total_duration = end_time - start_time 213 | overall_tokens_per_second = ( 214 | total_tokens / total_duration if total_duration > 0 else 0 215 | ) 216 | print( 217 | f"🏎️💨 Overall Tokens per Second: {overall_tokens_per_second:.2f}, per instance: {overall_tokens_per_second/isc.instances:.2f}" 218 | ) 219 | print(f"Generated {total_tokens / 1e6:.2f}M tokens") 220 | print( 221 | f"Total duration: {total_duration // 3600}h{int((total_duration % 3600) // 60)}min " 222 | ) 223 | print(f"Saving time: {saving_time}s={saving_time/60}min ") 224 | 225 | # load dataset 226 | print("Load checkpoints...") 227 | output_ds = load_dataset(checkpoint_dir, split="train") 228 | # remove empty completions 229 | final_data = output_ds.filter( 230 | lambda x: x["token_length"] >= args.min_token_length 231 | ) 232 | print(final_data) 233 | failed = output_ds.filter(lambda x: x["token_length"] <= args.min_token_length) 234 | print(final_data) 235 | if args.push_to_hub: 236 | print(f"📨 Pushing dataset to {repo_id}") 237 | final_data.push_to_hub(repo_id, private=True) 238 | print("Dataset pushed!") 239 | if len(failed) > 0: 240 | print(f"{len(failed)} generations failed") 241 | size = min(len(failed), 1000) 242 | failed = failed.select(range(size)) 243 | failed.push_to_hub(f"{repo_id}_failed", private=True) 244 | 245 | asyncio.run(main()) 246 | wandb.finish() -------------------------------------------------------------------------------- /plots/clusters_map.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/cosmopedia/653acad4f6146ce1043e2ca8792c671947e83ce0/plots/clusters_map.png -------------------------------------------------------------------------------- /plots/cover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/cosmopedia/653acad4f6146ce1043e2ca8792c671947e83ce0/plots/cover.png -------------------------------------------------------------------------------- /plots/cover_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/cosmopedia/653acad4f6146ce1043e2ca8792c671947e83ce0/plots/cover_01.png -------------------------------------------------------------------------------- /plots/educational_score.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/cosmopedia/653acad4f6146ce1043e2ca8792c671947e83ce0/plots/educational_score.png -------------------------------------------------------------------------------- /plots/topics_distpng.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/cosmopedia/653acad4f6146ce1043e2ca8792c671947e83ce0/plots/topics_distpng.png -------------------------------------------------------------------------------- /prompts/README.md: -------------------------------------------------------------------------------- 1 | # Building the prompts 2 | 3 | Here you can find the code for building the prompts for each `seed_data` in [Cosmopedia](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia). 4 | In Cosmopedia we standardized the column names and merged all the data sources together, which isn't the case for some scripts here. 5 | -------------------------------------------------------------------------------- /prompts/auto_math_text/README.md: -------------------------------------------------------------------------------- 1 | ## Synthetic generations from AutoMathText 2 | 3 | To build prompts from the web subset of AutoMaThText for two audiences: college students and grade school students, run: 4 | 5 | ``` 6 | python ./build_science_prompt.py --run_all_styles 7 | # for a specific generation style/audience: 8 | python ./build_science_prompt.py --generation_style college 9 | ``` 10 | -------------------------------------------------------------------------------- /prompts/auto_math_text/build_science_prompts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datasets import load_dataset 3 | 4 | 5 | STYLES = {"college": 6 | """Write an educational piece suited for college students related to the following text snippet: 7 | "" 8 | 9 | Do not just list concepts, but develop each one in detail before moving to the next, as we prioritize depth of understanding and comprehensive exploration of the subject matter over breadth. Focus on: 10 | 11 | - Rigor: Ensure in-depth coverage of the concepts/sections. 12 | - Engagement: Write with an academic, professional and engaging tone that captivates interest. 13 | - Application: Incorporate specific, practical examples, such as proofs in calculus or critical dates and figures in history. 14 | Do not include a title or an introduction, simply write the content without headlines and introductory phrases. Do not use images.""", 15 | 16 | "grade_school": 17 | """Here's an extract from a webpage: 18 | "" 19 | 20 | Create an educational piece related to the snippet above targeted at grade-school students. Complex college-like topics such Electromagnetism and Integration shouldn't be used, as they aren't usually taught at grade-school. If that's what the snippet is about, look for a much simpler scientific alternative to explain, and use everyday examples. For instance, if the topic is 'Linear Algebra' you might discuss how arranging objects in rows and columns can help solve puzzles. 21 | Avoid technical terms and LaTeX and only discuss simple grade-school level topics. Start the educational piece right away."""} 22 | 23 | EXTRACT_SIZE = 1000 24 | 25 | 26 | def get_args(): 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument("--repo_id", type=str, default="HuggingFaceTB/auto_math") 29 | parser.add_argument("--generation_style", type=str, default="college") 30 | parser.add_argument("--run_all_styles", action="store_true") 31 | return parser.parse_args() 32 | 33 | 34 | def build_prompt(x, style="college"): 35 | """Build the prompt based on the generation type""" 36 | snippet = x["text"].strip() 37 | snippet = snippet[:min(len(snippet), EXTRACT_SIZE)] 38 | prompt = STYLES[style].replace("", snippet) 39 | return {f"prompt_{style}": prompt} 40 | 41 | 42 | if __name__ == "__main__": 43 | args = get_args() 44 | 45 | print(f"Loading AutoMathText web data...") 46 | ds = load_dataset("math-ai/AutoMathText", "web-0.50-to-1.00")["train"] 47 | if args.run_all_styles: 48 | suffix = "" 49 | for style in STYLES.keys(): 50 | print(f"📖 Building prompts with a {style}...") 51 | ds = ds.map(build_prompt, num_proc=48, fn_kwargs={"style": style}) 52 | else: 53 | suffix = f"_{args.generation_style}" 54 | print(f"📖 Building prompts with a {args.generation_style}...") 55 | ds = ds.map(build_prompt, num_proc=48, fn_kwargs={"style": args.generation_style}) 56 | print(ds) 57 | print(ds) 58 | print(ds[0]["prompt_college"]) 59 | print("-"*100) 60 | print(ds[1]["prompt_grade_school"]) 61 | ds.push_to_hub(f"{args.repo_id}{suffix}", private=True) 62 | print(f"✅ Data available at {args.repo_id}{suffix}!") -------------------------------------------------------------------------------- /prompts/khanacademy/README.md: -------------------------------------------------------------------------------- 1 | # KhanAcademy 2 | ## Code adapted from https://github.com/rand-net/khan-dl 3 | 4 | ## Run script to download list of courses 5 | ```bash 6 | # install requirements 7 | pip install -r khan_dl/requirements.txt 8 | # run downloader with all courses 9 | python khan_dl/main.py -a 10 | ``` 11 | 12 | output will be saved on `khan_courses.json` 13 | 14 | You can then use `generate_textbooks.py` to build the textbook generation prompts. 15 | [TODO]: add code for updated prompts fo Cosmopedia -------------------------------------------------------------------------------- /prompts/khanacademy/generate_textbooks.py: -------------------------------------------------------------------------------- 1 | import json 2 | from string import Template 3 | import pandas as pd 4 | 5 | TEMPLATE = Template("""Write a long and very detailed course unit for a textbook on "${unit_title}". 6 | ${previous_sections}\"${section}\". 7 | ${previous_sub_units} 8 | Write the new sub-unit titled \"${unit}\" while trying to be: 9 | - Rigorous - you create challenging textbooks that cover the material in depth. 10 | - Engaging - your textbooks have a narrative arc and engaging tone, like the writing of Michael Lewis. 11 | - Applied - you use specific and practical examples. For example, if the topic is integration in calculus, include equations and proofs of the concept you're teaching. As another example, if the topic is the history of the United States, include dates, names, and key events. 12 | Model:""") 13 | 14 | 15 | # file created by khan_dl 16 | with open("khan_courses.json") as f: 17 | data = json.load(f) 18 | 19 | textbooks = [] 20 | total_courses = 0 21 | total_sections = 0 22 | for course in data: 23 | total_courses += 1 24 | units = course["subunits"] 25 | for ui, unit in enumerate(units): 26 | for sui, subunit in enumerate(unit["subunits"]): 27 | total_sections += 1 28 | for li, lesson_title in enumerate(subunit["lessons"]): 29 | # previous sections 30 | previous_subunits = [f"\"{sii + 1}. {s['title']}\"" for sii, s in enumerate(unit["subunits"][:sui])] 31 | previous_subunits_text = f"We have already covered chapter(s) {', '.join(previous_subunits)} and are now writing a chapter on " if previous_subunits else "We are currently writing the first chapter: " 32 | # previous lessons 33 | previous_lessons = [f"{lesson_title}\"" for lii, lesson_title in 34 | enumerate(subunit["lessons"][:li])] 35 | previous_lessons_text = f"We have already covered the following lessons in the current chapter: {', '.join(previous_lessons)}." if previous_lessons else "You will be writing the first lesson for this chapter." 36 | 37 | section_name = f"{unit['title']} - {subunit['title']}" 38 | # WIP 39 | sample = { 40 | "unit_title": course["title"].replace("_", " ") + " - " + unit['title'][ 41 | unit['title'].index(":") + 1:].strip(), 42 | "section": section_name, 43 | "unit": lesson_title, 44 | } 45 | sample["prompt"] = TEMPLATE.substitute(previous_sections=previous_subunits_text, 46 | previous_sub_units=previous_lessons_text, **sample) 47 | textbooks.append(sample) 48 | 49 | pd.DataFrame(textbooks).to_csv(f"khanacademy_prompts") 50 | -------------------------------------------------------------------------------- /prompts/khanacademy/khan_dl/khan_dl.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/rand-net/khan-dl 2 | 3 | import logging 4 | import os 5 | import platform 6 | import sys 7 | from typing import List, Tuple 8 | 9 | import requests 10 | from bs4 import BeautifulSoup 11 | from prompt_toolkit import prompt 12 | from prompt_toolkit.completion import FuzzyWordCompleter 13 | 14 | VIDEO_SITE_URL = "https://www.youtube.com/watch?v=" 15 | ROOT_URL = "https://www.khanacademy.org" 16 | DOMAINS = [ 17 | "math", 18 | "science", 19 | "computing", 20 | "humanities", 21 | "economics-finance-domain", 22 | "ela", 23 | ] 24 | 25 | # Tags and attributes for parsing HTML 26 | 27 | COURSE_HEAD = {"tag": "h2", "class": "_t2uf76"} 28 | COURSE_URL = {"tag": "a", "class": "_dwmetq"} 29 | COURSE_TITLE = {"data-test-id": "course-unit-title"} 30 | COURSE_UNIT_TITLE = {"data-test-id": "unit-header"} 31 | COURSE_SUBUNIT_TITLE_ATTRS = {"data-test-id": "lesson-card-link"} 32 | COURSE_SUBUNIT_BODY = {"tag": "ul", "class": "_37mhyh"} 33 | COURSE_LESSON_BODY = {"tag": "div", "class_i": "_10ct3cvu", "class_ii": "_1p9458yw"} 34 | COURSE_LESSON_SPAN = {"tag": "span", "class": "_e296pg"} 35 | COURSE_LESSON_LABEL = "aria-label" 36 | COURSE_LESSON_TITLE = {"tag": "span", "class": "_14hvi6g8"} 37 | 38 | """ 39 | 40 | Domain -> Course -> Unit Page -> Subunit Header + Subunit Block -> Lesson Block -> Lesson Title 41 | 42 | """ 43 | 44 | 45 | def clear_screen(): 46 | if platform.system() == "Linux" or platform.system() == "Darwin": 47 | os.system("clear") 48 | elif platform.system() == "Windows": 49 | os.system("cls") 50 | 51 | 52 | # Youtube-dl NoLogger 53 | class MyLogger(object): 54 | def debug(self, msg): 55 | pass 56 | 57 | def warning(self, msg): 58 | pass 59 | 60 | def error(self, msg): 61 | pass 62 | 63 | 64 | class KhanDL: 65 | def __init__(self): 66 | self.domain = "" 67 | self.course_url = "" 68 | self.course_title = "" 69 | self.course_page = "" 70 | self.course_unit_titles = [] 71 | self.course_unit_slugs = [] 72 | self.course_unit_urls = [] 73 | self.course_all_slugs = [] 74 | self.lesson_titles = [] 75 | self.lesson_youtube_ids = [] 76 | self.output_rel_path = os.getcwd() + "/" 77 | self.unit_ids_counter = {} 78 | self.unit_slugs_counter = {} 79 | self.nested_courses = [] 80 | self.course_subunits = [] 81 | self.selected_course = "" 82 | 83 | def get_courses(self, selected_domain_url: str) -> Tuple[List[str], List[str]]: 84 | """Returns the list of courses on a domain""" 85 | 86 | courses, courses_url = [], [] 87 | print("\nDownloading Courses...\n") 88 | try: 89 | selected_domain_page = BeautifulSoup( 90 | requests.get(selected_domain_url).text, "lxml" 91 | ) 92 | except requests.ConnectionError as e: 93 | print("Error Connecting!\n", e) 94 | sys.exit(1) 95 | except requests.exceptions.HTTPError as errh: 96 | print("Http Error:", errh) 97 | sys.exit(1) 98 | except requests.exceptions.ConnectionError as errc: 99 | print("Error Connecting:", errc) 100 | sys.exit(1) 101 | except requests.exceptions.Timeout as errt: 102 | print("Timeout Error:", errt) 103 | sys.exit(1) 104 | except requests.exceptions.RequestException as err: 105 | print("OOps: Something Else", err) 106 | sys.exit(1) 107 | 108 | for course_header in selected_domain_page.find_all( 109 | COURSE_HEAD["tag"], class_=COURSE_HEAD["class"] 110 | ): 111 | course = course_header.find( 112 | COURSE_URL["tag"], class_=COURSE_URL["class"] 113 | ).text 114 | courses.append(course) 115 | 116 | course_link = course_header.find( 117 | COURSE_URL["tag"], class_=COURSE_URL["class"] 118 | ) 119 | course_slug = course_link["href"] 120 | courses_url.append(ROOT_URL + course_slug) 121 | return courses, courses_url 122 | 123 | def domain_prompt(self): 124 | """Returns the selected domain""" 125 | 126 | # Domain selection prompt 127 | domain_completer = FuzzyWordCompleter( 128 | list(map(str.title, DOMAINS)) 129 | ) # Titlecase for aesthetics 130 | selected_domain = DOMAINS.index( 131 | prompt("Domain: ", completer=domain_completer).lower() 132 | ) 133 | 134 | print("Selected Domain: {}".format(DOMAINS[selected_domain])) 135 | self.domain = DOMAINS[selected_domain] 136 | logging.info("Domain Selected") 137 | 138 | def course_prompt(self): 139 | """Returns URL for the selected course""" 140 | 141 | selected_domain_url = ROOT_URL + "/" + self.domain 142 | courses, courses_url = self.get_courses(selected_domain_url) 143 | 144 | # Course Selection Prompt 145 | 146 | logging.debug(courses) 147 | courses_completer = FuzzyWordCompleter(courses) 148 | selected_course_index = courses.index( 149 | prompt("Course: ", completer=courses_completer) 150 | ) 151 | self.selected_course = courses[selected_course_index] 152 | print("Selected Course: {}".format(self.selected_course)) 153 | self.course_url = courses_url[selected_course_index] 154 | logging.info("Course Selected") 155 | 156 | def get_all_courses(self) -> List[str]: 157 | """Returns URL for all courses""" 158 | 159 | print("Downloading all Courses from all Domains...") 160 | all_courses_url = [] 161 | for domain in DOMAINS: 162 | print("Selected Domain: ", domain) 163 | selected_domain_url = ROOT_URL + "/" + domain 164 | courses, courses_url = self.get_courses(selected_domain_url) 165 | all_courses_url += courses_url 166 | 167 | return all_courses_url 168 | 169 | def get_course_page(self): 170 | """Retrieves course page html""" 171 | 172 | print("Course URL: {}".format(self.course_url)) 173 | try: 174 | self.course_page = BeautifulSoup(requests.get(self.course_url).text, "lxml") 175 | except requests.ConnectionError as e: 176 | print("Error Connecting!\n", e) 177 | sys.exit(1) 178 | except requests.exceptions.HTTPError as errh: 179 | print("Http Error:", errh) 180 | sys.exit(1) 181 | except requests.exceptions.ConnectionError as errc: 182 | print("Error Connecting:", errc) 183 | sys.exit(1) 184 | except requests.exceptions.Timeout as errt: 185 | print("Timeout Error:", errt) 186 | sys.exit(1) 187 | except requests.exceptions.RequestException as err: 188 | print("Oops: Something Else", err) 189 | sys.exit(1) 190 | 191 | def get_course_title(self): 192 | """Retrieves the course title""" 193 | 194 | course_title = self.course_page.find(attrs=COURSE_TITLE) 195 | if course_title and course_title.text: 196 | self.course_title = course_title.text.replace(" ", "_") 197 | logging.debug("course_title:{}".format(self.course_title)) 198 | logging.info("Course title retrieved") 199 | 200 | def get_course_unit_titles(self): 201 | """Retrieves course unit titles""" 202 | self.course_unit_titles = [] 203 | for title in self.course_page.find_all(attrs=COURSE_UNIT_TITLE): 204 | if "unit" in str(title.text).lower(): 205 | self.course_unit_titles.append(title.text) 206 | logging.debug("course_unit_titles:{}".format(self.course_unit_titles)) 207 | logging.info("Course unit titles retrieved") 208 | 209 | def get_course_unit_slugs(self): 210 | """Retrieves course unit slugs""" 211 | self.course_unit_slugs = [] 212 | counter = 0 213 | for title in self.course_unit_titles: 214 | self.course_unit_slugs.append( 215 | self.course_title + "/" + str(counter) + "_" + title.replace(" ", "_") 216 | ) 217 | counter += 1 218 | logging.debug("course_unit_slugs:{}".format(self.course_unit_slugs)) 219 | logging.info("Course unit slugs generated") 220 | 221 | def get_course_unit_urls(self): 222 | """Retrieves course unit urls""" 223 | self.course_unit_urls = [] 224 | self.nested_courses = [] 225 | for url in self.course_page.find_all(attrs=COURSE_UNIT_TITLE): 226 | if int(url["href"].count("/")) > 2: 227 | self.course_unit_urls.append(url["href"]) 228 | else: 229 | self.nested_courses.append(url["href"]) 230 | logging.debug("course_unit_urls:{}".format(self.course_unit_urls)) 231 | logging.debug("nested_courses:{}".format(self.nested_courses)) 232 | logging.info("Course unit urls retrieved") 233 | 234 | def get_course_all_slugs(self): 235 | """Generate slugs for all units""" 236 | 237 | unit_lessons_counter = 0 238 | # Unit Page -> Subunit Header + Subunit Block -> Lesson Block -> Lesson Title 239 | for course_unit_url, course_unit_slug, course_unit_title in zip( 240 | self.course_unit_urls, self.course_unit_slugs, self.course_unit_titles 241 | ): 242 | unit_lessons_counter = 0 243 | # -> Unit Page 244 | try: 245 | course_unit_page = BeautifulSoup( 246 | requests.get(ROOT_URL + course_unit_url).text, "lxml" 247 | ) 248 | except requests.ConnectionError as e: 249 | print("Error Connecting!\n", e) 250 | sys.exit(1) 251 | except requests.exceptions.HTTPError as errh: 252 | print("Http Error:", errh) 253 | sys.exit(1) 254 | except requests.exceptions.ConnectionError as errc: 255 | print("Error Connecting:", errc) 256 | sys.exit(1) 257 | except requests.exceptions.Timeout as errt: 258 | print("Timeout Error:", errt) 259 | sys.exit(1) 260 | except requests.exceptions.RequestException as err: 261 | print("OOps: Something Else", err) 262 | sys.exit(1) 263 | 264 | subunit_couter = 0 265 | 266 | subunits = [] 267 | # -> Subunit Header -> Subunit Block 268 | for course_subunit_title, course_subunit_body in zip( 269 | course_unit_page.find_all(attrs=COURSE_SUBUNIT_TITLE_ATTRS), 270 | course_unit_page.find_all( 271 | COURSE_SUBUNIT_BODY["tag"], class_=COURSE_SUBUNIT_BODY["class"] 272 | ), 273 | ): 274 | 275 | logging.debug("course_subunit_title:{}".format(course_subunit_title)) 276 | lesson_counter = 0 277 | # -> Lesson Block 278 | lessons = [] 279 | for course_lesson_body in course_subunit_body.find_all( 280 | COURSE_LESSON_BODY["tag"], 281 | { 282 | "class": [ 283 | COURSE_LESSON_BODY["class_i"], 284 | COURSE_LESSON_BODY["class_ii"], 285 | ] 286 | }, 287 | ): 288 | course_lesson_span = course_lesson_body.find_all( 289 | COURSE_LESSON_SPAN["tag"], class_=COURSE_LESSON_SPAN["class"] 290 | ) 291 | course_lesson_aria_label = course_lesson_span[0][ 292 | COURSE_LESSON_LABEL 293 | ] 294 | logging.debug( 295 | "course_lesson_aria_label:{}".format(course_lesson_aria_label) 296 | ) 297 | # -> Lesson Title 298 | # Check whether lesson block is a video 299 | if course_lesson_aria_label == "Video": 300 | lesson_title = course_lesson_body.find( 301 | COURSE_LESSON_TITLE["tag"], 302 | class_=COURSE_LESSON_TITLE["class"], 303 | ) 304 | 305 | logging.debug( 306 | "course_lesson_title:{}".format(lesson_title.text) 307 | ) 308 | lessons.append(lesson_title.text.strip()) 309 | self.lesson_titles.append(lesson_title.text) 310 | self.course_all_slugs.append( 311 | self.output_rel_path 312 | + course_unit_slug 313 | + "/" 314 | + str(subunit_couter) 315 | + "_" 316 | + course_subunit_title.text.replace(" ", "_") 317 | + "/" 318 | + str(lesson_counter) 319 | + "_" 320 | + lesson_title.text.replace(" ", "_") 321 | ) 322 | 323 | lesson_counter += 1 324 | unit_lessons_counter += lesson_counter 325 | subunit_couter += 1 326 | subunits.append({ 327 | "title": course_subunit_title.text.strip(), 328 | "lessons": lessons 329 | }) 330 | self.course_subunits.append({ 331 | "title": course_unit_title, 332 | "subunits": subunits 333 | }) 334 | self.unit_slugs_counter[course_unit_url] = unit_lessons_counter 335 | 336 | logging.info(len(self.course_all_slugs)) 337 | logging.info("Course - All slugs generated") 338 | 339 | def get_course_youtube_ids(self): 340 | """Retrieves youtube id per unit""" 341 | # 342 | # with ProgressBar() as pb: 343 | # for i, unit_url in zip( 344 | # pb(range(len(self.course_unit_urls)), label="Collecting Youtube IDs:"), 345 | # self.course_unit_urls, 346 | # ): 347 | # unit_url = ROOT_URL + unit_url 348 | # yt_dlp_opts = { 349 | # "logger": MyLogger(), 350 | # "retries": 20, 351 | # "ignoreerrors:": True, 352 | # "skip_download": True, 353 | # } 354 | # with yt_dlp.YoutubeDL(yt_dlp_opts) as ydl: 355 | # lessons_counter = 0 356 | # try: 357 | # logging.debug( 358 | # "Collecting youtube ids for unit:{}".format(unit_url) 359 | # ) 360 | # info_dict = ydl.extract_info(unit_url, download=False) 361 | # for video in info_dict["entries"]: 362 | # video_id = video.get("id", None) 363 | # self.lesson_youtube_ids.append(video_id) 364 | # lessons_counter += 1 365 | # except DownloadError as e: 366 | # logging.debug( 367 | # "Collecting youtube ids for unit:{}".format(unit_url) 368 | # ) 369 | # info_dict = ydl.extract_info( 370 | # unit_url, download=False, process=False 371 | # ) 372 | # for video in info_dict["entries"]: 373 | # video_id = video.get("url", None) 374 | # self.lesson_youtube_ids.append(video_id) 375 | # lessons_counter += 1 376 | # except Exception as e: 377 | # print("Youtube-dl: An error occured!", e) 378 | # sys.exit(1) 379 | # 380 | # self.unit_ids_counter[unit_url] = lessons_counter 381 | # 382 | # logging.info(self.lesson_youtube_ids) 383 | # logging.info(len(self.lesson_youtube_ids)) 384 | # logging.info("Course - Collected Youtube IDs") 385 | 386 | def download_course_videos(self): 387 | """Downloads Course Videos""" 388 | # 389 | # counter = 0 390 | # number_of_videos = len(self.course_all_slugs) 391 | # 392 | # with ProgressBar() as pb: 393 | # for i, lesson_output_file, lesson_video_id in zip( 394 | # pb(range(len(self.lesson_youtube_ids)), label="Downloading Videos:"), 395 | # self.course_all_slugs, 396 | # self.lesson_youtube_ids, 397 | # ): 398 | # lesson_youtube_url = VIDEO_SITE_URL + lesson_video_id 399 | # 400 | # yt_dlp_opts = { 401 | # "logger": MyLogger(), 402 | # "outtmpl": lesson_output_file, 403 | # "retries": 20, 404 | # } 405 | # 406 | # with yt_dlp.YoutubeDL(yt_dlp_opts) as ydl: 407 | # logging.debug( 408 | # "Downloading video[{}] {} of {}:".format( 409 | # lesson_youtube_url, counter, number_of_videos 410 | # ) 411 | # ) 412 | # try: 413 | # ydl.download([lesson_youtube_url]) 414 | # counter += 1 415 | # except DownloadError: 416 | # error_log = open("error_private_videos.txt", "a") 417 | # error_log.write( 418 | # str( 419 | # lesson_output_file 420 | # + ", " 421 | # + VIDEO_SITE_URL 422 | # + lesson_video_id 423 | # ) 424 | # ) 425 | # error_log.close() 426 | # except Exception as e: 427 | # print("Youtube-dl: An error occured!", e) 428 | # sys.exit(1) 429 | # logging.info( 430 | # "Course lesson video[{}]downloaded".format(lesson_video_id) 431 | # ) 432 | # logging.info("All course videos downloaded") 433 | 434 | def reset_course(self): 435 | self.domain = "" 436 | self.course_url = "" 437 | self.course_title = "" 438 | self.course_page = "" 439 | self.course_unit_titles = [] 440 | self.course_unit_slugs = [] 441 | self.course_unit_urls = [] 442 | self.course_all_slugs = [] 443 | self.lesson_titles = [] 444 | self.lesson_youtube_ids = [] 445 | self.unit_ids_counter = {} 446 | self.unit_slugs_counter = {} 447 | self.selected_course = "" 448 | self.course_subunits = [] 449 | 450 | def download_nested_courses(self): 451 | self.reset_course() 452 | if self.nested_courses: 453 | print("\nDownloading nested courses...\n") 454 | for nested_course_url in self.nested_courses: 455 | self.download_course_given(ROOT_URL + nested_course_url) 456 | 457 | def download_course_interactive(self): 458 | """Downloads the chosen course""" 459 | self.domain_prompt() 460 | self.course_prompt() 461 | self.get_course_page() 462 | self.get_course_title() 463 | self.get_course_unit_titles() 464 | self.get_course_unit_slugs() 465 | self.get_course_unit_urls() 466 | 467 | print("\nGenerating Path Slugs...\n") 468 | self.get_course_all_slugs() 469 | self.get_course_youtube_ids() 470 | self.download_course_videos() 471 | self.download_nested_courses() 472 | 473 | def download_course_given(self, course_url: str): 474 | """Downloads the given course""" 475 | self.reset_course() 476 | self.course_url = course_url 477 | self.get_course_page() 478 | self.get_course_title() 479 | self.get_course_unit_titles() 480 | self.get_course_unit_slugs() 481 | self.get_course_unit_urls() 482 | 483 | self.get_course_all_slugs() 484 | return { 485 | "domain": self.domain, 486 | "url": self.course_url, 487 | "title": self.course_title, 488 | # "page": self.course_page, 489 | "unit_titles": self.course_unit_titles, 490 | # "unit_slugs": self.course_unit_slugs, 491 | # "unit_urls": self.course_unit_urls, 492 | # "all_slugs": self.course_all_slugs, 493 | # "lesson_titles": self.lesson_titles, 494 | "subunits": self.course_subunits 495 | } 496 | 497 | print("\nGenerating Path Slugs...\n") 498 | # self.get_course_youtube_ids() 499 | # self.download_course_videos() 500 | -------------------------------------------------------------------------------- /prompts/khanacademy/khan_dl/main.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/rand-net/khan-dl 2 | 3 | import json 4 | import logging.handlers 5 | 6 | from tqdm import tqdm 7 | 8 | from khan_dl import * 9 | import argparse 10 | import sys 11 | from art import tprint 12 | 13 | __version__ = "1.2.8" 14 | 15 | 16 | def set_log_level(args): 17 | if not args.verbose: 18 | logging.basicConfig(level=logging.ERROR) 19 | elif int(args.verbose) == 1: 20 | logging.basicConfig(level=logging.WARNING) 21 | elif int(args.verbose) == 2: 22 | logging.basicConfig(level=logging.INFO) 23 | elif int(args.verbose) >= 3: 24 | logging.basicConfig(level=logging.DEBUG) 25 | 26 | 27 | def main(argv=None): 28 | argv = sys.argv if argv is None else argv 29 | argparser = argparse.ArgumentParser() 30 | argparser.add_argument( 31 | "-i", 32 | "--interactive", 33 | help="Enter Interactive Course Selection Mode", 34 | dest="interactive_prompt", 35 | action="store_true", 36 | ) 37 | argparser.add_argument( 38 | "-c", 39 | "--course_url", 40 | help="Enter Course URL", 41 | ) 42 | 43 | argparser.add_argument( 44 | "-a", 45 | "--all", 46 | help="Download all Courses from all Domains", 47 | action="store_true", 48 | ) 49 | 50 | argparser.add_argument( 51 | "-v", 52 | "--verbose", 53 | help="Verbose Levels of log. 1 = Warning; 2 = Info; 3 = Debug", 54 | ) 55 | 56 | args = argparser.parse_args() 57 | 58 | if args.interactive_prompt: 59 | set_log_level(args) 60 | tprint("KHAN-DL") 61 | khan_down = KhanDL() 62 | khan_down.download_course_interactive() 63 | 64 | elif args.course_url: 65 | set_log_level(args) 66 | tprint("KHAN-DL") 67 | print("Looking up " + args.course_url + "...") 68 | selected_course_url = args.course_url 69 | khan_down = KhanDL() 70 | khan_down.download_course_given(selected_course_url) 71 | 72 | elif args.all: 73 | set_log_level(args) 74 | tprint("KHAN-DL") 75 | khan_down = KhanDL() 76 | all_course_urls = khan_down.get_all_courses() 77 | courses = [khan_down.download_course_given(course_url) for course_url in tqdm(all_course_urls)] 78 | with open("khan_courses.json", "w") as outfile: 79 | outfile.write(json.dumps(courses, indent=4)) 80 | 81 | 82 | if __name__ == "__main__": 83 | main() 84 | -------------------------------------------------------------------------------- /prompts/khanacademy/khan_dl/requirements.txt: -------------------------------------------------------------------------------- 1 | art==5.5 2 | beautifulsoup4==4.11.1 3 | certifi==2021.10.8 4 | charset-normalizer==2.0.12 5 | idna==3.3 6 | lxml==4.8.0 7 | prompt-toolkit==3.0.29 8 | requests==2.27.1 9 | soupsieve==2.3.2.post1 10 | urllib3==1.26.9 11 | wcwidth==0.2.5 12 | yt-dlp==2022.5.18 13 | tqdm>=4.66.2 -------------------------------------------------------------------------------- /prompts/openstax/README.md: -------------------------------------------------------------------------------- 1 | ## Synthetic generations from OpenStax 2 | 3 | To build prompts from the OpenStax, the dataset with the course outline and introductions is avilable [here](https://huggingface.co/datasets/HuggingFaceTB/openstax_paragraphs). We generate textbooks for 4 different audiences: young children, middle school students, professionals and researchers and college students. Each prompt was carefully tailored based on the target audience. 4 | ```` 5 | python ./build_openstax_prompts.py 6 | ``` 7 | -------------------------------------------------------------------------------- /prompts/openstax/build_openstax_prompts.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import random 3 | import argparse 4 | import numpy as np 5 | from datasets import Dataset, load_dataset, concatenate_datasets 6 | 7 | 8 | STYLES = {"young children": 9 | {"beginning": "Create a fun and simple e-learning module on {{X}}, tailored for 5 to 10 year-old children. Opt for a playful and imaginative approach, suitable for very young learners.\n", 10 | "criteria":"""In this module for young children, aim to: 11 | 12 | - Use very simple, everyday words and phrases that a 5-year-old would easily understand, avoiding any complex concepts or technical terms. 13 | - Tell a short, engaging story with colorful cartoon characters. For instance, to illustrate economic trade concepts use characters like animals or friendly creatures trading snacks or toys. Another example is addition and calculus, use apples to explain: '2 apples + 3 apples = 5 apples' . 14 | - Keep the tone light, cheerful, and encouraging. Do not use images."""}, 15 | 16 | "middle school students": 17 | {"beginning": "Create an engaging and accessible e-learning module on {{X}}, tailored for middle school students without prior knowledge on the topic.\n", 18 | "criteria": """Instead of a traditional textbook approach, use a story-based narrative to explain the concept. Try to: 19 | 20 | - Avoid technical jargon and present the ideas in a straightforward, conversational tone to spark curiosity and relate to the experiences of a younger audience. 21 | - Include interactive elements like thought experiments and real-life scenarios. The goal is to topic approachable and fun, sparking curiosity about how it applies to everyday life. 22 | - Do not use introductory phrases such as "welcome to this unit" at the beginning or conclusions the end. Do not use images."""}, 23 | 24 | "professionals and researchers": 25 | {"beginning": "Create an extract of a scientific journal article for {{X}}, tailored for professionals and researchers on the topic.\n", 26 | "criteria": """The style should mirror that of a scholarly publication, not school textbooks, aiming to engage a highly knowledgeable audience with very deep expertise. Try to: 27 | 28 | - Present advanced theories, using technical and academic language. 29 | - Include critical analysis of recent research findings and debates in the field, with a detailed examination of empirical data and statistical methodologies. 30 | - The article should reflect the depth and complexity of content found in top-tier economics journals, intended for a readership deeply entrenched in the field. 31 | - Do not add come up with references or add them at the end of the article. If there are mathematical expressions use a correct LateX formatting and do not use images."""}, 32 | 33 | "college students": 34 | {"beginning": "Write a comprehensive and in-depth textbook on {{X}}, tailored for college students.\n", 35 | "criteria": """Try to be: 36 | 37 | - Rigorous: Ensure very detailed and in-depth coverage of the concepts. 38 | - Engaging: Write with an academic and engaging tone that captivates interest. 39 | - Applied: Use specific and practical examples. For example, if the topic is integration in calculus, include equations and proofs of the concept you're teaching. As another example, if the topic is the history of the United States, include dates, names, and key events. 40 | If there are mathematical expressions use a correct LateX formatting. Do not use images and avoid introductory phrases such as "welcome to this unit" at the beginning or conclusions the end."""}} 41 | 42 | 43 | def get_args(): 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument("--repo_id", type=str, default="HuggingFaceTB/openstax_prompts") 46 | return parser.parse_args() 47 | 48 | 49 | def parse_chapter(chapter, level=0, trail=[]): 50 | """Parse each chapter recursively and take into account sections""" 51 | trail_current = trail + [chapter["title"]] 52 | chapter_info = { 53 | "title": chapter["title"], 54 | "level": level, 55 | "trail": trail_current, 56 | "abstract": chapter.get("abstract", ""), 57 | "sections": [], 58 | "sub_chapters": [], 59 | } 60 | 61 | if chapter.get("sections"): 62 | for section in chapter.get("sections"): 63 | chapter_info["sections"].append( 64 | { 65 | "title": section["title"], 66 | "content": section.get("paragraph", ""), 67 | "trail": trail_current, 68 | "abstract": section.get("abstract", ""), 69 | } 70 | ) 71 | 72 | # Handle sub-chapters recursively 73 | if chapter.get("chapters"): 74 | for sub_chapter in chapter.get("chapters"): 75 | chapter_info["sub_chapters"].append( 76 | parse_chapter(sub_chapter, level + 1, trail_current) 77 | ) 78 | 79 | return chapter_info 80 | 81 | 82 | def parse_book(book): 83 | """Parse and rearrange a book""" 84 | book_info = {"title": book["book_title"], "chapters": []} 85 | 86 | for chapter in book["chapters"]: 87 | if "preface" in chapter["title"].lower(): 88 | continue 89 | book_info["chapters"].append(parse_chapter(chapter)) 90 | 91 | return book_info 92 | 93 | 94 | def build_prompts( 95 | parsed_book, style="college students", include_reference=True, refrence_size=500 96 | ): 97 | """Build prompts based on the (deepest) sections in each book""" 98 | prompts = [] 99 | target_units = [] 100 | chapters = [chap["title"] for chap in parsed_book["chapters"]] 101 | chosen_style = STYLES[style] 102 | start_prompt = chosen_style["beginning"] 103 | end_prompt = "\n" + chosen_style["criteria"] 104 | empty_content = 0 105 | 106 | for i, chapter in enumerate(parsed_book["chapters"]): 107 | chapter_prompt = start_prompt.replace("{{X}}", f"'{parsed_book['title']}'") 108 | chapter_prompt += f"We are writing on chapter '{chapter['title']}'. " 109 | 110 | for subchapter in chapter["sub_chapters"]: 111 | # Iterate over sections in subchapters 112 | if subchapter["sections"]: 113 | subchapter_prompt = ( 114 | chapter_prompt + f"In particular, section '{subchapter['title']}'. " 115 | ) 116 | for i, unit in enumerate(subchapter["sections"]): 117 | if i != 0: 118 | units = [s["title"] for s in subchapter["sections"][:i]] 119 | units = list(np.random.choice(units, random.randint(2, 5))) if len(units) > 5 else units 120 | prev_units = ", ".join([f"'{name}'" for name in units]) 121 | plural = "s" if i > 1 else "" 122 | subchapter_prompt += f"We have already covered the following unit{plural} in this section: {prev_units}. " 123 | if unit["content"] and include_reference: 124 | size = len(unit["content"]) 125 | ref = f" Here's some text for inspiration: {unit['content'][:min(refrence_size, size)]}".rstrip(".").rstrip() + "." 126 | else: 127 | empty_content += 1 128 | ref = "" 129 | new_prompt = ( 130 | subchapter_prompt 131 | + f"Write a new unit titled '{unit['title']}'.{ref}\n" 132 | ) 133 | prompts.append(new_prompt + end_prompt) 134 | target_units.append(unit['title']) 135 | else: 136 | # Handle nested subchapters 137 | for k, e in enumerate(subchapter["sub_chapters"]): 138 | if e["sections"]: 139 | subchapter_prompt = ( 140 | chapter_prompt 141 | + f"In particular, section '{e['title']}' of '{subchapter['title']}' part. " 142 | ) 143 | for i, unit in enumerate(e["sections"]): 144 | current_prompt = subchapter_prompt 145 | if i != 0: 146 | units = [s["title"] for s in e["sections"][:i]] 147 | units = list(np.random.choice(units, random.randint(2, 5))) if len(units) > 5 else units 148 | prev_units = ", ".join([f"'{name}'" for name in units]) 149 | plural = "s" if i > 1 else "" 150 | current_prompt += f"We have already covered the following unit{plural} in this section: {prev_units}. " 151 | if unit["content"] and include_reference: 152 | size = len(unit["content"]) 153 | ref = f" Here's some text for inspiration: {unit['content'][:min(refrence_size, size)]}".rstrip(".").rstrip() + "." 154 | else: 155 | empty_content += 1 156 | ref = "" 157 | new_prompt = ( 158 | current_prompt 159 | + f"Write a new unit titled '{unit['title']}'.{ref}\n" 160 | ) 161 | target_units.append(unit['title']) 162 | prompts.append(new_prompt + end_prompt) 163 | else: 164 | if "introduction" not in e['title'].lower() and e.get("abstract"): 165 | new_prompt = chapter_prompt 166 | if k != 0: 167 | subchapters = [s["title"] for s in subchapter["sub_chapters"][:k]] 168 | subchapters = list(np.random.choice(subchapters, random.randint(2, 5))) if len(subchapters) > 5 else subchapters 169 | prev_subchapters = ", ".join([f"'{name}'" for name in subchapters]) 170 | plural = "s" if k > 1 else "" 171 | new_prompt += f"We have already covered the following unit{plural} in this chapter: {prev_subchapters}. " 172 | new_prompt = new_prompt + f"Write a new unit titled {e['title']}." 173 | if include_reference: 174 | size = len(e["abstract"]) 175 | new_prompt += f" Here's some text for inspiration: {e['abstract'][:min(refrence_size, size)]}" 176 | else: 177 | empty_content += 1 178 | target_units.append(e['title']) 179 | prompts.append(new_prompt.rstrip('.').rstrip() + '.\n' + end_prompt) 180 | else: 181 | continue 182 | return prompts, target_units, empty_content 183 | 184 | 185 | if __name__ == "__main__": 186 | args = get_args() 187 | include_references = [True, False] 188 | refrence_size = 600 189 | ds = load_dataset("HuggingFaceTB/openstax_paragraphs", split="train") 190 | ds_en = ds.filter(lambda x: x["language"] == "en") 191 | 192 | print(f"English books dataset: {ds_en}") 193 | print("🔍 Parsing books...") 194 | parsed_books = [parse_book(e) for e in ds_en] 195 | datasets_list = [] 196 | for include_reference in include_references: 197 | ref = "with_ref" if include_reference else "no_ref" 198 | for style in STYLES: 199 | print(f"🧩 Building prompts for {style} {ref}...") 200 | outputs = [ 201 | build_prompts( 202 | book, 203 | style=style, 204 | include_reference=include_reference, 205 | refrence_size=refrence_size, 206 | ) 207 | for book in parsed_books 208 | ] 209 | prompts = [p[0] for p in outputs] 210 | target_units = [p[1] for p in outputs] 211 | empty_content = [p[2] for p in outputs] 212 | sizes = [len(p) for p in prompts] 213 | print( 214 | f"✅ Done building {sum(sizes)} prompts! ({sum(empty_content)} without reference text)" 215 | ) 216 | 217 | print(f"🌟 Examples:") 218 | print(f"- {prompts[random.randint(0, 5)][0]}") 219 | print(f"\n- {prompts[random.randint(0, 5)][-1]}") 220 | 221 | print("Converting to HF dataset and pushing to Hub...") 222 | flattened_prompts = [] 223 | for book_prompts, book_units in zip(prompts, target_units): 224 | for prompt, unit in zip(book_prompts, book_units): 225 | book_title = prompt.split(", tailored for")[0].split("'")[1] 226 | flattened_prompts.append((prompt, unit, book_title)) 227 | 228 | df = pd.DataFrame(flattened_prompts, columns=["prompt", "unit", "book title"]) 229 | ds = Dataset.from_pandas(df) 230 | audience = "_".join(style.split(" ")) 231 | ds = ds.add_column("audience", [f"{audience}_{ref}" for _ in range(len(ds))]) 232 | print(ds) 233 | datasets_list.append(ds) 234 | 235 | final_ds = concatenate_datasets(datasets_list) 236 | print(final_ds) 237 | final_ds.push_to_hub(args.repo_id, private=True) 238 | 239 | 240 | 241 | -------------------------------------------------------------------------------- /prompts/stanford/1_scraper.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "Jupyter notebook to scrape stanford's list of courses. Gets the following:\n", 7 | "- course title\n", 8 | "- course description\n", 9 | "- course numbers/ids" 10 | ], 11 | "metadata": { 12 | "collapsed": false 13 | }, 14 | "id": "aeed1babdd1ced9b" 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "id": "initial_id", 20 | "metadata": { 21 | "collapsed": true, 22 | "ExecuteTime": { 23 | "end_time": "2023-09-27T10:58:34.215404466Z", 24 | "start_time": "2023-09-27T10:58:34.208962606Z" 25 | } 26 | }, 27 | "outputs": [], 28 | "source": [ 29 | "from time import sleep\n", 30 | "\n", 31 | "from tqdm import tqdm\n", 32 | "\n", 33 | "MAIN_INDEX_URL = \"https://explorecourses.stanford.edu/search?q=all%20courses\"" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 3, 39 | "outputs": [], 40 | "source": [ 41 | "import requests\n", 42 | "from bs4 import BeautifulSoup\n", 43 | "import re" 44 | ], 45 | "metadata": { 46 | "collapsed": false, 47 | "ExecuteTime": { 48 | "end_time": "2023-09-27T10:58:34.720087515Z", 49 | "start_time": "2023-09-27T10:58:34.640417536Z" 50 | } 51 | }, 52 | "id": "b4ca6254d01f8dc8" 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 99, 57 | "outputs": [], 58 | "source": [ 59 | "headers = {'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36'}" 60 | ], 61 | "metadata": { 62 | "collapsed": false, 63 | "ExecuteTime": { 64 | "end_time": "2023-09-25T10:03:23.322727830Z", 65 | "start_time": "2023-09-25T10:03:23.319789021Z" 66 | } 67 | }, 68 | "id": "cd55f81c3d1d42dd" 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "source": [ 73 | "Scrape all courses.\n", 74 | "Change the number of pages based on the footer on https://explorecourses.stanford.edu/search?q=all%20courses" 75 | ], 76 | "metadata": { 77 | "collapsed": false 78 | }, 79 | "id": "bbce7731b5ff3ad2" 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 103, 84 | "outputs": [ 85 | { 86 | "name": "stderr", 87 | "output_type": "stream", 88 | "text": [ 89 | "100%|██████████| 1541/1541 [37:09<00:00, 1.45s/it]\n" 90 | ] 91 | } 92 | ], 93 | "source": [ 94 | "TOTAL_PAGES = 1541\n", 95 | "\n", 96 | "def formatt(text):\n", 97 | " text = text.strip(\"\\r\\n\\t\")\n", 98 | " if text.endswith(\"more »\"):\n", 99 | " text = text[:-6]\n", 100 | " return text.strip(\"\\r\\n\\t\")\n", 101 | "\n", 102 | "all_courses = []\n", 103 | "\n", 104 | "# sadly the api seems to be for students and faculty only\n", 105 | "\n", 106 | "for p in tqdm(range(TOTAL_PAGES), total=TOTAL_PAGES):\n", 107 | " r = requests.get(MAIN_INDEX_URL + f\"&page={p}\", headers=headers)\n", 108 | " soup = BeautifulSoup(r.content)\n", 109 | " courses = [{\n", 110 | " \"number\": x.find(\"span\", {\"class\": 'courseNumber'}).text.rstrip(\":\"),\n", 111 | " \"title\": x.find(\"span\", {\"class\": 'courseTitle'}).text,\n", 112 | " \"description\": formatt(x.find(\"div\", {\"class\": 'courseDescription'}).text),\n", 113 | " } for x in soup.find_all(\"div\", {\"class\": \"courseInfo\"})]\n", 114 | " all_courses.extend(courses)\n", 115 | " # don't spam their servers too much\n", 116 | " sleep(0.5)" 117 | ], 118 | "metadata": { 119 | "collapsed": false, 120 | "ExecuteTime": { 121 | "end_time": "2023-09-25T10:41:21.425942204Z", 122 | "start_time": "2023-09-25T10:04:11.761537359Z" 123 | } 124 | }, 125 | "id": "4523f23672936f0a" 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "source": [ 130 | "Deduplicate courses.\n", 131 | "Courses listed multiple times with different ids have the other ids inside brackets" 132 | ], 133 | "metadata": { 134 | "collapsed": false 135 | }, 136 | "id": "b140124d216e49b3" 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 127, 141 | "outputs": [], 142 | "source": [ 143 | "course_ids = set()\n", 144 | "unique_courses = []\n", 145 | "for course in all_courses:\n", 146 | " # check if we already found a duplicate of this course\n", 147 | " if course[\"number\"] in course_ids:\n", 148 | " continue\n", 149 | " ids = [course[\"number\"]]\n", 150 | " res = re.search(r\"\\((.*?)\\)\", course[\"title\"])\n", 151 | " if res:\n", 152 | " ids.extend(res.group(1).split(\", \"))\n", 153 | " course[\"title\"] = course[\"title\"][:course[\"title\"].rindex(\"(\") - 1] # strip the course ids from the title \"(...\"\n", 154 | " course_ids.update(ids)\n", 155 | " unique_courses.append({\n", 156 | " **course,\n", 157 | " \"number\": \", \".join(ids)\n", 158 | " })" 159 | ], 160 | "metadata": { 161 | "collapsed": false, 162 | "ExecuteTime": { 163 | "end_time": "2023-09-25T11:13:39.243448406Z", 164 | "start_time": "2023-09-25T11:13:39.189180954Z" 165 | } 166 | }, 167 | "id": "dc18a648c4fa5d50" 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 140, 172 | "outputs": [], 173 | "source": [ 174 | "import pandas as pd\n", 175 | "df = pd.DataFrame(unique_courses)\n", 176 | "df.to_csv(\"stanford_courses_unique.csv\")" 177 | ], 178 | "metadata": { 179 | "collapsed": false, 180 | "ExecuteTime": { 181 | "end_time": "2023-09-25T11:46:00.367181772Z", 182 | "start_time": "2023-09-25T11:46:00.283586989Z" 183 | } 184 | }, 185 | "id": "e8e4338438a8e5ae" 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "source": [ 190 | "Clean descriptions:\n", 191 | "- remove urls\n", 192 | "- remove course ids\n", 193 | "- remove \"Continuation of\"" 194 | ], 195 | "metadata": { 196 | "collapsed": false 197 | }, 198 | "id": "78788009ded5f401" 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 141, 203 | "outputs": [], 204 | "source": [ 205 | "cleaned_courses = []\n", 206 | "for course in unique_courses:\n", 207 | " desc = course[\"description\"]\n", 208 | " # urls\n", 209 | " desc = re.sub('http[s]?://\\S+', '', desc)\n", 210 | " # course names\n", 211 | " desc = re.sub('[A-Z]+ \\d+([A-Z]+)?', '', desc)\n", 212 | " cleaned_courses.append({\n", 213 | " **course,\n", 214 | " \"description\": desc\n", 215 | " })\n", 216 | " " 217 | ], 218 | "metadata": { 219 | "collapsed": false, 220 | "ExecuteTime": { 221 | "end_time": "2023-09-25T12:20:44.838325706Z", 222 | "start_time": "2023-09-25T12:20:44.761095868Z" 223 | } 224 | }, 225 | "id": "a59836f513898f95" 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": 4, 230 | "outputs": [], 231 | "source": [ 232 | "import pandas as pd" 233 | ], 234 | "metadata": { 235 | "collapsed": false, 236 | "ExecuteTime": { 237 | "end_time": "2023-09-27T10:58:39.102293877Z", 238 | "start_time": "2023-09-27T10:58:38.900613044Z" 239 | } 240 | }, 241 | "id": "c2524526f5e1bdf4" 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": 143, 246 | "outputs": [], 247 | "source": [ 248 | "df = pd.DataFrame(cleaned_courses)" 249 | ], 250 | "metadata": { 251 | "collapsed": false, 252 | "ExecuteTime": { 253 | "end_time": "2023-09-25T12:20:45.440192733Z", 254 | "start_time": "2023-09-25T12:20:45.438743636Z" 255 | } 256 | }, 257 | "id": "ed912eb58e38574c" 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 144, 262 | "outputs": [], 263 | "source": [ 264 | "df.to_csv(\"stanford_courses_cleaned.csv\")" 265 | ], 266 | "metadata": { 267 | "collapsed": false, 268 | "ExecuteTime": { 269 | "end_time": "2023-09-25T12:20:45.777534972Z", 270 | "start_time": "2023-09-25T12:20:45.702459885Z" 271 | } 272 | }, 273 | "id": "4d21646ef7ff9593" 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": 5, 278 | "outputs": [], 279 | "source": [ 280 | "df = pd.read_csv(\"stanford_courses_cleaned.csv\", dtype=str, na_values='', keep_default_na=False)" 281 | ], 282 | "metadata": { 283 | "collapsed": false, 284 | "ExecuteTime": { 285 | "end_time": "2023-09-27T10:58:44.461354916Z", 286 | "start_time": "2023-09-27T10:58:44.414283703Z" 287 | } 288 | }, 289 | "id": "303ce35ad306db4b" 290 | }, 291 | { 292 | "cell_type": "markdown", 293 | "source": [ 294 | "Some preprocessing to remove generic descriptions" 295 | ], 296 | "metadata": { 297 | "collapsed": false 298 | }, 299 | "id": "8ae1b01a983d663f" 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": 6, 304 | "outputs": [], 305 | "source": [ 306 | "import string\n", 307 | "\n", 308 | "def detect_generic(text):\n", 309 | " text = text.lower().translate(str.maketrans('', '', string.punctuation))\n", 310 | " if text in (\"tba\", \"tbd\", \"description tbd\"):\n", 311 | " return False\n", 312 | " for x in (\"prerequisite\", \"continuation of\", \"graduation\", \"prior arrangement\", \"consent of instructor\", \"doctoral practicum\", \"may be repeated\", \"required suprvised\", \"program consent required\", \"supervised experience\", \"students must obtain\", \"graduate\", \"research\", \"tutorial in\", \"independent study\", \"for credit\", \"for advanced\"):\n", 313 | " text = text.replace(x, \"\")\n", 314 | " return len(text) < 20" 315 | ], 316 | "metadata": { 317 | "collapsed": false, 318 | "ExecuteTime": { 319 | "end_time": "2023-09-27T10:58:45.230915973Z", 320 | "start_time": "2023-09-27T10:58:45.223450162Z" 321 | } 322 | }, 323 | "id": "ced2e95857c5d12b" 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": 7, 328 | "outputs": [], 329 | "source": [ 330 | "non_generic_courses = []\n", 331 | "\n", 332 | "for a, b in df.iterrows():\n", 333 | " # no description\n", 334 | " if not isinstance(b[\"description\"], str):\n", 335 | " if len(b[\"title\"]) < 25: # no description + short title = unusable\n", 336 | " continue\n", 337 | " b[\"description\"] = \"TBD\"\n", 338 | " if detect_generic(b[\"description\"]):\n", 339 | " continue\n", 340 | " non_generic_courses.append(b)" 341 | ], 342 | "metadata": { 343 | "collapsed": false, 344 | "ExecuteTime": { 345 | "end_time": "2023-09-27T10:58:46.428860754Z", 346 | "start_time": "2023-09-27T10:58:45.871251538Z" 347 | } 348 | }, 349 | "id": "8dbb9403c43e07a9" 350 | }, 351 | { 352 | "cell_type": "code", 353 | "execution_count": 10, 354 | "outputs": [], 355 | "source": [ 356 | "pd.DataFrame(non_generic_courses).to_csv(\"stanford_courses_cleaned_non_generic.csv\", index=False)" 357 | ], 358 | "metadata": { 359 | "collapsed": false, 360 | "ExecuteTime": { 361 | "end_time": "2023-09-27T11:08:23.018882075Z", 362 | "start_time": "2023-09-27T11:08:20.205486903Z" 363 | } 364 | }, 365 | "id": "5a321cbac67d727f" 366 | }, 367 | { 368 | "cell_type": "code", 369 | "execution_count": null, 370 | "outputs": [], 371 | "source": [], 372 | "metadata": { 373 | "collapsed": false 374 | }, 375 | "id": "b122079355a393a0" 376 | } 377 | ], 378 | "metadata": { 379 | "kernelspec": { 380 | "display_name": "Python 3", 381 | "language": "python", 382 | "name": "python3" 383 | }, 384 | "language_info": { 385 | "codemirror_mode": { 386 | "name": "ipython", 387 | "version": 2 388 | }, 389 | "file_extension": ".py", 390 | "mimetype": "text/x-python", 391 | "name": "python", 392 | "nbconvert_exporter": "python", 393 | "pygments_lexer": "ipython2", 394 | "version": "2.7.6" 395 | } 396 | }, 397 | "nbformat": 4, 398 | "nbformat_minor": 5 399 | } 400 | -------------------------------------------------------------------------------- /prompts/stanford/2_generate_course_outlines.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "initial_id", 7 | "metadata": { 8 | "collapsed": true 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "\n", 13 | "import pandas as pd\n", 14 | "\n", 15 | "df = pd.read_csv(\"stanford_courses_cleaned_non_generic.csv\", dtype=str)" 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "source": [ 21 | "1-shot generation of course outlines from the title and description" 22 | ], 23 | "metadata": { 24 | "collapsed": false 25 | }, 26 | "id": "255265fa0caac0da" 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "outputs": [], 32 | "source": [ 33 | "from string import Template\n", 34 | "\n", 35 | "OUTLINE_TEMPLATE = Template(\"\"\"Write a course outline for a textbook on \\\"The Global Positioning System: Where on Earth are We, and What Time is It?\\\" covering the following topics: \\\"Why people want to know where they are: answers include cross-Pacific trips of Polynesians, missile guidance, and distraught callers. How people determine where they are: navigation technology from dead-reckoning, sextants, and satellite navigation (GPS). Hands-on experience. How GPS works; when it does not work; possibilities for improving performance.\\\".\n", 36 | "Model: 1. Introduction\n", 37 | "- What is the Global Positioning System?\n", 38 | "- Importance of GPS\n", 39 | "- Overview of the course\n", 40 | "\n", 41 | "2. Navigation technology\n", 42 | "- Dead-reckoning\n", 43 | "- Sextants\n", 44 | "- Satellite navigation\n", 45 | "- Comparison of technologies\n", 46 | "- Hands-on experience with navigation technology\n", 47 | "\n", 48 | "3. GPS technology\n", 49 | "- How GPS works\n", 50 | " - Satellites\n", 51 | " - Ground receivers\n", 52 | " - Triangulation\n", 53 | "- When GPS does not work\n", 54 | " - Blockage\n", 55 | " - Multipath\n", 56 | "- Possibilities for improving performance\n", 57 | "\n", 58 | "4. Applications of GPS\n", 59 | "- Cross-Pacific trips of Polynesians\n", 60 | "- Missile guidance\n", 61 | "- Distraught callers\n", 62 | "- Other applications of GPS\n", 63 | "\n", 64 | "User: Write a course outline for a textbook on \\\"${COURSE_TITLE}\\\" covering the following topics: \\\"${COURSE_DESCRIPTION}\\\". Do not include assignments, exams or prerequisites.\n", 65 | "Model: \"\"\")" 66 | ], 67 | "metadata": { 68 | "collapsed": false 69 | }, 70 | "id": "52c53c272dffa9d2" 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "outputs": [], 76 | "source": [ 77 | "courses_to_generate = []\n", 78 | "for a, b in df.iterrows():\n", 79 | " prompt = OUTLINE_TEMPLATE.substitute({\"COURSE_TITLE\": b[\"title\"], \"COURSE_DESCRIPTION\": b[\"description\"]})\n", 80 | " courses_to_generate.append({\n", 81 | " \"course_title\": b[\"title\"],\n", 82 | " \"course_description\": b[\"description\"],\n", 83 | " \"prompt\": prompt,\n", 84 | " })" 85 | ], 86 | "metadata": { 87 | "collapsed": false 88 | }, 89 | "id": "de59641276702b38" 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "outputs": [], 95 | "source": [ 96 | "generations = [...] # code to generate using the prompts here" 97 | ], 98 | "metadata": { 99 | "collapsed": false 100 | }, 101 | "id": "29965888236f2bdd" 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "outputs": [], 107 | "source": [ 108 | "for course, generation in zip(courses_to_generate, generations):\n", 109 | " course[\"outline\"] = generation\n", 110 | "\n", 111 | "pd.DataFrame(courses_to_generate).to_csv(\"outlines_full.csv\")" 112 | ], 113 | "metadata": { 114 | "collapsed": false 115 | }, 116 | "id": "55c5350ab00b5d54" 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "source": [ 121 | "(very large) 2-shot prompt to have the model correct and clean up the generated outlines" 122 | ], 123 | "metadata": { 124 | "collapsed": false 125 | }, 126 | "id": "c0d4dd91d694584f" 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "outputs": [], 132 | "source": [ 133 | "\n", 134 | "OUTLINE_FILTER_TEMPLATE = Template(\"\"\"The following is a course outline for a course on \\\"Anesthesia Operating Room Clerkship\\\". This outline needs to be anonymized and adapted to an online audience:\n", 135 | "1.1 Introduction: Overview of the Anesthesia Operating Room Clerkship\n", 136 | "1.2 Introduction: Objectives of the clerkship\n", 137 | "1.3 Introduction: Prerequisites for the clerkship\n", 138 | "2.1 Clinical settings: Sequoia Hospital in Redwood City\n", 139 | "2.2 Clinical settings: Outpatient surgery centers throughout the community\n", 140 | "2.3 Clinical settings: Exposure to general and regional anesthetic techniques\n", 141 | "2.4 Clinical settings: Adult and pediatric patients\n", 142 | "3.1 Personalized discussion: Applied physiology\n", 143 | "3.2 Personalized discussion: Pharmacology\n", 144 | "3.3 Personalized discussion: Pathophysiology of the surgical patient\n", 145 | "3.4 Personalized discussion: Daily basis\n", 146 | "3.5 Personalized discussion: Final paper to be submitted by the students\n", 147 | "4.1 Transportation: Students need to arrange transportation to the various workplaces\n", 148 | "5.1 Prerequisites: A major clerkship in medicine or surgery is strongly recommended\n", 149 | "6.1 Periods available: 1-12, full-time for 2 weeks\n", 150 | "6.2 Periods available: 1 student per period\n", 151 | "7.1 Clerkship director and coordinator: Kurt Fink, M.D.\n", 152 | "7.2 Clerkship director and coordinator: Yun Tao, 650-724-1706, yuntao@stanford.edu, Stanford Hospital\n", 153 | "8.1 Reporting instructions: Contact Dr. Kurt Fink one week prior\n", 154 | "8.2 Reporting instructions: Time: TBA\n", 155 | "8.3 Reporting instructions: Call code: 0\n", 156 | "9.1 Other faculty: Palo Alto Medical Clinic Anesthesiologist\n", 157 | "10.1 Location: Palo Alto Medical Foundation.\n", 158 | "\n", 159 | "Which of the sections of the outline contain: \n", 160 | "- private faculty members information (names or contact information)\n", 161 | "- prerequisites, requirements, application processes or other practical course information not related to the course content\n", 162 | "- assignments, final papers, exams, presentations or other student evaluation information\n", 163 | "Falcon:\n", 164 | "- private faculty members information (names or contact information): 7.1, 7.2., 8.1, 9.1\n", 165 | "- prerequisites, requirements, application processes, schedules or other practical course information not related to the course content: 1.3, 4.1, 5.1, 6.1, 6.2, 8.1, 8.2, 8.3, 10.1\n", 166 | "- assignments, final papers, exams, presentations or other student evaluation information: 3.5\n", 167 | "User: The following is a course outline for a course on \"Numerical Methods for Compressible Flows\". This outline needs to be anonymized and adapted to an online audience:\n", 168 | "1.1 Introduction: Overview of the course\n", 169 | "1.2 Introduction: Importance of numerical methods for compressible flows\n", 170 | "1.3 Introduction: Prerequisites for the course\n", 171 | "2.1 Mathematical models for compressible flows: Hierarchy of mathematical models\n", 172 | "2.2 Mathematical models for compressible flows: Ideal potential flow\n", 173 | "2.3 Mathematical models for compressible flows: Transonic potential flow\n", 174 | "3.1 Numerical methods for compressible flows: Finite difference methods\n", 175 | "3.2 Numerical methods for compressible flows: Finite volume methods\n", 176 | "3.3 Numerical methods for compressible flows: Finite element methods\n", 177 | "4.1 Representative model problems: Shocks\n", 178 | "4.2 Representative model problems: Expansions\n", 179 | "5.1 Treatment of boundary conditions: Dirichlet boundary conditions\n", 180 | "5.2 Treatment of boundary conditions: Neumann boundary conditions\n", 181 | "6.1 Applications of numerical methods for compressible flows: Aerospace engineering\n", 182 | "6.3 Applications of numerical methods for compressible flows: Other applications of numerical methods for compressible flows\n", 183 | "\n", 184 | "Which of the sections of the outline contain: \n", 185 | "- private faculty members information (names or contact information)\n", 186 | "- prerequisites, requirements, application processes or other practical course information not related to the course content\n", 187 | "- assignments, final papers, exams, presentations or other student evaluation information\n", 188 | "Falcon: \n", 189 | "- private faculty members information (names or contact information): None\n", 190 | "- prerequisites, requirements, application processes, schedules or other practical course information not related to the course content: 1.3\n", 191 | "- assignments, final papers, exams, presentations or other student evaluation information: None\n", 192 | "User: The following is a course outline for a course on \\\"${COURSE_TITLE}\\\". This outline needs to be anonymized and adapted to an online audience:\n", 193 | "${SECTIONS_LIST}\n", 194 | "\n", 195 | "Which of the sections of the outline contain: \n", 196 | "- private faculty members information (names or contact information)\n", 197 | "- prerequisites, requirements, application processes, schedules or other practical course information not related to the course content\n", 198 | "- assignments, final papers, exams, presentations or other student evaluation information\n", 199 | "Falcon: \"\"\")" 200 | ], 201 | "metadata": { 202 | "collapsed": false 203 | }, 204 | "id": "53702cd87a677315" 205 | }, 206 | { 207 | "cell_type": "markdown", 208 | "source": [ 209 | "Reformat cells into numbered format" 210 | ], 211 | "metadata": { 212 | "collapsed": false 213 | }, 214 | "id": "ca753f9d85a95139" 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": null, 219 | "outputs": [], 220 | "source": [ 221 | "import re\n", 222 | "\n", 223 | "FIND_SECTIONS_REGEX = re.compile(r\"\\d\\. .*(?:\\n\\s*- .*)+\")\n", 224 | "FIND_TITLES_REGEX = re.compile(r\"\\d\\. (.*)\")\n", 225 | "FIND_UNIT_TITLES_REGEX = re.compile(r\"\\n\\s*- (.*)\")\n", 226 | "\n", 227 | "def extract_sections(outline):\n", 228 | " sections = FIND_SECTIONS_REGEX.findall(outline)\n", 229 | " return [\n", 230 | " {\n", 231 | " \"section_nr\": si + 1,\n", 232 | " \"title\": FIND_TITLES_REGEX.search(section).group(1),\n", 233 | " \"unit_titles\": FIND_UNIT_TITLES_REGEX.findall(section),\n", 234 | " } for si, section in enumerate(sections)\n", 235 | " ]\n", 236 | "\n", 237 | "\n", 238 | "df = pd.read_csv(\"outlines_full.csv\", dtype=str)\n", 239 | "for a, b in df.iterrows():\n", 240 | " sections = extract_sections(b[\"outline\"])\n", 241 | " sections_list = '\\n'.join(\n", 242 | " [f\"{si + 1}.{ui + 1} {section['title']}: {unit_title}\" for si, section in enumerate(sections) for\n", 243 | " ui, unit_title in enumerate(section[\"unit_titles\"])])\n", 244 | " prompt = OUTLINE_FILTER_TEMPLATE.substitute({\"COURSE_TITLE\": b[\"course_title\"], \"SECTIONS_LIST\": sections_list})\n", 245 | " df.loc[a, 'filter_outline_prompt'] = prompt\n", 246 | " df.loc[a, 'filter_outline_result'] = generate... # actually generate the filter results" 247 | ], 248 | "metadata": { 249 | "collapsed": false 250 | }, 251 | "id": "1cfa8766032e5ee2" 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": null, 256 | "outputs": [], 257 | "source": [ 258 | "df.to_csv(\"outlines_full_filtered.csv\", index=False)" 259 | ], 260 | "metadata": { 261 | "collapsed": false 262 | }, 263 | "id": "ea6a4080d3619d40" 264 | } 265 | ], 266 | "metadata": { 267 | "kernelspec": { 268 | "display_name": "Python 3", 269 | "language": "python", 270 | "name": "python3" 271 | }, 272 | "language_info": { 273 | "codemirror_mode": { 274 | "name": "ipython", 275 | "version": 2 276 | }, 277 | "file_extension": ".py", 278 | "mimetype": "text/x-python", 279 | "name": "python", 280 | "nbconvert_exporter": "python", 281 | "pygments_lexer": "ipython2", 282 | "version": "2.7.6" 283 | } 284 | }, 285 | "nbformat": 4, 286 | "nbformat_minor": 5 287 | } 288 | -------------------------------------------------------------------------------- /prompts/stanford/README.md: -------------------------------------------------------------------------------- 1 | # Synthetic textbooks from Stanford course outlines 2 | 3 | You can find the code for scraping Stanford courses in `1-scraper.ipynb`, and the code for generating course outlines in `2-generate_course_outlines.ipynb`. 4 | 5 | [TODO]: add code for updated prompts fo Cosmopedia -------------------------------------------------------------------------------- /prompts/stories/README.md: -------------------------------------------------------------------------------- 1 | # Stories from UltraChat and OpenHermes 2 | 3 | We build several types of stories: educational stories for young children, stories involving morals and principles, stories involving problem solving, and posts found on forums and reddit. The prompts are based on seed samples from [UltraChat](https://huggingface.co/datasets/stingning/ultrachat) and [OpenHermes 2.5](https://huggingface.co/datasets/teknium/OpenHermes-2.5). 4 | 5 | We only use the "Questions about the world" [subset](https://huggingface.co/datasets/HuggingFaceTB/ultrachat_questions_about_world) of UltraChat. For OpenHermes we filter out non English instruction in OpenHermes and remove categories and sources that wouldn't be suitable for stories, the filtered dataset is available [here](https://huggingface.co/datasets/HuggingFaceTB/openhermes_filtered). 6 | 7 | To run the filtering 8 | ```bash 9 | python filter_openhermes.py 10 | ``` 11 | 12 | To build the prompts 13 | ```bash 14 | python build_openhermes_stories_prompts.py --run_all_styles 15 | python build_ultrachat_stories_prompts.py --run_all_styles 16 | ``` 17 | -------------------------------------------------------------------------------- /prompts/stories/build_openhermes_stories_prompts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datasets import load_dataset 3 | 4 | 5 | STYLES = {"young_children_story": 6 | """Write an educational story (3-5 paragraphs) targeted at young children using simple words. The story should be inspired from this text snippet: 7 | “” 8 | 9 | The story doesn’t have to be addressing everything in the snippet, it is there just for inspiration. 10 | The story should have the following features: 11 | - Science integration: embed basic science concepts within the story, explaining them through the characters' adventures and discoveries. For example, if the story includes a scene where characters are looking at the sky, you could have them wonder why it's blue and explain the physics behind in grade school level. 12 | - Dialogue: include at least one dialogue and insightful conversation. 13 | - Unexpected twist: conclude with a twist that doesn't resolve as hoped, but leaves a clear lesson about life and science. 14 | Do not start with classic sentences like "Once upon a time", be creative.""", 15 | 16 | "problem_solving_story": 17 | """Write a story that explores a situation slightly related to this text snippet: 18 | “” 19 | 20 | The story should unfold through the characters interactions, decisions, and the consequences of their actions. Aim to weave in common sense lessons and social cues. The narrative should cater to a diverse age group, including at least one dialogue and presenting both positive and negative outcomes. 21 | Do not start with classic sentences like "Once upon a time", be creative.""", 22 | 23 | "reddit_post": 24 | """Write a real-life story shared by someone in a reddit forum. The story should be somehow related to this text snippet: 25 | “” 26 | 27 | The story should include: 28 | - Niche interests or humor: dive into specific hobbies, interests, or humorous situations 29 | - An unexpected plot twist or engaging conflict: introduce a relatable yet challenging situation or dilemma that the author faced. 30 | - Reflection and insight: end with a resolution that offers a new understanding, a sense of community, or a personal revelation, much like the conclusions drawn in forum discussions. 31 | Start the story right away. Do not start with sentences like "Once upon a time" as this is a reddit post and not a novel, you should also avoid starting with classic sentences like "A few years ago" or "A few years back", be creative."""} 32 | 33 | EXTRACT_SIZE = 1000 34 | 35 | 36 | def get_args(): 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument("--repo_id", type=str, default="HuggingFaceTB/prompts_stories_openhermes") 39 | parser.add_argument("--generation_style", type=str, default="problem_solving_story") 40 | parser.add_argument("--run_all_styles", action="store_true") 41 | return parser.parse_args() 42 | 43 | 44 | def build_prompt(x, style="forums_story"): 45 | """Build the prompt based on the generation type""" 46 | snippet = x["prompt"].strip() 47 | snippet = snippet[:min(len(snippet), EXTRACT_SIZE)] 48 | prompt = STYLES[style].replace("", snippet) 49 | return {f"prompt_{style}": prompt} 50 | 51 | 52 | if __name__ == "__main__": 53 | args = get_args() 54 | 55 | print(f"Loading ultrachat data...") 56 | ds = load_dataset("HuggingFaceTB/openhermes_filtered", split="train", num_proc=36) 57 | if args.run_all_styles: 58 | suffix = "" 59 | for style in STYLES.keys(): 60 | print(f"📖 Building prompts with a {style}...") 61 | ds = ds.map(build_prompt, num_proc=48, fn_kwargs={"style": style}) 62 | else: 63 | suffix = f"_{args.generation_style}" 64 | print(f"📖 Building prompts with a {args.generation_style}...") 65 | ds = ds.map(build_prompt, num_proc=48, fn_kwargs={"style": args.generation_style}) 66 | print(ds) 67 | print(ds) 68 | print(ds[0]["prompt_young_children_story"]) 69 | print("-"*100) 70 | print(ds[1]["prompt_problem_solving_story"]) 71 | print("-"*100) 72 | print(ds[2]["prompt_reddit_post"]) 73 | ds.push_to_hub(f"{args.repo_id}{suffix}", private=True) 74 | print(f"✅ Data available at {args.repo_id}{suffix}") 75 | -------------------------------------------------------------------------------- /prompts/stories/build_ultrachat_stories_prompts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | from datasets import load_dataset 4 | 5 | 6 | STYLES = {"young_children_story": 7 | """Write an educational story (3-5 paragraphs) targeted at young children using simple words. The story should be inspired from this text snippet: 8 | “” 9 | 10 | The story doesn’t have to be addressing everything in the snippet, it is there just for inspiration. 11 | The story should have the following features: 12 | - Science integration: embed basic science concepts within the story, explaining them through the characters' adventures and discoveries. For example, if the story includes a scene where characters are looking at the sky, you could have them wonder why it's blue and explain the physics behind in grade-school level. 13 | - Dialogue: include at least one dialogue and insightful conversation. 14 | - Unexpected twist: conclude with a twist that doesn't resolve as hoped, but leaves a clear lesson about life and science. 15 | Do not start with classic sentences like "Once upon a time", be creative.""", 16 | 17 | "morality_story": 18 | """Write a compelling story related to the following text snippet: 19 | “” 20 | 21 | The story doesn’t need to mention everything in the snippet, use it just for inspiration and be creative! 22 | The story should incorporate the following elements: 23 | - Dialogue: the story must feature at least one meaningful dialogue that reveals character depth, advances the plot, or unravels a crucial piece of the mystery 24 | - Interesting themes: explore themes resonant with a mature audience, such as moral ambiguity, existential queries, personal transformation, or the consequences of past actions. 25 | Do not start with classic sentences like "Once upon a time", "The sun hung low in the sky" or "In the dimly lit", be creative.""", 26 | 27 | "problem_solving_story": 28 | """Write a story that explores a situation slightly related to this text snippet: 29 | “” 30 | 31 | The story should unfold through the characters interactions, decisions, and the consequences of their actions. Aim to weave in common sense lessons and social cues, emphasizing the importance of problem-solving. The narrative should cater to a diverse age group, including at least one dialogue and presenting both positive and negative outcomes. 32 | Do not start with classic sentences like "Once upon a time", be creative.""", 33 | 34 | "forums_story": 35 | """Write a story in the style of real-life situations that people share in forums. The story should be somehow related to this text snippet: 36 | “” 37 | 38 | The story needs to include a compelling and unexpected plot twist. Your narrative should resonate with the authenticity and personal touch found in forum discussions. Include relatable events and emotional depth. 39 | Do not start with classic sentences like "Once upon a time", "A few years back" or "A few montsh ago", be creative.""", 40 | 41 | "reddit_post": 42 | """Write a real-life story shared by someone in a reddit forum. The story should be somehow related to this text snippet: 43 | “” 44 | 45 | The story should include: 46 | - Niche interests or humor: dive into specific hobbies, interests, or humorous situations 47 | - An unexpected plot twist or engaging conflict: introduce a relatable yet challenging situation or dilemma that the author faced. 48 | - Reflection and insight: end with a resolution that offers a new understanding, a sense of community, or a personal revelation, much like the conclusions drawn in forum discussions. 49 | Start the story right away. Do not start with sentences like "Once upon a time" as this is a reddit post and not a novel, you should also avoid starting with classic sentences like "A few years ago" or "A few years back", be creative."""} 50 | 51 | EXTRACT_SIZE = 1000 52 | 53 | 54 | def get_args(): 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument("--repo_id", type=str, default="HuggingFaceTB/ultrachat_stories_prompts") 57 | parser.add_argument("--generation_style", type=str, default="textbook_academic") 58 | parser.add_argument("--run_all_styles", action="store_true") 59 | return parser.parse_args() 60 | 61 | 62 | def build_prompt(x, style="forums_story"): 63 | """Build the prompt based on the generation type""" 64 | snippet = x["first_turn"].strip() 65 | snippet = snippet[:min(len(snippet), EXTRACT_SIZE)] 66 | prompt = STYLES[style].replace("", snippet) 67 | return {f"prompt_{style}": prompt} 68 | 69 | 70 | if __name__ == "__main__": 71 | args = get_args() 72 | 73 | print(f"Loading ultrachat data...") 74 | ds = load_dataset("HuggingFaceTB/ultrachat_questions_about_world", split="train") 75 | if args.run_all_styles: 76 | suffix = "" 77 | for style in STYLES.keys(): 78 | print(f"📖 Building prompts with a {style}...") 79 | ds = ds.map(build_prompt, num_proc=48, fn_kwargs={"style": style}) 80 | else: 81 | suffix = f"_{args.generation_style}" 82 | print(f"📖 Building prompts with a {args.generation_style}...") 83 | ds = ds.map(build_prompt, num_proc=48, fn_kwargs={"style": args.generation_style}) 84 | print(ds) 85 | print(ds) 86 | print(ds[0]["prompt_young_children_story"]) 87 | print("-"*100) 88 | print(ds[1]["prompt_morality_story"]) 89 | print("-"*100) 90 | print(ds[2]["prompt_reddit_post"]) 91 | ds.push_to_hub(f"{args.repo_id}_{suffix}", private=True) 92 | print(f"✅ Data available at {args.repo_id}_{suffix}!") 93 | -------------------------------------------------------------------------------- /prompts/stories/filter_openhermes.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | 3 | ds = load_dataset("teknium/OpenHermes-2.5", split="train", num_proc=36) 4 | drop_sources = ["camelai", "glaive-code-assist"] 5 | drop_categories = ["rp", "gtkm", "coding", "wordgame", "riddle"] 6 | 7 | def filter_files(x): 8 | if x["category"] and x["category"].lower() in drop_categories: 9 | return False 10 | if x["source"] and x["source"].lower() in drop_sources: 11 | return False 12 | return True 13 | 14 | def get_prompt(x): 15 | conversations = x["conversations"] 16 | prompt = "" 17 | for i in range(len(conversations)): 18 | if conversations[i]["from"] == "human": 19 | prompt += conversations[i]["value"] + "\n" 20 | assert conversations[i+1]["from"] == "gpt", f"role is {conversations[i+1]['from']} not 'gpt'!" 21 | prompt += conversations[i+1]["value"] 22 | break 23 | return {"prompt": prompt} 24 | 25 | print("Start...") 26 | print(ds) 27 | print("Language filter...") 28 | ds = ds.filter(lambda x: x["language"] in [None, "English"], num_proc=12) 29 | print(ds) 30 | print("Category & source filter...") 31 | ds_f = ds.filter(filter_files, num_proc=36) 32 | print(ds_f) 33 | ds_f = ds_f.map(get_prompt, num_proc=36) 34 | ds_f = ds_f.remove_columns([col for col in ds_f.column_names if col not in ["prompt", "source", "category"]]) 35 | ds_f.push_to_hub("HuggingFaceTB/openhermes_filtered") -------------------------------------------------------------------------------- /prompts/web_samples/README.md: -------------------------------------------------------------------------------- 1 | # Synthetic data from Web samples 2 | 3 | We built several types of synthetic content from seed web samples: textbooks (in narrative or academic tone), blogposts and WikiHow articles. 4 | 5 | To select the web samples, we initially clustered 100k samples from a web dataset like [ReFineWeb](https://huggingface.co/datasets/tiiuae/falcon-refinedweb). This resulted into 145 cluters. You can inspect the clusters in this [demo](https://huggingface.co/spaces/HuggingFaceTB/inspect_clusters_free_topics). Then we inferred the clusters of 15M other web samples and used them for the prompts with their topic. 6 | 7 | The clustering code can be found in [text-clustering](https://github.com/huggingface/text-clustering?tab=readme-ov-file#cosmopedia-experiments-clustering-of-web-samples-and-topic-labeling) repository. We then excluded 38 clutsers, deemed uneducational, using the scores generated in the clustering, but also after doing some manual inspection of each cluster. We noticed that medium scores weren't always of the topic quality. 8 | 9 | We also tried to infer which generation style would best suit each topic: e.g Mathematics are suitable for textbooks, Beauty & Lifetyle might be suitable for blogposts and DIY for WikiHow articles. However we didn't respect this classification as the prompted LLM seemed to address each topic from interesting different angles when using it with different styles. 10 | 11 | Script for classification and filtering (this depends on the clusters you find in your dataset) 12 | ```bash 13 | python filter_and_classify_clusters.py 14 | ``` 15 | 16 | Script for building the web prompts: 17 | ```bash 18 | python build_web_prompts.py 19 | ``` 20 | -------------------------------------------------------------------------------- /prompts/web_samples/build_web_prompts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | from datasets import load_dataset 4 | 5 | 6 | STYLES = {"wikihow": 7 | """Here is an extract from a webpage: "". 8 | 9 | Write a long and very detailed tutorial that could be part of WikiHow whose title is related to the extract above. Include in depth explanations for each step and how it helps achieve the desired outcome, inluding key tips and guidelines. 10 | Ensure clarity and practicality, allowing readers to easily follow and apply the instructions. Do not use images.""", 11 | 12 | "textbook_narrative": 13 | """Here is an extract from a webpage: "". 14 | 15 | Write an extensive and detailed course unit suitable for a textbook, related to the given extract. Do not just list concepts, but develop each one in detail before moving to the next, as we prioritize depth of understanding and comprehensive exploration of the subject matter over breadth. Focus on: 16 | 17 | - Rigor: Ensure in-depth coverage of the concepts. 18 | - Engagement: Use a narrative style akin to Michael Lewis, making it captivating and thought-provoking. 19 | - Relevance: Connect the topic with current trends, real-life examples, or recent studies. Do not use images. 20 | Do not include a title or an introduction, simply write the content without headlines and introductory phrases. Do not use images.""", 21 | 22 | "textbook_academic": 23 | """Here is an extract from a webpage: "". 24 | 25 | Write an extensive and detailed course unit suitable for a textbook targeted at college students, related to the given extract. Do not just list concepts, but develop each one in detail before moving to the next, as we prioritize depth of understanding and comprehensive exploration of the subject matter over breadth. Focus on: 26 | 27 | - Rigor: Ensure in-depth coverage of the concepts/sections. 28 | - Engagement: Write with an academic, professional and engaging tone that captivates interest. 29 | - Application: Incorporate specific, practical examples, such as proofs in calculus or critical dates and figures in history. 30 | Do not include a title or an introduction, simply write the content without headlines and introductory phrases. Do not use images.""", 31 | 32 | "blogpost": 33 | """Here is an extract from a webpage: "". 34 | 35 | Write an informative and insightful blog post that expands upon the extract above. Your post should delve into the nuances of the topic, offering fresh perspectives and deeper analysis. Aim to: 36 | 37 | - Inform: Provide valuable, well-researched information that educates the reader. 38 | - Engage: Write in a conversational tone that connects with the audience, making complex ideas accessible. 39 | - Illustrate: Use examples, anecdotes, or personal experiences to bring the topic to life. 40 | Do not give a title and do not start with sentences like "Have you ever..." or "Hello dear readers..", simply write the content without these introductory phrases.""" 41 | } 42 | 43 | EXTRACT_SIZE = 1000 44 | 45 | 46 | def get_args(): 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument("--repo_id", type=str, default="HuggingFaceTB/web_prompts") 49 | parser.add_argument("--data_type", type=str, default="textbook") 50 | parser.add_argument("--generation_style", type=str, default="textbook_academic") 51 | parser.add_argument("--run_all_styles", action="store_true") 52 | return parser.parse_args() 53 | 54 | 55 | def build_prompt(x, style="textbook_academic"): 56 | """Build the prompt based on the generation type""" 57 | # web extract and topic 58 | web_sample = x["examples"] 59 | web_sample = web_sample[:min(EXTRACT_SIZE, len(web_sample))] 60 | topic = x["category"] 61 | add_topic = f', within the context of "{topic}"' if random.random() < 0.5 else "" 62 | # requested generation style 63 | prompt = STYLES[style].replace("", add_topic).replace("", web_sample) 64 | return {f"prompt_{style}": prompt} 65 | 66 | 67 | if __name__ == "__main__": 68 | # load data=data_type and generate content in style=stayle 69 | args = get_args() 70 | 71 | print(f"Loading data fw2_as_{args.data_type}...") 72 | ds = load_dataset(f"HuggingFaceTB/fw2_as_{args.data_type}", split="train", num_proc=48) 73 | if args.run_all_styles: 74 | suffix = "" 75 | for style in STYLES.keys(): 76 | print(f"📖 Building prompts with a {style}...") 77 | ds = ds.map(build_prompt, num_proc=48, fn_kwargs={"style": style}) 78 | else: 79 | suffix = f"_{args.generation_style}" 80 | print(f"📖 Building prompts with a {args.generation_style}...") 81 | ds = ds.map(build_prompt, num_proc=48, fn_kwargs={"style": args.generation_style}) 82 | print(ds) 83 | print(ds) 84 | print(ds[0]["prompt_textbook_academic"]) 85 | print("-"*100) 86 | print(ds[1]["prompt_textbook_academic"]) 87 | print("-"*100) 88 | print(ds[2]["prompt_textbook_academic"]) 89 | ds.push_to_hub(f"{args.repo_id}_{args.data_type}{suffix}", private=True) 90 | print(f"✅ Data available at {args.repo_id}_{args.data_type}{suffix}!") -------------------------------------------------------------------------------- /prompts/web_samples/filter_and_classify_clusters.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pandas as pd 3 | from collections import Counter 4 | from datasets import load_dataset, Dataset 5 | 6 | 7 | # re-classifying each topic into the most adequate class: textbook, blogpost, or wikihow 8 | classifications = { 9 | "textbook": [ 10 | 'International Relations and Politics', 11 | 'Product Marketing and Design', 12 | 'Digital Imaging and Photography', 13 | 'Computer Science', 14 | 'Economics and Finance', 15 | 'Pharmaceutical manufacturing and technology', 16 | 'Real Estate & Investment', 17 | 'Business and Entrepreneurship', 18 | 'Astronomy and Astrophysics', 19 | 'Finance and Investment', 20 | 'Computer Antivirus Software and Security', 21 | 'Healthcare and Operations Management', 22 | 'Technology and Computer Science', 23 | 'Computer Programming and Web Development', 24 | 'Taxation and Finance', 25 | 'Human Resources / Organizational Management', 26 | 'Computer Hardware and Graphics Cards', 27 | 'Marketing and Business Strategies', 28 | 'Digital Marketing and Business', 29 | 'Audio Equipment and Home Theater Systems', 30 | 'HIV Treatment and Care', 31 | 'Legal Studies and Public Policy', 32 | 'Legal Studies / Law', 33 | 'Jewelry Design and Manufacturing', 34 | 'Biochemistry and Molecular Biology', 35 | 'Insurance', 36 | 'Energy and Environmental Policy', 37 | 'Data Privacy and Protection', 38 | 'International Relations and Conflict', 39 | 'Entomology and Apiculture', 40 | 'Loans and Mortgages', 41 | 'Public Transit and Transportation', 42 | 'International Relations and Current Events', 43 | 'Politics and Government', 44 | 'Political Science', 45 | 'Genetics and Mental Health', 46 | 'Public Administration and Policy', 47 | 'Technology and Consumer Electronics', 48 | 'Computer Security & Privacy', 49 | 'Online Platforms & Web Technologies', 50 | 'Human Resources and Education', 51 | 'Sports and Education', 52 | 'Lighting Design and Technology', 53 | 'Medicine', 54 | 'Cryptocurrency and Blockchain Technology', 55 | 'Mental Health Counseling', 56 | 'Geography and Weather', 57 | 'Leadership and Education', 58 | 'Infant Feeding and Child Development', 59 | 'Molecular Biology and Genetics', 60 | 'Energy and Natural Resources', 61 | 'Mental Health and Therapy', 62 | 'Business and Management', 63 | 'Legal Services and Issues', 64 | 'Christian Theology and Spirituality', 65 | 'Personal Finance and Investments', 66 | 'Psychology', 67 | 'Healthcare & Medical Services', 68 | 'Watchmaking and Horology', 69 | 'Online Chat Platforms and Data Privacy', 70 | 'Waste Management and Recycling' 71 | ], 72 | "blogpost": [ 73 | 'Health and Lifestyle', 74 | 'Physical Fitness and Health', 75 | 'Music', 76 | 'Fiction and Fantasy Writing', 77 | 'Literature and Creative Writing', 78 | 'Arts and Crafts', 79 | 'Education', 80 | 'Education and Youth Development', 81 | 'Writing and Storytelling', 82 | 'Hair Care and Styling', 83 | 'Automotive Parts and Accessories', 84 | 'Astrology', 85 | 'Culinary Arts and Beverages', 86 | 'Events and Community Happenings', 87 | 'Cooking and Baking', 88 | 'Online Dating & Relationships', 89 | 'Career Development and Job Opportunities', 90 | 'Cosmetic Surgery and Body Modifications', 91 | 'Skincare and Beauty Products', 92 | 'Addiction and Mental Illness', 93 | 'Visual Arts and Art Appreciation', 94 | 'Pets and Pet Care', 95 | 'Personal Development and Empowerment', 96 | 'Video Games', 97 | 'Hair Care', 98 | 'Nutrition and Health', 99 | 'Fashion & Apparel', 100 | 'Travel', 101 | 'Performing Arts', 102 | 'Cannabis and CBD Products', 103 | 'Wine & Winemaking', 104 | 'Cooking and Recipes' 105 | ], 106 | "wikihow": [ 107 | 'Dentistry', 108 | 'Football/Soccer', 109 | 'Cleaning and Maintenance', 110 | 'American Football', 111 | 'Baseball', 112 | 'Recreational Fishing', 113 | 'Public Safety and Emergency Response', 114 | 'Ice Hockey', 115 | 'Professional Basketball/NBA', 116 | 'Home Improvement and Maintenance', 117 | 'Tennis', 118 | 'Professional Wrestling and Sports Entertainment', 119 | 'Cricket', 120 | 'Gun Control and Violence', 121 | 'Fire Incidents', 122 | 'Electric Vehicles and Battery Technology', 123 | 'Christianity and Theology' 124 | ] 125 | } 126 | remove = ['Online Gambling and Casinos', 127 | 'Reality Television and Celebrity Gossip', 128 | 'Explicit Adult Content', 129 | 'Fashion & Accessories', 130 | 'Solicitation and Prostitution', 131 | 'Adult Entertainment and Webcam Sites', 132 | 'Clothing & Fashion', 133 | 'Firearms and Accessories', 134 | 'Marketing and Sales Promotions', 135 | 'Cookies and Privacy Policies', 136 | 'Fragrances and Personal Care Products', 137 | 'Obituaries and Personal Profiles', 138 | 'Motor Sports', 139 | 'Death', 140 | 'Male Enhancement Products and Supplements', 141 | 'Political Science', 142 | 'Heating', 143 | 'Anabolic Steroids and', 144 | 'Combat Sports', 145 | 'Events and Community Happenings', 146 | 'Footwear and Fashion', 147 | 'Furniture Design and Sales', 148 | 'Entertainment & Media', 149 | 'Real Estate & Property Management', 150 | 'Weight Loss and Body Contouring', 151 | 'Moving Services and Logistics' 152 | 'Transportation and City Planning', 153 | 'Home Decoration and Furniture', 154 | 'Events and Conferences', 155 | 'Cosmetics and Beauty Products', 156 | 'Mattresses and Sleep', 157 | 'Soap & Skincare Products', 158 | 'E-commerce and Online Shopping', 159 | 'Optical Equipment and Accessories', 160 | ] 161 | 162 | new_dict = {v[i]: k for k, v in classifications.items() for i in range(len(v))} 163 | 164 | 165 | def get_args(): 166 | parser = argparse.ArgumentParser() 167 | parser.add_argument("--clusters_dataset", type=str, default="HuggingFaceTB/web_clusters") 168 | parser.add_argument("--user", type=str, default="HuggingFaceTB") 169 | return parser.parse_args() 170 | 171 | 172 | def extract_category(example): 173 | summary = example["summary"] 174 | category = summary.split(". Educational")[0].strip() 175 | score = summary.split(" Educational score: ")[1].strip() 176 | return {"category": category, "educational_score": score} 177 | 178 | 179 | def add_generation_type(x): 180 | topic = x["category"] 181 | try: 182 | generation_type = new_dict[topic] 183 | except: 184 | print(f"{topic} not in keep list") 185 | generation_type = "blogpost" 186 | return {"generation_type": generation_type} 187 | 188 | args = get_args() 189 | print("Loading web samples (after the clustering)...") 190 | ds_1 = load_dataset(args.clusters_dataset, split="train") 191 | print(ds_1) 192 | 193 | print("Converting to dataframe...") 194 | full_df = ds_1.to_pandas().explode("examples") 195 | 196 | full_df.sort_values(by=['cluster_id'], inplace=True) 197 | 198 | print("Full df info...") 199 | print(full_df.head()) 200 | print(full_df.info()) 201 | 202 | print("Convert to HF dataset...") 203 | final_ds = Dataset.from_pandas(full_df) 204 | final_ds = final_ds.map(extract_category) 205 | print("HF dataset:") 206 | print(final_ds) 207 | 208 | print("Filter out bad topics...") 209 | ds_keep = final_ds.filter(lambda x: x["category"] not in remove, num_proc=64) 210 | print(f"Size after dropping low quality clusters: {len(ds_keep)}={len(ds_keep)*100/len(final_ds):.2f}% of the original dataset") 211 | 212 | print("Add generation type...") 213 | ds_keep = ds_keep.map(add_generation_type, num_proc=24) 214 | print(Counter(ds_keep["generation_type"])) 215 | 216 | print("Retrieve textbooks...") 217 | textbooks = ds_keep.filter(lambda x: x["generation_type"] == "textbook") 218 | print(textbooks) 219 | print("Retrieve wikihow...") 220 | wikihow = ds_keep.filter(lambda x: x["generation_type"] == "wikihow") 221 | print(wikihow) 222 | print("Retrieve blopgpot...") 223 | blogpost = ds_keep.filter(lambda x: x["generation_type"] == "blogpost") 224 | print(blogpost) 225 | 226 | print("Pushing to hub ...") 227 | textbooks.push_to_hub(f"{args.user}/fw2_as_textbook", private=True) 228 | wikihow.push_to_hub(f"{args.user}/fw2_as_wikihow", private=True) 229 | blogpost.push_to_hub(f"{args.user}/fw2_as_blogpost", private=True) 230 | print("Done!") -------------------------------------------------------------------------------- /prompts/wikihow/README.md: -------------------------------------------------------------------------------- 1 | # Synthetic WikiHow articles from scraped WikiHow titles 2 | 3 | You can find the list fo WikiHow titles we scraped in `wikihowcom-20231012-titles.txt`. 4 | An updated list of wikihow titles can be extracted using https://github.com/mediawiki-client-tools/mediawiki-dump-generator 5 | 6 | [TODO] Add code for updated prompts of Cosmopedia --------------------------------------------------------------------------------