├── LICENSE
├── README.md
├── classification
├── README.md
├── run_edu_bert.py
├── run_edu_bert.slurm
├── train_edu_bert.py
└── train_edu_bert.slurm
├── decontamination
├── README.md
└── decontaminate.py
├── deduplication
├── README.md
└── deduplicate_dataset.py
├── evaluation
├── README.md
├── eval.slurm
└── lighteval_tasks.py
├── fulltext_search
├── README.md
├── index_docs.py
├── index_docs.slurm
├── manticore.conf
├── search_sharded.py
└── search_sharded.slurm
├── generation
├── README.md
├── boilerplate_cleanup.py
└── llm_swarm_script.py
├── plots
├── clusters_map.png
├── cover.png
├── cover_01.png
├── educational_score.png
└── topics_distpng.png
└── prompts
├── README.md
├── auto_math_text
├── README.md
└── build_science_prompts.py
├── khanacademy
├── README.md
├── generate_textbooks.py
└── khan_dl
│ ├── khan_dl.py
│ ├── main.py
│ └── requirements.txt
├── openstax
├── README.md
└── build_openstax_prompts.py
├── stanford
├── 1_scraper.ipynb
├── 2_generate_course_outlines.ipynb
└── README.md
├── stories
├── README.md
├── build_openhermes_stories_prompts.py
├── build_ultrachat_stories_prompts.py
└── filter_openhermes.py
├── web_samples
├── README.md
├── build_web_prompts.py
└── filter_and_classify_clusters.py
└── wikihow
├── README.md
└── wikihowcom-20231012-titles.txt
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Cosmopedia
2 |
3 |
4 |

5 |
Image generated by DALL-E, the prompt was generated by Mixtral-8x7B-Instruct-v0.1.
6 |
7 |
8 | [🤗 Cosmopedia dataset] | [🤖 1B-LLM trained on Cosmopedia] | [📰 Blog post]
9 |
10 | blog post:
11 |
12 |
13 | ## Description
14 | Here you can find the code used for creating [Cosmopedia](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia), a dataset of synthetic textbooks, blogposts, stories, posts and WikiHow articles generated by Mixtral-8x7B-Instruct-v0.1. It contains over **30 million files and 25 billion tokens**, making it the largest open synthetic dataset to date.
15 |
16 | Cosmopedia covers a variety of topics; we tried to map world knowledge present in Web datasets like RefinedWeb and RedPajama, and generate synthetic content that covers them. This is the v0.1 of Cosmopedia, with ample room for improvement and topics to be more comprehensively covered. We hope this dataset will help the community's research efforts in the increasingly intriguing domain of synthetic data.
17 |
18 |
19 |

20 |
The clusters of Cosmopedia.
21 |
22 |
23 | You can also find a files frequency plot of single topic clusters in `plots/topic_distpng.png`.
24 |
25 | ## Code structure
26 | - `prompts`: the code for building the prompts in each `seed_data` in Cosmopedia. In `web_samples`, you can also find pointers for the topic clustering we did.
27 | - `generation`: the code to run large scale synthetic generations with [llm-swarm](https://github.com/huggingface/llm-swarm) using the prompts you built. Cosmopedia consists of 25B tokens and was generated in > 10k H100 GPU hours.
28 | - `deduplication`: the script we used to run MinHash deduplication with [datatrove](https://github.com/huggingface/datatrove).
29 | - `decontamination`: the code we used to run n-gram decontamination against evaluation benchmarks, when training models on the dataset like [cosmopedian-1b](https://huggingface.co/HuggingFaceTB/cosmopedian-1b).
30 |
--------------------------------------------------------------------------------
/classification/README.md:
--------------------------------------------------------------------------------
1 | # Educational value classifier
2 |
3 | ### 1. Finetune a model for educational value regression
4 |
5 | * edit `train_edu_bert.slurm`
6 | ```bash
7 | --base_model_name="Snowflake/snowflake-arctic-embed-m" \ # BERT-like base model
8 | --dataset_name="HuggingFaceTB/LLM_juries_fineweb_430k_annotations" \ # Llama3-annotated eduational value dataset
9 | --target_column="score"
10 | ```
11 | * run the training script on a SLURM cluster:
12 | ```bash
13 | sbatch train_edu_bert.slurm
14 | ```
15 |
16 | ### 2. Annotate a dataset with the educational scores predicted by the model
17 |
18 | ```bash
19 | sbatch run_edu_bert.slurm
20 | ```
--------------------------------------------------------------------------------
/classification/run_edu_bert.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | from transformers import AutoTokenizer, AutoModelForSequenceClassification
4 | from datasets import load_dataset
5 |
6 |
7 | def main(args):
8 | tokenizer = AutoTokenizer.from_pretrained(args.model_name)
9 | model = AutoModelForSequenceClassification.from_pretrained(
10 | args.model_name, torch_dtype=torch.bfloat16
11 | )
12 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13 | model.to(device)
14 |
15 | dataset = load_dataset(
16 | args.dataset_name,
17 | args.dataset_config,
18 | split="train",
19 | cache_dir="/scratch/cosmo/cache/",
20 | num_proc=12,
21 | )
22 | dataset = dataset.filter(
23 | lambda x, i: i % args.num_shards == args.shard, with_indices=True, num_proc=12
24 | )
25 |
26 | def compute_scores(batch):
27 | inputs = tokenizer(
28 | batch[args.text_column],
29 | return_tensors="pt",
30 | padding="longest",
31 | truncation=True,
32 | ).to(device)
33 | with torch.no_grad():
34 | outputs = model(**inputs)
35 | logits = outputs.logits.squeeze(-1).float().cpu().numpy()
36 |
37 | batch["score"] = logits.tolist()
38 | batch["int_score"] = [int(round(max(0, min(score, 5)))) for score in logits]
39 | return batch
40 |
41 | dataset = dataset.map(compute_scores, batched=True, batch_size=512)
42 |
43 | while True:
44 | try:
45 | config_name = f"{args.output_dataset_config}_{args.shard}"
46 | dataset.push_to_hub(
47 | args.output_dataset_name,
48 | config_name=config_name,
49 | private=True,
50 | max_shard_size="4096MB",
51 | )
52 | break
53 | except Exception as e:
54 | print(e)
55 | continue
56 |
57 |
58 | if __name__ == "__main__":
59 | parser = argparse.ArgumentParser()
60 |
61 | parser.add_argument(
62 | "--model_name", type=str, default="HHuggingFaceFW/fineweb-edu-classifier"
63 | )
64 | parser.add_argument("--dataset_name", type=str, default="HuggingFaceFW/fineweb")
65 | parser.add_argument("--dataset_config", type=str, default="default")
66 | parser.add_argument(
67 | "--output_dataset_name", type=str, default="HuggingFaceFW/fineweb-edu"
68 | )
69 | parser.add_argument("--output_dataset_config", type=str, default="default")
70 | parser.add_argument("--text_column", type=str, default="text")
71 | parser.add_argument("--shard", type=int, required=True)
72 | parser.add_argument("--num_shards", type=int, required=True)
73 |
74 | args = parser.parse_args()
75 | main(args)
76 |
--------------------------------------------------------------------------------
/classification/run_edu_bert.slurm:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=run_edu_bert
3 | #SBATCH --partition hopper-prod
4 | #SBATCH --qos=normal
5 | #SBATCH --requeue
6 | #SBATCH --nodes=1
7 | #SBATCH --ntasks-per-node=1
8 | #SBATCH --cpus-per-task=12
9 | #SBATCH --mem-per-cpu=20G
10 | #SBATCH --gpus=1
11 | #SBATCH -o %x_%j.out
12 | #SBATCH -e %x_%j.err
13 | #SBATCH --time=7-00:00:00
14 | #SBATCH --array=0-127%128
15 |
16 | set -x -e
17 | source ~/.bashrc
18 | source "$CONDA_PREFIX/etc/profile.d/conda.sh"
19 | source activate pytorch
20 |
21 | python run_edu_bert.py \
22 | --model_name="HuggingFaceFW/fineweb-edu-classifier" \
23 | --dataset_name="HuggingFaceFW/fineweb" \
24 | --dataset_config="CC-MAIN-2019-04" \
25 | --output_dataset_name="HuggingFaceFW/fineweb-edu-annotations" \
26 | --output_dataset_config="CC-MAIN-2019-04" \
27 | --text_column="text" \
28 | --shard ${SLURM_ARRAY_TASK_ID} \
29 | --num_shards 128
30 |
--------------------------------------------------------------------------------
/classification/train_edu_bert.py:
--------------------------------------------------------------------------------
1 | from transformers import (
2 | AutoTokenizer,
3 | DataCollatorWithPadding,
4 | TrainingArguments,
5 | Trainer,
6 | AutoModelForSequenceClassification,
7 | )
8 | from datasets import load_dataset, ClassLabel
9 | import numpy as np
10 | import evaluate
11 | import argparse
12 | import os
13 | from sklearn.metrics import classification_report, confusion_matrix
14 |
15 |
16 | def compute_metrics(eval_pred):
17 | precision_metric = evaluate.load("precision")
18 | recall_metric = evaluate.load("recall")
19 | f1_metric = evaluate.load("f1")
20 | accuracy_metric = evaluate.load("accuracy")
21 |
22 | logits, labels = eval_pred
23 | preds = np.round(logits.squeeze()).clip(0, 5).astype(int)
24 | labels = np.round(labels.squeeze()).astype(int)
25 | precision = precision_metric.compute(
26 | predictions=preds, references=labels, average="macro"
27 | )["precision"]
28 | recall = recall_metric.compute(
29 | predictions=preds, references=labels, average="macro"
30 | )["recall"]
31 | f1 = f1_metric.compute(predictions=preds, references=labels, average="macro")["f1"]
32 | accuracy = accuracy_metric.compute(predictions=preds, references=labels)["accuracy"]
33 |
34 | report = classification_report(labels, preds)
35 | cm = confusion_matrix(labels, preds)
36 | print("Validation Report:\n" + report)
37 | print("Confusion Matrix:\n" + str(cm))
38 |
39 | return {
40 | "precision": precision,
41 | "recall": recall,
42 | "f1_macro": f1,
43 | "accuracy": accuracy,
44 | }
45 |
46 |
47 | def main(args):
48 | dataset = load_dataset(
49 | args.dataset_name, split="train", cache_dir="/scratch/cosmo/cache/", num_proc=8
50 | )
51 | dataset = dataset.map(
52 | lambda x: {args.target_column: np.clip(int(x[args.target_column]), 0, 5)},
53 | num_proc=8,
54 | )
55 |
56 | dataset = dataset.cast_column(
57 | args.target_column, ClassLabel(names=[str(i) for i in range(6)])
58 | )
59 | dataset = dataset.train_test_split(
60 | train_size=0.9, seed=42, stratify_by_column=args.target_column
61 | )
62 |
63 | model = AutoModelForSequenceClassification.from_pretrained(
64 | args.base_model_name,
65 | num_labels=1,
66 | classifier_dropout=0.0,
67 | hidden_dropout_prob=0.0,
68 | output_hidden_states=False,
69 | )
70 | tokenizer = AutoTokenizer.from_pretrained(
71 | args.base_model_name,
72 | model_max_length=min(model.config.max_position_embeddings, 512),
73 | )
74 | if not tokenizer.pad_token:
75 | tokenizer.pad_token = tokenizer.eos_token
76 |
77 | def preprocess(examples):
78 | batch = tokenizer(examples["text"], truncation=True)
79 | batch["labels"] = np.float32(examples[args.target_column])
80 | return batch
81 |
82 | dataset = dataset.map(preprocess, batched=True)
83 | data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
84 |
85 | for param in model.bert.embeddings.parameters():
86 | param.requires_grad = False
87 | for param in model.bert.encoder.parameters():
88 | param.requires_grad = False
89 |
90 | training_args = TrainingArguments(
91 | output_dir=args.checkpoint_dir,
92 | hub_model_id=args.output_model_name,
93 | eval_strategy="steps",
94 | save_strategy="steps",
95 | eval_steps=1000,
96 | save_steps=1000,
97 | logging_steps=100,
98 | learning_rate=3e-4,
99 | num_train_epochs=20,
100 | seed=0,
101 | per_device_train_batch_size=256,
102 | per_device_eval_batch_size=128,
103 | eval_on_start=True,
104 | load_best_model_at_end=True,
105 | metric_for_best_model="f1_macro",
106 | greater_is_better=True,
107 | bf16=True,
108 | push_to_hub=True,
109 | )
110 |
111 | trainer = Trainer(
112 | model=model,
113 | args=training_args,
114 | train_dataset=dataset["train"],
115 | eval_dataset=dataset["test"],
116 | tokenizer=tokenizer,
117 | data_collator=data_collator,
118 | compute_metrics=compute_metrics,
119 | )
120 |
121 | trainer.train()
122 | trainer.save_model(os.path.join(args.checkpoint_dir, "final"))
123 |
124 |
125 | if __name__ == "__main__":
126 | parser = argparse.ArgumentParser()
127 | parser.add_argument(
128 | "--base_model_name", type=str, default="Snowflake/snowflake-arctic-embed-m"
129 | )
130 | parser.add_argument(
131 | "--dataset_name",
132 | type=str,
133 | default="HuggingFaceFW/fineweb-edu-llama3-annotations",
134 | )
135 | parser.add_argument("--target_column", type=str, default="score")
136 | parser.add_argument(
137 | "--checkpoint_dir",
138 | type=str,
139 | default="/fsx/anton/cosmopedia/edu_score/bert_snowflake_regression",
140 | )
141 | parser.add_argument(
142 | "--output_model_name", type=str, default="HuggingFaceTB/fineweb-edu-scorer"
143 | )
144 | args = parser.parse_args()
145 |
146 | main(args)
147 |
--------------------------------------------------------------------------------
/classification/train_edu_bert.slurm:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=train_edu_bert
3 | #SBATCH --partition hopper-prod
4 | #SBATCH --nodes=1
5 | #SBATCH --ntasks-per-node=1
6 | #SBATCH --cpus-per-task=16
7 | #SBATCH --mem-per-cpu=20G
8 | #SBATCH --gpus=1
9 | #SBATCH -o %x_%j.out
10 | #SBATCH -e %x_%j.err
11 | #SBATCH --time=1-00:00:00
12 |
13 | set -x -e
14 | source ~/.bashrc
15 | source "$CONDA_PREFIX/etc/profile.d/conda.sh"
16 | source activate pytorch
17 |
18 | python train_edu_bert.py \
19 | --base_model_name="Snowflake/snowflake-arctic-embed-m" \
20 | --dataset_name="HuggingFaceFW/fineweb-edu-llama3-annotations" \
21 | --target_column="score" \
22 | --checkpoint_dir="/fsx/anton/cosmopedia/edu_score/snowflake_regression_median_jury" \
23 | --output_model_name="HuggingFaceTB/fineweb-edu-scorer"
24 |
--------------------------------------------------------------------------------
/decontamination/README.md:
--------------------------------------------------------------------------------
1 | # Decontamination
2 |
3 | We use a 10-gram overlap to retrieve potentially contaminated samples, similarly to [Phi-1](https://huggingface.co/papers/2306.11644).
4 | After retrieving the candidates, we run a diff between the dataset sample and the benchmark sample using `difflib.SequenceMatcher` and discard the sample if `len(matched_substrings)/len(benchmark_sample) > 0.5`.
5 | We run decontamination against all the benchmarks we evaluated the Cosmo-1B model on: MMLU, HellaSwag, PIQA, SIQA, Winogrande, OpenBookQA, ARC-easy, ARC-challenge.
6 |
7 | Usage:
8 | ```bash
9 | export HF_DATASETS_CACHE=/scratch/cosmo/cache
10 | export HUGGINGFACE_HUB_CACHE=/scratch/cosmo/cache
11 |
12 | python decontaminate.py --train_dataset "HuggingFaceTB/AMT_2M_Khanacademy_24k" --report_dataset_name "HuggingFaceTB/AMT_2M_Khanacademy_24k_decont_report" --save_decontaminated --decontaminated_dataset_name "HuggingFaceTB/AMT_2M_Khanacademy_24k_decont"
13 | ```
14 |
15 |
16 |
--------------------------------------------------------------------------------
/decontamination/decontaminate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import difflib
3 | import re
4 | import unicodedata
5 | from pathlib import Path
6 | from tqdm.auto import tqdm
7 | from datasets import load_dataset, Dataset
8 |
9 |
10 | def tokenize(text):
11 | """Normalize text by removing diacritics and tokenize."""
12 | text = "".join(c for c in unicodedata.normalize("NFD", text) if unicodedata.category(c) != "Mn")
13 | tokens = re.findall("\w+", text.lower())
14 | return tokens
15 |
16 |
17 | def get_ngrams(tokens, n):
18 | """Generate n-grams from tokens."""
19 | return set(zip(*[tokens[i:] for i in range(n)]))
20 |
21 |
22 | def retrieve_ngrams_batch(batch, eval_ngrams, eval_datasets, eval_texts, ngram_len):
23 | """Find contaminated samples based on n-grams."""
24 | new_batch = {"completion": [], "ngram": [], "bench_name": [], "bench_text": []}
25 | for completion in batch["completion"]:
26 | tokens = tokenize(completion)
27 | ngrams = get_ngrams(tokens, ngram_len)
28 | for ngram in ngrams:
29 | if ngram in eval_ngrams:
30 | idx = eval_ngrams[ngram]
31 | new_batch["completion"].append(completion)
32 | new_batch["ngram"].append(ngram)
33 | new_batch["bench_name"].append(eval_datasets[idx])
34 | new_batch["bench_text"].append(eval_texts[idx])
35 | break
36 | return new_batch
37 |
38 |
39 | def diff_strings(string1, string2):
40 | """Find matching parts between two strings."""
41 | matcher = difflib.SequenceMatcher(None, string1.lower(), string2.lower(), autojunk=False)
42 | matching_blocks = matcher.get_matching_blocks()
43 | matches = []
44 | for block in matching_blocks:
45 | start_a, start_b, length = block
46 | if length > 5:
47 | match = string1[start_a:start_a + length]
48 | matches.append(match)
49 | return matches
50 |
51 |
52 | def add_match_stats(example):
53 | gen_text = " ".join(tokenize(example["completion"]))
54 | bench_text = " ".join(tokenize(example["bench_text"]))
55 | matching_parts = diff_strings(gen_text, bench_text)
56 | match = " ".join("".join(matching_parts).split())
57 | example["diff"] = matching_parts
58 | example["diff_ratio"] = len(match) / len(bench_text) if len(bench_text) > 0 else 0
59 | example["diff_length"] = len(match)
60 | example["longest_diff_part"] = max(matching_parts, key=len, default="")
61 | example["longest_diff_part_length"] = len(example["longest_diff_part"])
62 | return example
63 |
64 |
65 | def main(args):
66 | # Load the evaluation data to build n-grams index
67 | eval_ngrams, eval_datasets, eval_texts = {}, [], []
68 | eval_data = load_dataset(args.eval_dataset, split="train", num_proc=args.num_proc)
69 | for example in tqdm(eval_data):
70 | tokens = tokenize(example["text"])
71 | ngrams = get_ngrams(tokens, args.ngram_length)
72 | if ngrams:
73 | idx = len(eval_texts)
74 | eval_ngrams.update(zip(ngrams, [idx] * len(ngrams)))
75 | eval_datasets.append(example.get("task_name", "unknown"))
76 | eval_texts.append(example["text"])
77 |
78 | train_dataset_path = Path(args.train_dataset)
79 | if train_dataset_path.exists() and train_dataset_path.suffix in [".json", ".csv"]:
80 | if train_dataset_path.suffix == ".json":
81 | train_data = Dataset.from_json(args.train_dataset)
82 | elif train_dataset_path.suffix == ".csv":
83 | train_data = Dataset.from_csv(args.train_dataset)
84 | else:
85 | train_data = load_dataset(args.train_dataset, split="train", num_proc=args.num_proc)
86 |
87 | contamination_report = train_data.map(
88 | lambda batch: retrieve_ngrams_batch(batch, eval_ngrams, eval_datasets, eval_texts, args.ngram_length),
89 | batched=True, batch_size=1000, num_proc=args.num_proc, remove_columns=train_data.column_names
90 | )
91 |
92 | contamination_report = contamination_report.map(
93 | lambda example: add_match_stats(example), num_proc=args.num_proc
94 | )
95 |
96 | contamination_report.push_to_hub(args.report_dataset_name, private=args.private)
97 |
98 | contamination_report = contamination_report.filter(lambda x: x["diff_ratio"] > args.diff_threshold)
99 |
100 | if args.save_decontaminated:
101 | contaminated_completions = set(contamination_report["completion"])
102 | filtered_data = train_data.filter(lambda x: x["completion"] not in contaminated_completions)
103 | filtered_data.push_to_hub(args.decontaminated_dataset_name, private=args.private)
104 |
105 |
106 | if __name__ == "__main__":
107 | parser = argparse.ArgumentParser(description="Generate a decontamination report for a dataset.")
108 | parser.add_argument("--eval_dataset", type=str,
109 | default="HuggingFaceTB/phi2_eval_data_for_decontamination",
110 | help="Name of the dataset with benchmark samples to use for decontamination.")
111 | parser.add_argument("--train_dataset", type=str, required=True,
112 | help="Path or name of the training dataset to process.")
113 | parser.add_argument("--report_dataset_name", type=str, required=True,
114 | help="Name for the output dataset with decontamination report.")
115 | parser.add_argument("--decontaminated_dataset_name", type=str, help="Name for the decontaminated dataset.")
116 | parser.add_argument("--private", action='store_true', help="Whether to make the output dataset private.")
117 | parser.add_argument("--ngram_length", type=int, default=10, help="Length of the n-grams to consider.")
118 | parser.add_argument("--diff_threshold", type=float, default=0.5,
119 | help="Threshold for filtering based on difference ratio.")
120 | parser.add_argument("--num_proc", type=int, default=16, help="Number of processes to use for map operations.")
121 | parser.add_argument("--save_decontaminated", action='store_true',
122 | help="Whether to save the decontaminated dataset.")
123 |
124 | args = parser.parse_args()
125 | main(args)
--------------------------------------------------------------------------------
/deduplication/README.md:
--------------------------------------------------------------------------------
1 | # Deduplication
2 |
3 | We run deduplication on the dataset using MinHash from [datatrove](https://github.com/huggingface/datatrove).
4 | Considering that the seed samples had already undergone deduplication, and we carefully crafted the prompts to ensure distinct outputs even with identical seeds, the volume of duplicates found in Cosmopedia was less than 1% of the files, which were subsequenlty removed.
5 |
6 | The deduplication script is available at `deduplicate_dataset.py`, make sure to follow the installation guidelines in `datatrove` and to change the paths in the file before running it.
--------------------------------------------------------------------------------
/deduplication/deduplicate_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from datatrove.executor.slurm import SlurmPipelineExecutor
4 | from datatrove.pipeline.dedup import MinhashDedupSignature
5 | from datatrove.pipeline.dedup.minhash import (
6 | MinhashConfig,
7 | MinhashDedupBuckets,
8 | MinhashDedupCluster,
9 | MinhashDedupFilter,
10 | )
11 | from datatrove.pipeline.readers import HuggingFaceDatasetReader
12 | from datatrove.pipeline.tokens import TokensCounter
13 | from datatrove.pipeline.writers.jsonl import JsonlWriter
14 |
15 |
16 | # you can also change ngrams or the number of buckets and their size here
17 | minhash_config = MinhashConfig()
18 | HF_DATA = "cosmopedia-100k"
19 |
20 | S3_MINHASH_BASE_PATH = f"s3://synthetic-datasets-phi/{HF_DATA}/minhash"
21 | S3_LOGS_FOLDER = f"s3://synthetic-datasets-phi/{HF_DATA}/minhash_logs/"
22 | LOCAL_LOGS_FOLDER = f"./logs/dedup_extras/{HF_DATA}"
23 | os.makedirs(LOCAL_LOGS_FOLDER, exist_ok = True)
24 |
25 | TOTAL_TASKS = 120
26 |
27 | INPUT_READER = HuggingFaceDatasetReader(
28 | dataset=f"HuggingFaceTB/{HF_DATA}", # dataset name
29 | dataset_options={
30 | "split": "train"
31 | },
32 | text_key="completion"
33 | )
34 | # stage 1 computes minhash signatures for each task (each task gets a set of files)
35 | stage1 = SlurmPipelineExecutor(
36 | job_name="mh1",
37 | pipeline=[
38 | INPUT_READER,
39 | MinhashDedupSignature(output_folder=f"{S3_MINHASH_BASE_PATH}/signatures", config=minhash_config),
40 | ],
41 | tasks=TOTAL_TASKS,
42 | time="5:00:00",
43 | partition="hopper-cpu",
44 | logging_dir=f"{S3_LOGS_FOLDER}/signatures",
45 | slurm_logs_folder=f"{LOCAL_LOGS_FOLDER}/signatures/slurm_logs",
46 | qos="high",
47 | )
48 |
49 | # stage 2 finds matches between signatures in each bucket
50 | stage2 = SlurmPipelineExecutor(
51 | job_name="mh2",
52 | pipeline=[
53 | MinhashDedupBuckets(
54 | input_folder=f"{S3_MINHASH_BASE_PATH}/signatures",
55 | output_folder=f"{S3_MINHASH_BASE_PATH}/buckets",
56 | config=minhash_config,
57 | ),
58 | ],
59 | tasks=minhash_config.num_buckets,
60 | time="90:00:00",
61 | partition="hopper-prod",
62 | logging_dir=f"{S3_LOGS_FOLDER}/buckets",
63 | depends=stage1,
64 | slurm_logs_folder=f"{LOCAL_LOGS_FOLDER}/buckets/slurm_logs",
65 | qos="high",
66 | )
67 |
68 | # stage 3 creates clusters of duplicates using the results from all buckets
69 | stage3 = SlurmPipelineExecutor(
70 | job_name="mh3",
71 | pipeline=[
72 | MinhashDedupCluster(
73 | input_folder=f"{S3_MINHASH_BASE_PATH}/buckets",
74 | output_folder=f"{S3_MINHASH_BASE_PATH}/remove_ids",
75 | config=minhash_config,
76 | ),
77 | ],
78 | tasks=1,
79 | time="90:00:00",
80 | partition="hopper-prod",
81 | logging_dir=f"{S3_LOGS_FOLDER}/clusters",
82 | mem_per_cpu_gb=70,
83 | cpus_per_task=2,
84 | depends=stage2,
85 | slurm_logs_folder=f"{LOCAL_LOGS_FOLDER}/clusters/slurm_logs",
86 | )
87 |
88 | # stage 4 reads the original input data and removes all but 1 sample per duplicate cluster
89 | # the data must match exactly stage 1, so number of tasks and the input source must be the same
90 | stage4 = SlurmPipelineExecutor(
91 | job_name="mh4",
92 | pipeline=[
93 | INPUT_READER,
94 | TokensCounter(), # nice way to see how many tokens we had before and after deduplication
95 | MinhashDedupFilter(
96 | input_folder=f"{S3_MINHASH_BASE_PATH}/remove_ids",
97 | exclusion_writer=JsonlWriter(f"{S3_MINHASH_BASE_PATH}/removed"),
98 | ),
99 | JsonlWriter(output_folder=f"{S3_MINHASH_BASE_PATH}/deduplicated_output"), # output_folder="hf_stack"
100 | ],
101 | tasks=TOTAL_TASKS,
102 | time="50:00:00",
103 | partition="hopper-cpu",
104 | logging_dir=f"{S3_LOGS_FOLDER}/filter",
105 | depends=stage3,
106 | slurm_logs_folder=f"{LOCAL_LOGS_FOLDER}/filter/slurm_logs",
107 | )
108 |
109 |
110 | stage4.run()
111 |
--------------------------------------------------------------------------------
/evaluation/README.md:
--------------------------------------------------------------------------------
1 | # Benchmark evaluation
2 |
3 | This is an extended version of the [FineWeb-v1 evaluation script](https://huggingface.co/datasets/HuggingFaceFW/fineweb/blob/main/lighteval_tasks.py)
4 | In particular, we add the MMLU-Pro, TriviaQA, and GSM8k benchmarks.
5 |
6 | To run the script, please install the latest version of the `lighteval` library:
7 | ```bash
8 | git clone https://github.com/huggingface/lighteval.git
9 | cd lighteval
10 | conda create -n lighteval python=3.10 && conda activate lighteval
11 | pip install '.[accelerate,quantization,adapters]'
12 | ```
13 |
14 | Then, you can run the evaluation script with the following command:
15 | ```bash
16 | MODEL = "openai-community/gpt2"
17 | accelerate launch --num_processes=1 --main_process_port=29600 "lighteval/run_evals_accelerate.py" --model_args="pretrained=$MODEL" \
18 | --custom_tasks "lighteval_tasks.py" --output_dir $OUTPUT_DIR --override_batch_size 16 \
19 | --tasks "custom|hellaswag|0|1,custom|winogrande|0|1,custom|piqa|0|1,custom|siqa|0|1,custom|openbookqa|0|1,custom|arc:easy|0|1,custom|arc:challenge|0|1,custom|commonsense_qa|0|1,custom|trivia_qa|0|1,custom|mmlu_pro_cloze|0|1,custom|gsm8k|5|1,custom|mmlu_cloze:abstract_algebra|0|1,custom|mmlu_cloze:anatomy|0|1,custom|mmlu_cloze:astronomy|0|1,custom|mmlu_cloze:business_ethics|0|1,custom|mmlu_cloze:clinical_knowledge|0|1,custom|mmlu_cloze:college_biology|0|1,custom|mmlu_cloze:college_chemistry|0|1,custom|mmlu_cloze:college_computer_science|0|1,custom|mmlu_cloze:college_mathematics|0|1,custom|mmlu_cloze:college_medicine|0|1,custom|mmlu_cloze:college_physics|0|1,custom|mmlu_cloze:computer_security|0|1,custom|mmlu_cloze:conceptual_physics|0|1,custom|mmlu_cloze:econometrics|0|1,custom|mmlu_cloze:electrical_engineering|0|1,custom|mmlu_cloze:elementary_mathematics|0|1,custom|mmlu_cloze:formal_logic|0|1,custom|mmlu_cloze:global_facts|0|1,custom|mmlu_cloze:high_school_biology|0|1,custom|mmlu_cloze:high_school_chemistry|0|1,custom|mmlu_cloze:high_school_computer_science|0|1,custom|mmlu_cloze:high_school_european_history|0|1,custom|mmlu_cloze:high_school_geography|0|1,custom|mmlu_cloze:high_school_government_and_politics|0|1,custom|mmlu_cloze:high_school_macroeconomics|0|1,custom|mmlu_cloze:high_school_mathematics|0|1,custom|mmlu_cloze:high_school_microeconomics|0|1,custom|mmlu_cloze:high_school_physics|0|1,custom|mmlu_cloze:high_school_psychology|0|1,custom|mmlu_cloze:high_school_statistics|0|1,custom|mmlu_cloze:high_school_us_history|0|1,custom|mmlu_cloze:high_school_world_history|0|1,custom|mmlu_cloze:human_aging|0|1,custom|mmlu_cloze:human_sexuality|0|1,custom|mmlu_cloze:international_law|0|1,custom|mmlu_cloze:jurisprudence|0|1,custom|mmlu_cloze:logical_fallacies|0|1,custom|mmlu_cloze:machine_learning|0|1,custom|mmlu_cloze:management|0|1,custom|mmlu_cloze:marketing|0|1,custom|mmlu_cloze:medical_genetics|0|1,custom|mmlu_cloze:miscellaneous|0|1,custom|mmlu_cloze:moral_disputes|0|1,custom|mmlu_cloze:moral_scenarios|0|1,custom|mmlu_cloze:nutrition|0|1,custom|mmlu_cloze:philosophy|0|1,custom|mmlu_cloze:prehistory|0|1,custom|mmlu_cloze:professional_accounting|0|1,custom|mmlu_cloze:professional_law|0|1,custom|mmlu_cloze:professional_medicine|0|1,custom|mmlu_cloze:professional_psychology|0|1,custom|mmlu_cloze:public_relations|0|1,custom|mmlu_cloze:security_studies|0|1,custom|mmlu_cloze:sociology|0|1,custom|mmlu_cloze:us_foreign_policy|0|1,custom|mmlu_cloze:virology|0|1,custom|mmlu_cloze:world_religions|0|1"
20 | ```
21 |
--------------------------------------------------------------------------------
/evaluation/eval.slurm:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=eval_cosmo
3 | #SBATCH --partition hopper-prod
4 | #SBATCH --nodes=1
5 | #SBATCH --ntasks-per-node=1
6 | #SBATCH --cpus-per-task=48
7 | #SBATCH --mem-per-cpu=20G
8 | #SBATCH --gpus=8
9 | #SBATCH -o %x_%j.out
10 | #SBATCH -e %x_%j.err
11 | #SBATCH --time=1-00:00:00
12 |
13 | set -x -e
14 | source ~/.bashrc
15 | source "/admin/home/anton/miniforge3/etc/profile.d/conda.sh"
16 | source activate cosmolighteval
17 |
18 | export HF_HOME="/fsx/anton/cosmo/cache/"
19 | export HF_DATASETS_CACHE="/fsx/anton/cosmo/cache/"
20 |
21 | MODELS=(
22 | "openai-community/gpt2"
23 | "openai-community/gpt2-medium"
24 | "openai-community/gpt2-xl"
25 | "karpathy/gpt2_1558M_final4_hf"
26 | "EleutherAI/pythia-160m"
27 | "Qwen/Qwen2-0.5B"
28 | "HuggingFaceTB/cosmo2-1.7B-1T"
29 | "HuggingFaceTB/cosmo2-149M-600B-fp32"
30 | "HuggingFaceTB/cosmo2-362M-600B-fp32"
31 | "HuggingFaceTB/cosmo2-1.7B-900B"
32 | "HuggingFaceTB/mixture11-600B"
33 | "HuggingFaceTB/cosmo2-base-magpie-lr-5e-5"
34 | "HuggingFaceTB/cosmo-300B-with-decay-instruct-mixture-5-bis"
35 | "HuggingFaceTB/cosmo2-600B-tokens-base-mixture"
36 | "microsoft/phi-1_5"
37 | "microsoft/phi-2"
38 | "HuggingFaceTB/cosmo-1b"
39 | "HuggingFaceTB/cosmo2-test-classic"
40 | "HuggingFaceFW/ablation-model-fineweb-edu"
41 | "HuggingFaceFW/ablation-model-fineweb-v1"
42 | "Qwen/Qwen1.5-1.8B"
43 | "Qwen/Qwen2-1.5B"
44 | "Qwen/Qwen1.5-0.5B"
45 | "stabilityai/stablelm-2-1_6b"
46 | "allenai/OLMo-1B-hf"
47 | "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
48 | "Qwen/Qwen2-1.5B-Instruct"
49 | "Qwen/Qwen2-0.5B-Instruct"
50 | "HuggingFaceFW/ablation-model-refinedweb"
51 | "HuggingFaceFW/ablation-model-c4"
52 | "HuggingFaceFW/ablation-model-dolma-v1_6"
53 | "HuggingFaceFW/ablation-model-slimpajama"
54 | "HuggingFaceFW/ablation-model-the-pile"
55 | "HuggingFaceFW/ablation-model-redpajama2"
56 | )
57 | OUTPUT_DIR="/fsx/anton/cosmopedia/eval_results_cosmo2/"
58 | OUTPUT_DATASET="HuggingFaceTB/eval_results_cosmo2"
59 |
60 | for model in "${MODELS[@]}"
61 | do
62 | accelerate launch --num_processes=8 --main_process_port=29600 "/admin/home/anton/repos/lighteval/run_evals_accelerate.py" --model_args="pretrained=$model" \
63 | --custom_tasks "lighteval_tasks.py" --output_dir $OUTPUT_DIR --override_batch_size 16 \
64 | --tasks "custom|hellaswag|0|1,custom|winogrande|0|1,custom|piqa|0|1,custom|siqa|0|1,custom|openbookqa|0|1,custom|arc:easy|0|1,custom|arc:challenge|0|1,custom|commonsense_qa|0|1,custom|boolq|0|1,custom|trivia_qa|0|1,custom|trivia_qa|5|1,custom|mmlu_pro_cloze|0|1,custom|mmlu_stem_mc|0|1,custom|mmlu_stem_cloze|0|1,custom|gsm8k|5|1,custom|mmlu_mc:abstract_algebra|0|1,custom|mmlu_mc:anatomy|0|1,custom|mmlu_mc:astronomy|0|1,custom|mmlu_mc:business_ethics|0|1,custom|mmlu_mc:clinical_knowledge|0|1,custom|mmlu_mc:college_biology|0|1,custom|mmlu_mc:college_chemistry|0|1,custom|mmlu_mc:college_computer_science|0|1,custom|mmlu_mc:college_mathematics|0|1,custom|mmlu_mc:college_medicine|0|1,custom|mmlu_mc:college_physics|0|1,custom|mmlu_mc:computer_security|0|1,custom|mmlu_mc:conceptual_physics|0|1,custom|mmlu_mc:econometrics|0|1,custom|mmlu_mc:electrical_engineering|0|1,custom|mmlu_mc:elementary_mathematics|0|1,custom|mmlu_mc:formal_logic|0|1,custom|mmlu_mc:global_facts|0|1,custom|mmlu_mc:high_school_biology|0|1,custom|mmlu_mc:high_school_chemistry|0|1,custom|mmlu_mc:high_school_computer_science|0|1,custom|mmlu_mc:high_school_european_history|0|1,custom|mmlu_mc:high_school_geography|0|1,custom|mmlu_mc:high_school_government_and_politics|0|1,custom|mmlu_mc:high_school_macroeconomics|0|1,custom|mmlu_mc:high_school_mathematics|0|1,custom|mmlu_mc:high_school_microeconomics|0|1,custom|mmlu_mc:high_school_physics|0|1,custom|mmlu_mc:high_school_psychology|0|1,custom|mmlu_mc:high_school_statistics|0|1,custom|mmlu_mc:high_school_us_history|0|1,custom|mmlu_mc:high_school_world_history|0|1,custom|mmlu_mc:human_aging|0|1,custom|mmlu_mc:human_sexuality|0|1,custom|mmlu_mc:international_law|0|1,custom|mmlu_mc:jurisprudence|0|1,custom|mmlu_mc:logical_fallacies|0|1,custom|mmlu_mc:machine_learning|0|1,custom|mmlu_mc:management|0|1,custom|mmlu_mc:marketing|0|1,custom|mmlu_mc:medical_genetics|0|1,custom|mmlu_mc:miscellaneous|0|1,custom|mmlu_mc:moral_disputes|0|1,custom|mmlu_mc:moral_scenarios|0|1,custom|mmlu_mc:nutrition|0|1,custom|mmlu_mc:philosophy|0|1,custom|mmlu_mc:prehistory|0|1,custom|mmlu_mc:professional_accounting|0|1,custom|mmlu_mc:professional_law|0|1,custom|mmlu_mc:professional_medicine|0|1,custom|mmlu_mc:professional_psychology|0|1,custom|mmlu_mc:public_relations|0|1,custom|mmlu_mc:security_studies|0|1,custom|mmlu_mc:sociology|0|1,custom|mmlu_mc:us_foreign_policy|0|1,custom|mmlu_mc:virology|0|1,custom|mmlu_mc:world_religions|0|1,custom|mmlu_cloze:abstract_algebra|0|1,custom|mmlu_cloze:anatomy|0|1,custom|mmlu_cloze:astronomy|0|1,custom|mmlu_cloze:business_ethics|0|1,custom|mmlu_cloze:clinical_knowledge|0|1,custom|mmlu_cloze:college_biology|0|1,custom|mmlu_cloze:college_chemistry|0|1,custom|mmlu_cloze:college_computer_science|0|1,custom|mmlu_cloze:college_mathematics|0|1,custom|mmlu_cloze:college_medicine|0|1,custom|mmlu_cloze:college_physics|0|1,custom|mmlu_cloze:computer_security|0|1,custom|mmlu_cloze:conceptual_physics|0|1,custom|mmlu_cloze:econometrics|0|1,custom|mmlu_cloze:electrical_engineering|0|1,custom|mmlu_cloze:elementary_mathematics|0|1,custom|mmlu_cloze:formal_logic|0|1,custom|mmlu_cloze:global_facts|0|1,custom|mmlu_cloze:high_school_biology|0|1,custom|mmlu_cloze:high_school_chemistry|0|1,custom|mmlu_cloze:high_school_computer_science|0|1,custom|mmlu_cloze:high_school_european_history|0|1,custom|mmlu_cloze:high_school_geography|0|1,custom|mmlu_cloze:high_school_government_and_politics|0|1,custom|mmlu_cloze:high_school_macroeconomics|0|1,custom|mmlu_cloze:high_school_mathematics|0|1,custom|mmlu_cloze:high_school_microeconomics|0|1,custom|mmlu_cloze:high_school_physics|0|1,custom|mmlu_cloze:high_school_psychology|0|1,custom|mmlu_cloze:high_school_statistics|0|1,custom|mmlu_cloze:high_school_us_history|0|1,custom|mmlu_cloze:high_school_world_history|0|1,custom|mmlu_cloze:human_aging|0|1,custom|mmlu_cloze:human_sexuality|0|1,custom|mmlu_cloze:international_law|0|1,custom|mmlu_cloze:jurisprudence|0|1,custom|mmlu_cloze:logical_fallacies|0|1,custom|mmlu_cloze:machine_learning|0|1,custom|mmlu_cloze:management|0|1,custom|mmlu_cloze:marketing|0|1,custom|mmlu_cloze:medical_genetics|0|1,custom|mmlu_cloze:miscellaneous|0|1,custom|mmlu_cloze:moral_disputes|0|1,custom|mmlu_cloze:moral_scenarios|0|1,custom|mmlu_cloze:nutrition|0|1,custom|mmlu_cloze:philosophy|0|1,custom|mmlu_cloze:prehistory|0|1,custom|mmlu_cloze:professional_accounting|0|1,custom|mmlu_cloze:professional_law|0|1,custom|mmlu_cloze:professional_medicine|0|1,custom|mmlu_cloze:professional_psychology|0|1,custom|mmlu_cloze:public_relations|0|1,custom|mmlu_cloze:security_studies|0|1,custom|mmlu_cloze:sociology|0|1,custom|mmlu_cloze:us_foreign_policy|0|1,custom|mmlu_cloze:virology|0|1,custom|mmlu_cloze:world_religions|0|1"
65 | done
66 |
67 | huggingface-cli upload $OUTPUT_DATASET $OUTPUT_DIR / --repo-type dataset --delete="*"
68 | # huggingface-cli upload HuggingFaceTB/eval_results_cosmo2 /fsx/anton/cosmopedia/eval_results_cosmo2/ / --repo-type dataset --delete="*"
--------------------------------------------------------------------------------
/evaluation/lighteval_tasks.py:
--------------------------------------------------------------------------------
1 | # ruff: noqa: F405, F403, F401
2 | """
3 | Custom evaluation tasks for lighteval
4 |
5 | Do note that we ran the evals with `max_samples=1000` to speed up large evals.
6 | Most custom prompt changes were in an attempt to improve signal for small models in general.
7 |
8 | This file generally creates just a TASKS_TABLE and TASKS_GROUPS which are then imported by LightEval.
9 |
10 | Example usage (lighteval_tasks.py is the path to this file):
11 | ===================
12 | accelerate launch --num_processes=1 lighteval/run_evals_accelerate.py --model_args="pretrained=HuggingFaceTB/cosmo-1b" \
13 | --custom_tasks "lighteval_tasks.py" --output_dir [OUTPUTPATH] --max_samples 1000 \
14 | --tasks "custom|hellaswag|0|1,custom|winogrande|0|1,custom|piqa|0|1,custom|siqa|0|1,custom|openbookqa|0|1,custom|arc:easy|0|1,custom|arc:challenge|0|1,custom|commonsense_qa|0|1,custom|mmlu:abstract_algebra|0|1,custom|mmlu:anatomy|0|1,custom|mmlu:astronomy|0|1,custom|mmlu:business_ethics|0|1,custom|mmlu:clinical_knowledge|0|1,custom|mmlu:college_biology|0|1,custom|mmlu:college_chemistry|0|1,custom|mmlu:college_computer_science|0|1,custom|mmlu:college_mathematics|0|1,custom|mmlu:college_medicine|0|1,custom|mmlu:college_physics|0|1,custom|mmlu:computer_security|0|1,custom|mmlu:conceptual_physics|0|1,custom|mmlu:econometrics|0|1,custom|mmlu:electrical_engineering|0|1,custom|mmlu:elementary_mathematics|0|1,custom|mmlu:formal_logic|0|1,custom|mmlu:global_facts|0|1,custom|mmlu:high_school_biology|0|1,custom|mmlu:high_school_chemistry|0|1,custom|mmlu:high_school_computer_science|0|1,custom|mmlu:high_school_european_history|0|1,custom|mmlu:high_school_geography|0|1,custom|mmlu:high_school_government_and_politics|0|1,custom|mmlu:high_school_macroeconomics|0|1,custom|mmlu:high_school_mathematics|0|1,custom|mmlu:high_school_microeconomics|0|1,custom|mmlu:high_school_physics|0|1,custom|mmlu:high_school_psychology|0|1,custom|mmlu:high_school_statistics|0|1,custom|mmlu:high_school_us_history|0|1,custom|mmlu:high_school_world_history|0|1,custom|mmlu:human_aging|0|1,custom|mmlu:human_sexuality|0|1,custom|mmlu:international_law|0|1,custom|mmlu:jurisprudence|0|1,custom|mmlu:logical_fallacies|0|1,custom|mmlu:machine_learning|0|1,custom|mmlu:management|0|1,custom|mmlu:marketing|0|1,custom|mmlu:medical_genetics|0|1,custom|mmlu:miscellaneous|0|1,custom|mmlu:moral_disputes|0|1,custom|mmlu:moral_scenarios|0|1,custom|mmlu:nutrition|0|1,custom|mmlu:philosophy|0|1,custom|mmlu:prehistory|0|1,custom|mmlu:professional_accounting|0|1,custom|mmlu:professional_law|0|1,custom|mmlu:professional_medicine|0|1,custom|mmlu:professional_psychology|0|1,custom|mmlu:public_relations|0|1,custom|mmlu:security_studies|0|1,custom|mmlu:sociology|0|1,custom|mmlu:us_foreign_policy|0|1,custom|mmlu:virology|0|1,custom|mmlu:world_religions|0|1"
15 | ===================
16 |
17 | More info here: https://github.com/huggingface/lighteval?tab=readme-ov-file#evaluate-a-model-on-extended-community-or-custom-tasks
18 | For more info on differences between MMLU implementations: https://huggingface.co/blog/open-llm-leaderboard-mmlu#1001-flavors-of-mmlu
19 | In particular, the default leaderboard MMLU implementation (which uses "A", "B", etc as answer targets) gives generally random results on small/non instruction tuned models.
20 | Instead, we use the full MMLU answer as the target.
21 | """
22 | import re
23 | from typing import List, Tuple
24 |
25 | from lighteval.metrics.metrics import Metrics
26 | from lighteval.tasks.lighteval_task import LightevalTaskConfig
27 | from lighteval.tasks.requests import Doc
28 | from lighteval.tasks.default_prompts import LETTER_INDICES
29 |
30 | _TASKS_STRINGS: List[Tuple[LightevalTaskConfig, str]] = []
31 | _TASKS: List[LightevalTaskConfig] = []
32 |
33 | ## COMMON_SENSE_REASONING_TASKS ##
34 | COMMON_SENSE_REASONING_TASKS = [
35 | LightevalTaskConfig(
36 | name="hellaswag",
37 | prompt_function="hellaswag_prompt",
38 | hf_repo="hellaswag",
39 | hf_subset="default",
40 | metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace],
41 | ),
42 | LightevalTaskConfig(
43 | name="winogrande",
44 | prompt_function="winogrande",
45 | hf_repo="winogrande",
46 | hf_subset="winogrande_xl",
47 | metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace],
48 | ),
49 | LightevalTaskConfig(
50 | name="piqa",
51 | prompt_function="piqa_harness",
52 | hf_repo="piqa",
53 | hf_subset="plain_text",
54 | metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace],
55 | ),
56 | LightevalTaskConfig(
57 | name="siqa",
58 | prompt_function="siqa_prompt",
59 | hf_repo="lighteval/siqa",
60 | hf_subset="default",
61 | hf_avail_splits=["train", "validation"],
62 | metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace],
63 | ),
64 | LightevalTaskConfig(
65 | name="openbookqa",
66 | prompt_function="openbookqa",
67 | hf_repo="openbookqa",
68 | hf_subset="main",
69 | metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace],
70 | ),
71 | LightevalTaskConfig(
72 | name="arc:easy",
73 | prompt_function="arc",
74 | hf_repo="ai2_arc",
75 | hf_subset="ARC-Easy",
76 | evaluation_splits=["test"],
77 | generation_size=1,
78 | metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace],
79 | ),
80 | LightevalTaskConfig(
81 | name="arc:challenge",
82 | prompt_function="arc",
83 | hf_repo="ai2_arc",
84 | hf_subset="ARC-Challenge",
85 | evaluation_splits=["test"],
86 | generation_size=1,
87 | metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace],
88 | ),
89 | LightevalTaskConfig(
90 | name="commonsense_qa",
91 | prompt_function="commonsense_qa_prompt",
92 | hf_repo="commonsense_qa",
93 | hf_subset="default",
94 | metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace],
95 | ),
96 | LightevalTaskConfig(
97 | name="mmlu_pro_cloze",
98 | prompt_function="mmlu_pro_cloze_prompt",
99 | hf_repo="TIGER-Lab/MMLU-Pro",
100 | hf_subset="default",
101 | metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace],
102 | evaluation_splits=["test"],
103 | few_shots_split="validation",
104 | few_shots_select=None,
105 | generation_size=-1,
106 | stop_sequence=None,
107 | output_regex=None,
108 | frozen=False,
109 | ),
110 | LightevalTaskConfig(
111 | name="mmlu_pro_mc",
112 | prompt_function="mmlu_pro_mc_prompt",
113 | hf_repo="TIGER-Lab/MMLU-Pro",
114 | hf_subset="default",
115 | metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace],
116 | evaluation_splits=["test"],
117 | few_shots_split="validation",
118 | few_shots_select=None,
119 | generation_size=1,
120 | stop_sequence=None,
121 | output_regex=None,
122 | frozen=False,
123 | ),
124 | LightevalTaskConfig(
125 | name="boolq",
126 | prompt_function="boolq_prompt",
127 | hf_repo="super_glue",
128 | hf_subset="boolq",
129 | metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace],
130 | trust_dataset=True,
131 | stop_sequence=["\n"],
132 | ),
133 | LightevalTaskConfig(
134 | name="trivia_qa",
135 | prompt_function="triviaqa",
136 | hf_repo="mandarjoshi/trivia_qa",
137 | hf_subset="rc.nocontext",
138 | hf_avail_splits=["train", "validation"],
139 | evaluation_splits=["validation"],
140 | metric=[Metrics.quasi_exact_match_triviaqa],
141 | generation_size=20,
142 | trust_dataset=True,
143 | stop_sequence=["\n", ".", ","],
144 | few_shots_select="random_sampling_from_train",
145 | ),
146 | ]
147 |
148 |
149 | def boolq_prompt(line, task_name: str = None):
150 | return Doc(
151 | task_name=task_name,
152 | query=f"{line['passage']}\nQuestion: {line['question'].capitalize()}?\nAnswer:",
153 | choices=[" No", " Yes"], # Only gold
154 | gold_index=int(line["label"]),
155 | )
156 |
157 |
158 | def mmlu_pro_cloze_prompt(line, task_name: str = None):
159 | """MMLU-Pro prompt without letters"""
160 | topic = line["category"]
161 | prompt = f"The following are questions about {topic.replace('_', ' ')}.\nQuestion: "
162 | prompt += line["question"] + "\nAnswer:"
163 |
164 | return Doc(
165 | task_name=task_name,
166 | query=prompt,
167 | choices=[f" {c}" for c in line["options"]],
168 | gold_index=line["answer_index"],
169 | instruction=f"The following are questions about {topic.replace('_', ' ')}.\n",
170 | )
171 |
172 |
173 | def mmlu_pro_mc_prompt(line, task_name: str = None):
174 | topic = line["category"]
175 | query = f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n"
176 | query += line["question"] + "\n"
177 | query += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["options"])])
178 | query += "Answer:"
179 |
180 | return Doc(
181 | task_name=task_name,
182 | query=query,
183 | choices=LETTER_INDICES[: len(line["options"])],
184 | gold_index=line["answer_index"],
185 | instruction=f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n",
186 | target_for_fewshot_sorting=LETTER_INDICES[line["answer_index"]],
187 | )
188 |
189 |
190 | def commonsense_qa_prompt(line, task_name: str = None):
191 | return Doc(
192 | task_name=task_name,
193 | query=line["question"],
194 | choices=[f" {c}" for c in line["choices"]["text"]],
195 | gold_index=LETTER_INDICES.index(line["answerKey"].strip()),
196 | instruction="",
197 | )
198 |
199 |
200 | def siqa_prompt(line, task_name: str = None):
201 | return Doc(
202 | task_name=task_name,
203 | query=line["context"] + " " + line["question"],
204 | choices=[f" {c}" for c in [line["answerA"], line["answerB"], line["answerC"]]],
205 | gold_index=int(line["label"]) - 1,
206 | instruction="",
207 | )
208 |
209 |
210 | def hellaswag_prompt(line, task_name: str = None):
211 | def preprocess(text):
212 | """Comes from AiHarness"""
213 | # text = text.strip()
214 | # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
215 | text = text.replace(" [title]", ". ")
216 | text = re.sub("\\[.*?\\]", "", text)
217 | text = text.replace(" ", " ")
218 | return text
219 |
220 | ctx = f"{line['ctx_a']} {line['ctx_b'].capitalize()} "
221 | return Doc(
222 | task_name=task_name,
223 | query=preprocess(line["activity_label"] + ": " + ctx),
224 | choices=[" " + preprocess(ending) for ending in line["endings"]],
225 | gold_index=int(line["label"]) if line["label"] != "" else -1, # -1 for test
226 | # "metric": "choices_loglikelihood",
227 | )
228 |
229 |
230 | GSM8K = LightevalTaskConfig(
231 | name="gsm8k",
232 | prompt_function="gsm8k",
233 | hf_repo="gsm8k",
234 | hf_subset="main",
235 | hf_avail_splits=["train", "test"],
236 | evaluation_splits=["test"],
237 | metric=[Metrics.quasi_exact_match_gsm8k],
238 | generation_size=256,
239 | stop_sequence=["Question:", "Question"],
240 | few_shots_select="random_sampling_from_train",
241 | )
242 | MATH_TASKS = [
243 | LightevalTaskConfig(
244 | name=f"math:{subset}",
245 | prompt_function="math",
246 | hf_repo="lighteval/MATH",
247 | hf_subset=subset,
248 | hf_avail_splits=["train", "test"],
249 | evaluation_splits=["test"],
250 | metric=[Metrics.quasi_exact_match_math],
251 | generation_size=256,
252 | stop_sequence=["Problem:", "Problem"],
253 | few_shots_select="random_sampling_from_train",
254 | )
255 | for subset in [
256 | "algebra",
257 | "counting_and_probability",
258 | "geometry",
259 | "intermediate_algebra",
260 | "number_theory",
261 | "prealgebra",
262 | "precalculus",
263 | ]
264 | ]
265 |
266 | # 0 short for common sense
267 | COMMON_SENSE_REASONING_STRING = [(t, f"custom|{t.name}|0|1") for t in COMMON_SENSE_REASONING_TASKS]
268 | _TASKS_STRINGS.extend(COMMON_SENSE_REASONING_STRING)
269 | _TASKS_STRINGS.extend([(GSM8K, f"custom|{GSM8K.name}|5|1")])
270 | _TASKS_STRINGS.extend([(t, f"custom|{t.name}|4|1") for t in MATH_TASKS])
271 | _TASKS += COMMON_SENSE_REASONING_TASKS
272 | _TASKS += [GSM8K] + MATH_TASKS
273 |
274 | ## MMLU ##
275 | class CustomMMLUEvaluationTask(LightevalTaskConfig):
276 | def __init__(
277 | self,
278 | name,
279 | prompt_function="mmlu_prompt",
280 | hf_repo="lighteval/mmlu",
281 | hf_subset=None,
282 | # metric=[Metrics.loglikelihood_acc_single_token],
283 | metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace],
284 | hf_avail_splits=None,
285 | evaluation_splits=["test"],
286 | few_shots_split="dev",
287 | few_shots_select=None,
288 | generation_size=-1,
289 | stop_sequence=None,
290 | output_regex=None,
291 | frozen=False,
292 | ):
293 | super().__init__(
294 | name=name,
295 | prompt_function=prompt_function,
296 | hf_repo=hf_repo,
297 | hf_subset=hf_subset,
298 | metric=metric,
299 | hf_avail_splits=hf_avail_splits,
300 | evaluation_splits=evaluation_splits,
301 | few_shots_split=few_shots_split,
302 | few_shots_select=few_shots_select,
303 | generation_size=generation_size,
304 | stop_sequence=stop_sequence,
305 | output_regex=output_regex,
306 | frozen=frozen,
307 | )
308 |
309 | MMLU_TASKS = []
310 | mmlu_subsets = [
311 | "abstract_algebra",
312 | "anatomy",
313 | "astronomy",
314 | "business_ethics",
315 | "clinical_knowledge",
316 | "college_biology",
317 | "college_chemistry",
318 | "college_computer_science",
319 | "college_mathematics",
320 | "college_medicine",
321 | "college_physics",
322 | "computer_security",
323 | "conceptual_physics",
324 | "econometrics",
325 | "electrical_engineering",
326 | "elementary_mathematics",
327 | "formal_logic",
328 | "global_facts",
329 | "high_school_biology",
330 | "high_school_chemistry",
331 | "high_school_computer_science",
332 | "high_school_european_history",
333 | "high_school_geography",
334 | "high_school_government_and_politics",
335 | "high_school_macroeconomics",
336 | "high_school_mathematics",
337 | "high_school_microeconomics",
338 | "high_school_physics",
339 | "high_school_psychology",
340 | "high_school_statistics",
341 | "high_school_us_history",
342 | "high_school_world_history",
343 | "human_aging",
344 | "human_sexuality",
345 | "international_law",
346 | "jurisprudence",
347 | "logical_fallacies",
348 | "machine_learning",
349 | "management",
350 | "marketing",
351 | "medical_genetics",
352 | "miscellaneous",
353 | "moral_disputes",
354 | "moral_scenarios",
355 | "nutrition",
356 | "philosophy",
357 | "prehistory",
358 | "professional_accounting",
359 | "professional_law",
360 | "professional_medicine",
361 | "professional_psychology",
362 | "public_relations",
363 | "security_studies",
364 | "sociology",
365 | "us_foreign_policy",
366 | "virology",
367 | "world_religions",
368 | ]
369 |
370 | for answer_type in ("mc", "cloze"):
371 | prompt_function = f"mmlu_{answer_type}_prompt"
372 | generation_size = -1 if answer_type == "cloze" else 1
373 | for subset in mmlu_subsets:
374 | MMLU_TASKS.append(
375 | CustomMMLUEvaluationTask(
376 | name=f"mmlu_{answer_type}:{subset}",
377 | prompt_function=prompt_function,
378 | hf_subset=subset,
379 | generation_size=generation_size
380 | )
381 | )
382 |
383 | MMLU_TASKS += [
384 | CustomMMLUEvaluationTask(
385 | name=f"mmlu_stem_mc",
386 | hf_repo="TIGER-Lab/MMLU-STEM",
387 | prompt_function="mmlu_mc_prompt",
388 | hf_subset="default",
389 | generation_size=1
390 | ),
391 | CustomMMLUEvaluationTask(
392 | name=f"mmlu_stem_cloze",
393 | hf_repo="TIGER-Lab/MMLU-STEM",
394 | prompt_function="mmlu_cloze_prompt",
395 | hf_subset="default",
396 | generation_size=-1
397 | ),
398 | ]
399 |
400 |
401 | def mmlu_cloze_prompt(line, task_name: str = None):
402 | """MMLU prompt without letters"""
403 | topic = line["subject"]
404 | prompt = f"The following are questions about {topic.replace('_', ' ')}.\nQuestion: "
405 | prompt += line["question"] + "\nAnswer:"
406 |
407 | return Doc(
408 | task_name=task_name,
409 | query=prompt,
410 | choices=[f" {c}" for c in line["choices"]],
411 | gold_index=line["answer"],
412 | instruction=f"The following are questions about {topic.replace('_', ' ')}.\n",
413 | )
414 |
415 |
416 | def mmlu_mc_prompt(line, task_name: str = None):
417 | topic = line["subject"]
418 | query = f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n"
419 | query += line["question"] + "\n"
420 | query += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["choices"])])
421 | query += "Answer:"
422 |
423 | gold_ix = LETTER_INDICES.index(line["answer"]) if isinstance(line["answer"], str) else line["answer"]
424 |
425 | return Doc(
426 | task_name=task_name,
427 | query=query,
428 | choices=[" A", " B", " C", " D"],
429 | gold_index=gold_ix,
430 | instruction=f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n",
431 | target_for_fewshot_sorting=[" A", " B", " C", " D"][gold_ix],
432 | )
433 |
434 |
435 | MMLU_STRING = [(t, f"custom|{t.name}|0|1") for t in MMLU_TASKS]
436 | _TASKS_STRINGS.extend(MMLU_STRING)
437 | _TASKS += MMLU_TASKS
438 |
439 | # common sense reasoning + mmlu
440 | EARLY_SIGNAL_TASKS = ",".join([t[1] for t in COMMON_SENSE_REASONING_STRING] + [t[1] for t in MMLU_STRING])
441 |
442 | # Convert to dict for lighteval
443 | TASKS_TABLE = _TASKS
444 | # You can have a few pre-organised groups of tasks
445 | TASKS_GROUPS = {
446 | "early-signal": EARLY_SIGNAL_TASKS,
447 | "math": f"custom|{GSM8K.name}|5|1" + "," + ",".join([f"custom|{t.name}|4|1" for t in MATH_TASKS]),
448 | }
449 |
--------------------------------------------------------------------------------
/fulltext_search/README.md:
--------------------------------------------------------------------------------
1 | # Fulltext search with BISAC topics
2 |
3 | This is a simple example of how to use Manticore for fulltext search over CommonCrawl pages.
4 |
5 | Due to the size of the corpus and the specifics of the HuggingFace cluster, we only provide these scripts as a reference.
--------------------------------------------------------------------------------
/fulltext_search/index_docs.py:
--------------------------------------------------------------------------------
1 | import json
2 | import time
3 | import sys
4 | import random
5 |
6 | import requests
7 | from datasets import load_dataset
8 |
9 |
10 | def insert_batch(batch):
11 | ndjson = ""
12 |
13 | index_name = f"fineweb{random.randint(0, 63)}"
14 |
15 | for text, _id, url, language_score, token_count in zip(
16 | batch["text"],
17 | batch["id"],
18 | batch["url"],
19 | batch["language_score"],
20 | batch["token_count"],
21 | ):
22 | doc = {
23 | "insert": {
24 | "index": index_name,
25 | "_id": _id.split(":")[-1].strip(">"),
26 | "doc": {
27 | "content": text,
28 | "fw_id": _id.split(":")[-1].strip(">"),
29 | "url": url,
30 | "language_score": language_score,
31 | "token_count": token_count,
32 | },
33 | }
34 | }
35 | ndjson += json.dumps(doc) + "\n"
36 |
37 | response = None
38 | while response is None:
39 | try:
40 | response = requests.post(
41 | "http://127.0.0.1:9308/bulk",
42 | headers={"Content-Type": "application/x-ndjson"},
43 | data=ndjson,
44 | )
45 | except requests.exceptions.ConnectionError as e:
46 | print(e, file=sys.stderr)
47 | time.sleep(1)
48 | pass
49 |
50 | return {"response": [response.status_code]}
51 |
52 |
53 | def main():
54 | sql_url = "http://127.0.0.1:9308/sql?mode=raw"
55 |
56 | print("Removing table", file=sys.stderr)
57 | while True:
58 | try:
59 | requests.post(sql_url, data={"query": "drop table if exists fineweb"})
60 | break
61 | except requests.exceptions.ConnectionError as e:
62 | print(e, file=sys.stderr)
63 | time.sleep(5)
64 | pass
65 |
66 | print("Creating table", file=sys.stderr)
67 | for i in range(64):
68 | response = requests.post(
69 | sql_url, data={"query": f"drop table if exists fineweb{i}"}
70 | )
71 | print(response.text, file=sys.stderr)
72 | local_query = f"create table fineweb{i}(content text, fw_id string, url string, language_score float, token_count int) charset_table='non_cjk' stopwords='en' morphology='stem_en'"
73 | response = requests.post(sql_url, data={"query": local_query})
74 | print(response.text, file=sys.stderr)
75 |
76 | distributed_query = "create table fineweb type='distributed'"
77 | for i in range(64):
78 | distributed_query += f" local='fineweb{i}'"
79 | response = requests.post(sql_url, data={"query": distributed_query})
80 | print(response.text, file=sys.stderr)
81 |
82 | for dump in ["CC-MAIN-2024-10", "CC-MAIN-2023-50"]:
83 | print("Loading dataset", file=sys.stderr)
84 | dataset = load_dataset(
85 | "HuggingFaceFW/fineweb",
86 | dump,
87 | split="train",
88 | num_proc=64,
89 | cache_dir="/scratch/cosmo/.cache",
90 | )
91 | dataset = dataset.select_columns(
92 | ["text", "id", "url", "language_score", "token_count"]
93 | )
94 | dataset = dataset.map(
95 | insert_batch,
96 | batched=True,
97 | batch_size=10000,
98 | remove_columns=["text", "id", "url", "language_score", "token_count"],
99 | num_proc=64,
100 | )
101 | for _ in dataset:
102 | pass
103 |
104 | time.sleep(30)
105 | for i in range(64):
106 | print(f"Optimizing table fineweb{i}", file=sys.stderr)
107 | response = requests.post(
108 | sql_url,
109 | data={"query": f"FLUSH TABLE fineweb{i}"},
110 | timeout=600,
111 | )
112 | print(response.text, file=sys.stderr)
113 | response = requests.post(
114 | sql_url,
115 | data={"query": f"OPTIMIZE TABLE fineweb{i} OPTION cutoff=16, sync=1"},
116 | timeout=600,
117 | )
118 | print(response.text, file=sys.stderr)
119 | response = requests.post(
120 | sql_url,
121 | data={"query": f"FREEZE fineweb{i}"},
122 | timeout=600,
123 | )
124 | print(response.text, file=sys.stderr)
125 |
126 | response = requests.post(
127 | "http://127.0.0.1:9308/search",
128 | data='{"index":"fineweb","query":{"match":{"*":"hello world"}}}',
129 | )
130 | print(response.text, file=sys.stderr)
131 |
132 | # print("Backing up the index", file=sys.stderr)
133 | # time.sleep(30)
134 | # response = requests.post(
135 | # sql_url,
136 | # data={"query": "BACKUP TO /tmp/backups"},
137 | # )
138 | # print(response.text, file=sys.stderr)
139 |
140 |
141 | if __name__ == "__main__":
142 | main()
143 |
--------------------------------------------------------------------------------
/fulltext_search/index_docs.slurm:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=index_fineweb
3 | #SBATCH --partition hopper-prod
4 | #SBATCH --nodes=1
5 | #SBATCH --ntasks-per-node=1
6 | #SBATCH --cpus-per-task=96
7 | #SBATCH --mem-per-cpu=20G
8 | #SBATCH -o %x_%j.out
9 | #SBATCH -e %x_%j.err
10 | #SBATCH --time=7-00:00:00
11 |
12 | set -x -e
13 | source ~/.bashrc
14 | source "$CONDA_PREFIX/etc/profile.d/conda.sh"
15 | source activate pyspark
16 |
17 | ulimit -n 99999
18 |
19 | mkdir -p /scratch/cosmo/manticore_idx
20 | rm -rf /scratch/cosmo/manticore_idx/*
21 | srun --container-image='manticoresearch/manticore:6.2.12' \
22 | --container-env=EXTRA=1 \
23 | --container-mounts="/scratch/cosmo/manticore_idx:/var/lib/manticore:z,$(pwd)/manticore.conf:/etc/manticoresearch/manticore.conf" \
24 | --no-container-mount-home \
25 | --qos high \
26 | /bin/bash -c 'mkdir -p /var/run/manticore && chown manticore:manticore /var/run/manticore && mkdir -p /var/run/mysqld && chown manticore:manticore /var/run/mysqld && export EXTRA=1 && source /entrypoint.sh && docker_setup_env && /entrypoint.sh searchd -c /etc/manticoresearch/manticore.conf --nodetach' &
27 |
28 | python index_docs.py
29 |
30 | sleep 1000
31 |
32 | rclone copy -P --transfers 32 /scratch/cosmo/manticore_idx/ s3:cosmopedia-data/manticore_idx/CC-MAIN-2024-10-2023-50/
33 |
34 | sleep 1000000000
--------------------------------------------------------------------------------
/fulltext_search/manticore.conf:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | ip=`hostname -i|rev|cut -d\ -f 1|rev`
3 | cat << EOF
4 | searchd {
5 |
6 | # https://manual.manticoresearch.com/Server_settings/Searchd#access_plain_attrs
7 | # access_plain_attrs = mmap_preread
8 |
9 | # https://manual.manticoresearch.com/Server_settings/Searchd#access_blob_attrs
10 | # access_blob_attrs = mmap_preread
11 |
12 | # https://manual.manticoresearch.com/Server_settings/Searchd#access_doclists
13 | # access_doclists = file
14 |
15 | # https://manual.manticoresearch.com/Server_settings/Searchd#access_hitlists
16 | # access_hitlists = file
17 |
18 | # https://manual.manticoresearch.com/Server_settings/Searchd#agent_connect_timeout
19 | # agent_connect_timeout =
20 |
21 | # https://manual.manticoresearch.com/Server_settings/Searchd#agent_query_timeout
22 | # agent_query_timeout =
23 |
24 | # https://manual.manticoresearch.com/Server_settings/Searchd#agent_retry_count
25 | # agent_retry_count = 0
26 |
27 | # https://manual.manticoresearch.com/Server_settings/Searchd#agent_retry_delay
28 | # agent_retry_delay = 500
29 |
30 | # https://manual.manticoresearch.com/Server_settings/Searchd#attr_flush_period
31 | # attr_flush_period = 0
32 |
33 | # https://manual.manticoresearch.com/Server_settings/Searchd#binlog_flush
34 | # binlog_flush = 2
35 |
36 | # https://manual.manticoresearch.com/Server_settings/Searchd#binlog_max_log_size
37 | # binlog_max_log_size = 268435456
38 |
39 | # https://manual.manticoresearch.com/Server_settings/Searchd#binlog_path
40 | # binlog_path =
41 |
42 | # https://manual.manticoresearch.com/Server_settings/Searchd#client_timeout
43 | # client_timeout = 300
44 |
45 | # https://manual.manticoresearch.com/Server_settings/Searchd#collation_libc_locale
46 | # collation_libc_locale = C
47 |
48 | # https://manual.manticoresearch.com/Server_settings/Searchd#collation_server
49 | # collation_server = libc_ci
50 |
51 | # https://manual.manticoresearch.com/Server_settings/Searchd#data_dir
52 | data_dir = /var/lib/manticore
53 |
54 | # https://manual.manticoresearch.com/Server_settings/Searchd#docstore_cache_size
55 | # docstore_cache_size = 16m
56 |
57 | # https://manual.manticoresearch.com/Server_settings/Searchd#expansion_limit
58 | # expansion_limit = 0
59 |
60 | # https://manual.manticoresearch.com/Server_settings/Searchd#grouping_in_utc
61 | # grouping_in_utc = 0
62 |
63 | # https://manual.manticoresearch.com/Server_settings/Searchd#ha_period_karma
64 | # ha_period_karma = 60
65 |
66 | # https://manual.manticoresearch.com/Server_settings/Searchd#ha_ping_interval
67 | # ha_ping_interval = 1000
68 |
69 | # https://manual.manticoresearch.com/Server_settings/Searchd#hostname_lookup
70 | # hostname_lookup =
71 |
72 | # https://manual.manticoresearch.com/Server_settings/Searchd#jobs_queue_size
73 | # jobs_queue_size =
74 |
75 | # https://manual.manticoresearch.com/Server_settings/Searchd#listen_backlog
76 | # listen_backlog = 5
77 |
78 | # https://manual.manticoresearch.com/Server_settings/Searchd#listen
79 | # listen_env = this directive allows to append listeners from environment variables
80 |
81 | listen = 9306:mysql41
82 | listen = /var/run/mysqld/mysqld.sock:mysql41
83 | listen = $ip:9312
84 | listen = 9308:http
85 | listen = $ip:9315-9325:replication
86 |
87 | # https://manual.manticoresearch.com/Server_settings/Searchd#listen_tfo
88 | # listen_tfo = 0
89 |
90 | # https://manual.manticoresearch.com/Server_settings/Searchd#log
91 | log = /var/log/manticore/searchd.log
92 |
93 | # https://manual.manticoresearch.com/Server_settings/Searchd#max_batch_queries
94 | # max_batch_queries = 32
95 |
96 | # https://manual.manticoresearch.com/Server_settings/Searchd#threads
97 | threads = 64
98 |
99 | # https://manual.manticoresearch.com/Server_settings/Searchd#max_filters
100 | # max_filters = 256
101 |
102 | # https://manual.manticoresearch.com/Server_settings/Searchd#max_filter_values
103 | # max_filter_values = 4096
104 |
105 | # https://manual.manticoresearch.com/Server_settings/Searchd#max_open_files
106 | max_open_files = 65535
107 |
108 | optimize_cutoff = 2
109 |
110 | # https://manual.manticoresearch.com/Server_settings/Searchd#max_packet_size
111 | max_packet_size = 128M
112 |
113 | # https://manual.manticoresearch.com/Server_settings/Searchd#mysql_version_string
114 | # mysql_version_string =
115 |
116 | # https://manual.manticoresearch.com/Server_settings/Searchd#net_workers
117 | # net_workers = 1
118 |
119 | # https://manual.manticoresearch.com/Server_settings/Searchd#net_wait_tm
120 | # net_wait_tm = -1
121 |
122 | # https://manual.manticoresearch.com/Server_settings/Searchd#net_throttle_accept
123 | # net_throttle_accept = 0
124 |
125 | # https://manual.manticoresearch.com/Server_settings/Searchd#net_throttle_action
126 | # net_throttle_action = 0
127 |
128 | # https://manual.manticoresearch.com/Server_settings/Searchd#node_address
129 | # node_address =
130 |
131 | # https://manual.manticoresearch.com/Server_settings/Searchd#ondisk_attrs_default
132 | # ondisk_attrs_default = 0
133 |
134 | # https://manual.manticoresearch.com/Server_settings/Searchd#persistent_connections_limit
135 | # persistent_connections_limit =
136 |
137 | # https://manual.manticoresearch.com/Server_settings/Searchd#pid_file
138 | pid_file = /var/run/manticore/searchd.pid
139 |
140 | # https://manual.manticoresearch.com/Server_settings/Searchd#predicted_time_costs
141 | # predicted_time_costs = doc=64, hit=48, skip=2048, match=64
142 |
143 | # https://manual.manticoresearch.com/Server_settings/Searchd#preopen_indexes
144 | # preopen_indexes = 1
145 |
146 | # https://manual.manticoresearch.com/Server_settings/Searchd#qcache_max_bytes
147 | # qcache_max_bytes = 16Mb
148 |
149 | # https://manual.manticoresearch.com/Server_settings/Searchd#qcache_thresh_msec
150 | # qcache_thresh_msec = 3000
151 |
152 | # https://manual.manticoresearch.com/Server_settings/Searchd#qcache_ttl_sec
153 | # qcache_ttl_sec = 60
154 |
155 | # https://manual.manticoresearch.com/Server_settings/Searchd#query_log_format
156 | query_log_format = sphinxql
157 |
158 | # https://manual.manticoresearch.com/Server_settings/Searchd#query_log_min_msec
159 | # query_log_min_msec = 0
160 |
161 | # https://manual.manticoresearch.com/Server_settings/Searchd#query_log
162 | # query_log = /var/log/manticore/query.log
163 |
164 | # https://manual.manticoresearch.com/Server_settings/Searchd#query_log_mode
165 | # query_log_mode = 600
166 |
167 | # https://manual.manticoresearch.com/Server_settings/Searchd#max_connections
168 | # max_connections =
169 |
170 | # https://manual.manticoresearch.com/Server_settings/Searchd#network_timeout
171 | # network_timeout = 5
172 |
173 | # https://manual.manticoresearch.com/Server_settings/Searchd#read_buffer
174 | # read_buffer = 256K
175 |
176 | # https://manual.manticoresearch.com/Server_settings/Searchd#read_buffer_docs
177 | # read_buffer_docs = 256K
178 |
179 | # https://manual.manticoresearch.com/Server_settings/Searchd#read_buffer_hits
180 | # read_buffer_hits = 256K
181 |
182 | # https://manual.manticoresearch.com/Server_settings/Searchd#read_unhinted
183 | # read_unhinted 32K
184 |
185 | # https://manual.manticoresearch.com/Server_settings/Searchd#rt_flush_period
186 | # rt_flush_period =
187 |
188 | # https://manual.manticoresearch.com/Server_settings/Searchd#rt_merge_iops
189 | # rt_merge_iops = 0
190 |
191 | # https://manual.manticoresearch.com/Server_settings/Searchd#rt_merge_maxiosize
192 | # rt_merge_maxiosize = 0
193 |
194 | # https://manual.manticoresearch.com/Server_settings/Searchd#seamless_rotate
195 | # seamless_rotate = 1
196 |
197 | # https://manual.manticoresearch.com/Server_settings/Searchd#server_id
198 | # server_id =
199 |
200 | # https://manual.manticoresearch.com/Server_settings/Searchd#shutdown_timeout
201 | # shutdown_timeout = 3
202 |
203 | # https://manual.manticoresearch.com/Server_settings/Searchd#shutdown_token
204 | # shutdown_token =
205 |
206 | # https://manual.manticoresearch.com/Server_settings/Searchd#snippets_file_prefix
207 | # snippets_file_prefix =
208 |
209 | # https://manual.manticoresearch.com/Server_settings/Searchd#sphinxql_state
210 | # sphinxql_state =
211 |
212 | # https://manual.manticoresearch.com/Server_settings/Searchd#sphinxql_timeout
213 | # sphinxql_timeout = 900
214 |
215 | # https://manual.manticoresearch.com/Server_settings/Searchd#ssl_ca
216 | # ssl_ca =
217 |
218 | # https://manual.manticoresearch.com/Server_settings/Searchd#ssl_cert
219 | # ssl_cert =
220 |
221 | # https://manual.manticoresearch.com/Server_settings/Searchd#ssl_key
222 | # ssl_key =
223 |
224 | # https://manual.manticoresearch.com/Server_settings/Searchd#subtree_docs_cache
225 | # subtree_docs_cache = 0
226 |
227 | # https://manual.manticoresearch.com/Server_settings/Searchd#subtree_hits_cache
228 | # subtree_hits_cache = 0
229 |
230 | # https://manual.manticoresearch.com/Server_settings/Searchd#thread_stack
231 | # thread_stack =
232 |
233 | # https://manual.manticoresearch.com/Server_settings/Searchd#unlink_old
234 | # unlink_old = 1
235 |
236 | # https://manual.manticoresearch.com/Server_settings/Searchd#watchdog
237 | # watchdog = 1
238 |
239 | # https://manual.manticoresearch.com/Server_settings/Searchd#secondary_indexes
240 | secondary_indexes = 0
241 | }
242 |
243 | common {
244 |
245 | # https://manual.manticoresearch.com/Server_settings/Common#lemmatizer_base
246 | # lemmatizer_base = /usr/local/share
247 |
248 | # https://manual.manticoresearch.com/Server_settings/Common#progressive_merge
249 | # progressive_merge =
250 |
251 | # https://manual.manticoresearch.com/Server_settings/Common#json_autoconv_keynames
252 | # json_autoconv_keynames =
253 |
254 | # https://manual.manticoresearch.com/Server_settings/Common#json_autoconv_numbers
255 | # json_autoconv_numbers = 0
256 |
257 | # https://manual.manticoresearch.com/Server_settings/Common#on_json_attr_error
258 | # on_json_attr_error = ignore_attr
259 |
260 | # https://manual.manticoresearch.com/Server_settings/Common#plugin_dir
261 | # plugin_dir =
262 |
263 | }
264 |
265 | EOF
--------------------------------------------------------------------------------
/fulltext_search/search_sharded.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import sys
4 | import time
5 |
6 | import requests
7 | from datasets import load_dataset
8 |
9 |
10 | def get_args():
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument(
13 | "--input_dataset", type=str, default="HuggingFaceTB/bisac_expanded_final"
14 | )
15 | parser.add_argument("--n_pages", type=int, default=2000)
16 | parser.add_argument(
17 | "--output_dataset",
18 | type=str,
19 | default="HuggingFaceTB/bisac_boosted_new_index_2000",
20 | )
21 | parser.add_argument("--shard", type=int, required=True)
22 | parser.add_argument("--num_shards", type=int, required=True)
23 | return parser.parse_args()
24 |
25 |
26 | # wait until the server is up
27 | while True:
28 | try:
29 | requests.post(
30 | "http://127.0.0.1:9308/search",
31 | data='{"index": "fineweb", "query": {"match": {"content": "ping"}}}',
32 | )
33 | break
34 | except requests.exceptions.ConnectionError:
35 | time.sleep(10)
36 | pass
37 |
38 |
39 | args = get_args()
40 | data = load_dataset(
41 | args.input_dataset, split="train", cache_dir="/scratch/cosmo/.cache"
42 | )
43 | data = data.filter(lambda x, i: i % args.num_shards == args.shard, with_indices=True)
44 | data = data.select_columns(["top_category", "subcategory", "subtopic"])
45 |
46 |
47 | def run_query(query, n_pages):
48 | while True:
49 | try:
50 | max_pages = 4_000
51 | response = requests.post(
52 | "http://127.0.0.1:9308/search",
53 | data=json.dumps(
54 | {
55 | "index": "fineweb",
56 | "size": n_pages,
57 | "query": query,
58 | "max_matches": max_pages,
59 | }
60 | ),
61 | timeout=1000,
62 | )
63 | if response.status_code != 200:
64 | print(response.text, file=sys.stderr)
65 | time.sleep(5)
66 | continue
67 | else:
68 | hits = response.json()["hits"]["hits"]
69 | return hits
70 | except requests.exceptions.ConnectionError as e:
71 | print(e, file=sys.stderr)
72 | time.sleep(5)
73 | continue
74 |
75 |
76 | def search_topic(sample):
77 | top_category = sample["top_category"][0].strip()
78 | subcategory = sample["subcategory"][0].strip()
79 | subtopic = sample["subtopic"][0].strip()
80 | for c in ["!", '"', "$", "'", "(", ")", "/", "<", "@", "\\", "^", "|", "~"]:
81 | top_category = top_category.replace(c, " ")
82 | subcategory = subcategory.replace(c, " ")
83 | subtopic = subtopic.replace(c, " ")
84 | # boosting the IDF score of subtopic tokens
85 | boosted_subtopic = " ".join([w + "^2" for w in subtopic.split()])
86 | match_query = " ".join([top_category, subcategory, subtopic])
87 | boosted_query = " ".join([top_category, subcategory, boosted_subtopic])
88 |
89 | boosted_hits = run_query({"query_string": boosted_query}, args.n_pages)
90 | print(f"Boosted hits: {len(boosted_hits)} for {boosted_query}", file=sys.stderr)
91 | if len(boosted_hits) < args.n_pages:
92 | match_hits = run_query(
93 | {"match": {"content": match_query}}, args.n_pages + len(boosted_hits)
94 | )
95 | print(f"Match hits: {len(match_hits)} for {match_query}", file=sys.stderr)
96 | else:
97 | match_hits = []
98 |
99 | hit_ids = set()
100 | hits = []
101 | for hit in boosted_hits + match_hits:
102 | if hit["_id"] not in hit_ids:
103 | hits.append(hit)
104 | hit_ids.add(hit["_id"])
105 | hits = hits[: args.n_pages]
106 |
107 | results = {
108 | "top_category": sample["top_category"] * len(hits),
109 | "subcategory": sample["subcategory"] * len(hits),
110 | "subtopic": sample["subtopic"] * len(hits),
111 | "topic_hits": hits,
112 | "num_hits": [len(hits)] * len(hits),
113 | }
114 | return results
115 |
116 |
117 | data = data.map(search_topic, batched=True, batch_size=1, num_proc=2)
118 | data.push_to_hub(
119 | f"{args.output_dataset}_{args.shard}", private=True, max_shard_size="4096MB"
120 | )
121 |
--------------------------------------------------------------------------------
/fulltext_search/search_sharded.slurm:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=cosmo_search_sharded
3 | #SBATCH --partition hopper-prod
4 | #SBATCH --nodes=1
5 | #SBATCH --ntasks-per-node=1
6 | #SBATCH --cpus-per-task=96
7 | #SBATCH --mem-per-cpu=20G
8 | #SBATCH -o %x_%j.out
9 | #SBATCH -e %x_%j.err
10 | #SBATCH --array=0-15%8
11 |
12 | set -x -e
13 | source ~/.bashrc
14 | source "$CONDA_PREFIX/etc/profile.d/conda.sh"
15 | source activate pyspark
16 |
17 | ulimit -n 99999
18 |
19 | mkdir -p /scratch/cosmo/manticore_idx
20 | rm -rf /scratch/cosmo/manticore_idx/*
21 | rclone copy -P --transfers 32 s3:cosmopedia-data/manticore_idx/CC-MAIN-2024-10-2023-50/ /scratch/cosmo/manticore_idx/
22 |
23 | srun --container-image='manticoresearch/manticore:6.2.12' \
24 | --container-mounts="/scratch/cosmo/manticore_idx:/var/lib/manticore:z,$(pwd)/manticore.conf:/etc/manticoresearch/manticore.conf" \
25 | --no-container-mount-home \
26 | --qos high \
27 | /bin/bash -c 'mkdir -p /var/run/manticore && chown manticore:manticore /var/run/manticore && mkdir -p /var/run/mysqld && chown manticore:manticore /var/run/mysqld && /entrypoint.sh searchd -c /etc/manticoresearch/manticore.conf --nodetach' &
28 |
29 | python search_sharded.py --shard ${SLURM_ARRAY_TASK_ID} --num_shards 16
--------------------------------------------------------------------------------
/generation/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Synthetic data generation
3 |
4 | If you have a large dataset of prompts and want to generate content using an Open-Source LLM like [Mixtral-8x7B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1), you can use `llm-swarm` which spins TGI or vLLM instances on `slurm`clutsres, we used it to generated Cosmopedia, which con sists of 25B tokens. The full generation took around > 10k H100 GPU hours.
5 |
6 | You can find the instructions for running the generation in `llm-swarm` here: https://github.com/huggingface/llm-swarm/tree/loubna/examples/textbooks
7 |
8 | The generation script is also available here (to be used within `examples/textbooks` of the library.)
9 |
10 | ```bash
11 | # after having followed all the installation guidlines in llm-swrarm + install wandb
12 | # 100k subset
13 | python generate_syntehtic_textbooks.py \
14 | --model mistralai/Mixtral-8x7B-Instruct-v0.1 \
15 | --instances 2 \
16 | --prompts_dataset "HuggingFaceTB/cosmopedia-100k" \
17 | --prompt_column prompt \
18 | --max_samples 2000 \
19 | --checkpoint_path "./synthetic_data" \
20 | --checkpoint_interval 1000
21 | ```
22 |
--------------------------------------------------------------------------------
/generation/boilerplate_cleanup.py:
--------------------------------------------------------------------------------
1 | import re
2 | import argparse
3 | from datasets import load_dataset
4 |
5 |
6 | patterns = [
7 | # alien stories
8 | r"^Hello.*?[.!]\s+",
9 | #r"^I'm( so)? excited to.*?[.!]\s+",
10 | r"^My name is.*?[.!]\s+",
11 | r"^You've just arrived.*?[.!]\s+",
12 |
13 | # wikihow
14 | r"^\*\*Welcome, .*?[.!]\*\*\s+",
15 | r"^(\*\*)?Warning:.*?[.!]\s+",
16 | r"^We're thrilled.*?[.!]\s+",
17 |
18 | # middle school
19 | r"^Welcome, .*?[.!]\s+",
20 | ]
21 | patterns = [re.compile(p, flags=re.IGNORECASE|re.MULTILINE) for p in patterns]
22 |
23 | def clean_text(sample):
24 | sample['completion_unfiltered'] = sample['completion']
25 | for pattern in patterns:
26 | sample['completion'] = pattern.sub('', sample['completion'].strip())
27 | return sample
28 |
29 |
30 | if __name__ == "__main__":
31 | parser = argparse.ArgumentParser()
32 | parser.add_argument("--dataset", type=str, default="HuggingFaceTB/alien_stories_0_1M_llama3")
33 | args = parser.parse_args()
34 |
35 | data = load_dataset(args.dataset, split="train", cache_dir="/scratch/cosmo/cache", num_proc=32)
36 | data = data.map(clean_text, num_proc=32)
37 | data.push_to_hub(args.dataset, private=True)
--------------------------------------------------------------------------------
/generation/llm_swarm_script.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import multiprocessing
3 | import os
4 | import time
5 | from dataclasses import asdict, dataclass
6 |
7 | from datasets import Dataset, load_dataset
8 | from huggingface_hub import AsyncInferenceClient
9 | from llm_swarm import LLMSwarm, LLMSwarmConfig
10 | from tqdm.asyncio import tqdm_asyncio
11 | from transformers import AutoTokenizer, HfArgumentParser
12 |
13 | import wandb
14 |
15 | HF_TOKEN = os.environ.get("HF_TOKEN", None)
16 |
17 |
18 | @dataclass
19 | class Args:
20 | # gneration parameters
21 | max_new_tokens: int = 2500
22 | """Max new tokens"""
23 | temperature: float = 0.6
24 | """Generation temperature"""
25 | top_p: float = 0.95
26 | """Generation top_p"""
27 | top_k: int = 50
28 | """Generation top_k"""
29 | repetition_penalty: float = 1.2
30 | """Generation repetition_penalty"""
31 | # prompts dataset parameters
32 | prompts_dataset: str = "HuggingFaceTB/cosmopedia-100k"
33 | """Dataset containing the prompts"""
34 | max_samples: int = 5000
35 | """The maximum number of samples to generate (use -1 for all))"""
36 | start_sample: int = -1
37 | """First sample to process"""
38 | end_sample: int = -1
39 | """Last sample to process"""
40 | seed: int = 42
41 | """Seed for shuffling"""
42 | prompt_column: str = "prompt"
43 | """Name of the column containing the prompt"""
44 | shuffle_dataset: bool = False
45 | """Whether to shuffle the prompts"""
46 | debug: bool = False
47 | """Debugging mode"""
48 | # logging parameters
49 | repo_id: str = "HuggingFaceTB/synthetic_data_test"
50 | """The repo id to push to"""
51 | checkpoint_path: str = "./synthetic_data"
52 | """Path for saving intermediate generations"""
53 | checkpoint_interval: int = 1_000
54 | """Interval for saving intermediate generations"""
55 | wandb_username: str = "loubnabnl"
56 | """Wandb username"""
57 | min_token_length: int = 150
58 | """Minimum number of tokens in a generation to be kept in the final dataset"""
59 | push_to_hub: bool = True
60 | """Whether to push to hub"""
61 |
62 |
63 | parser = HfArgumentParser((Args, LLMSwarmConfig))
64 | args, isc = parser.parse_args_into_dataclasses()
65 | # args used in wandb
66 | args_dict = asdict(args)
67 | args_dict.update(
68 | {
69 | "per_instance_max_parallel_requests": isc.per_instance_max_parallel_requests,
70 | "instances": isc.instances,
71 | "inference_engine": isc.inference_engine,
72 | "model": isc.model,
73 | }
74 | )
75 | print(args_dict)
76 |
77 | tokenizer = AutoTokenizer.from_pretrained(isc.model)
78 |
79 | num_proc = 1 if args.debug else multiprocessing.cpu_count()
80 | ds = load_dataset(
81 | args.prompts_dataset, token=HF_TOKEN, split="train", num_proc=num_proc
82 | )
83 |
84 | if args.shuffle_dataset:
85 | ds = ds.shuffle(seed=args.seed)
86 |
87 | if args.start_sample >= 0:
88 | end_sample = len(ds) if args.end_sample < 0 else args.end_sample
89 | print(f"Loading a defined range of samples: ({args.start_sample}, {end_sample})...")
90 | ds = ds.select(range(args.start_sample, end_sample))
91 | elif args.max_samples > 0:
92 | print(f"Loading the first {args.max_samples} samples...")
93 | ds = ds.select(range(args.max_samples))
94 |
95 |
96 | with LLMSwarm(isc) as llm_swarm:
97 | semaphore = asyncio.Semaphore(llm_swarm.suggested_max_parallel_requests)
98 | client = AsyncInferenceClient(model=llm_swarm.endpoint)
99 | STOP_SEQ = ["<|endoftext|>"]
100 |
101 | MAX_RETRIES = 6 # maximum number of retries
102 | RETRY_DELAY = 4 # delay in seconds between retries
103 |
104 | async def process_text(sample):
105 | token_length = 0
106 | attempt = 0
107 | while attempt < MAX_RETRIES:
108 | try:
109 | async with semaphore:
110 | completion = await client.text_generation(
111 | prompt=tokenizer.apply_chat_template(
112 | [{"role": "user", "content": sample[args.prompt_column]}],
113 | tokenize=False,
114 | ),
115 | max_new_tokens=args.max_new_tokens,
116 | stop_sequences=STOP_SEQ,
117 | temperature=args.temperature,
118 | top_p=args.top_p,
119 | top_k=args.top_k,
120 | repetition_penalty=args.repetition_penalty,
121 | )
122 | for stop_seq in STOP_SEQ:
123 | if completion.endswith(stop_seq):
124 | completion = completion[: -len(stop_seq)].rstrip()
125 | token_length += len(tokenizer.encode(completion))
126 | sample["completion"] = completion
127 | sample["token_length"] = token_length
128 | return sample
129 |
130 | except Exception as e:
131 | attempt += 1
132 | if attempt < MAX_RETRIES:
133 | print(
134 | f"Request failed, retrying in {RETRY_DELAY} seconds... (Attempt {attempt}/{MAX_RETRIES})"
135 | )
136 | await asyncio.sleep(RETRY_DELAY)
137 | else:
138 | print(
139 | f"Max retries reached. Failed to process the request with error {str(e)}."
140 | )
141 | sample["completion"] = ""
142 | sample["token_length"] = 0
143 | return sample
144 |
145 | async def main():
146 | start_time = time.time()
147 | total_tokens = 0
148 | saving_time = 0
149 |
150 | repo_id = (
151 | f"{args.repo_id}_{args.prompt_column}"
152 | if args.prompt_column not in args.repo_id
153 | else args.repo_id
154 | )
155 | wandb.init(
156 | project="synthetic_data",
157 | entity=args.wandb_username,
158 | name=repo_id.split("/")[1],
159 | )
160 | wandb.config.update(args_dict)
161 |
162 | repo_id = (
163 | f"{args.repo_id}_{args.prompt_column}"
164 | if args.prompt_column not in args.repo_id
165 | else args.repo_id
166 | )
167 | checkpoint_dir = f"{args.checkpoint_path}/{repo_id.split('/')[1]}/data"
168 | os.makedirs(checkpoint_dir, exist_ok=True)
169 | print(f"Will be saving at {checkpoint_dir}")
170 |
171 | total_samples = len(ds)
172 | for i in range(0, total_samples, args.checkpoint_interval):
173 | batch_time = time.time()
174 | # Processing a chunk
175 | print(
176 | f"Processing chunk {int(i/args.checkpoint_interval)}/{int(total_samples/args.checkpoint_interval)}"
177 | )
178 | end_index = min(i + args.checkpoint_interval, total_samples)
179 | chunk = ds.select(range(i, end_index))
180 | chunk_results = await tqdm_asyncio.gather(
181 | *(process_text(sample) for sample in chunk)
182 | )
183 | # Save the chunk results and log throughput
184 | temp_time = time.time()
185 | time_per_chunk = temp_time - batch_time
186 | checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_{i}.json")
187 | intermediate_ds = Dataset.from_list(chunk_results)
188 | intermediate_ds.to_json(checkpoint_path)
189 | batch_tokens = sum(intermediate_ds["token_length"])
190 | total_tokens += batch_tokens
191 | saving_time += time.time() - temp_time
192 | print(f"💾 Checkpoint (samples {i}-{i + args.checkpoint_interval}) saved at {checkpoint_path}.")
193 | wandb.log(
194 | {
195 | "sample": i + args.checkpoint_interval,
196 | "batch": int(i / args.checkpoint_interval),
197 | "total_tokens (M)": total_tokens / 1e6,
198 | "tokens_per_batch": batch_tokens,
199 | "time_per_batch (s)": time_per_chunk,
200 | "generated_tokens_per_sec": int(batch_tokens / time_per_chunk),
201 | "generated_tokens_per_sec_per_node": int(
202 | batch_tokens / (time_per_chunk * isc.instances)
203 | ),
204 | }
205 | )
206 |
207 | end_time = time.time()
208 |
209 | print(
210 | "Done processing and saving all chunks 🎉! Let's get some stats and push to hub..."
211 | )
212 | total_duration = end_time - start_time
213 | overall_tokens_per_second = (
214 | total_tokens / total_duration if total_duration > 0 else 0
215 | )
216 | print(
217 | f"🏎️💨 Overall Tokens per Second: {overall_tokens_per_second:.2f}, per instance: {overall_tokens_per_second/isc.instances:.2f}"
218 | )
219 | print(f"Generated {total_tokens / 1e6:.2f}M tokens")
220 | print(
221 | f"Total duration: {total_duration // 3600}h{int((total_duration % 3600) // 60)}min "
222 | )
223 | print(f"Saving time: {saving_time}s={saving_time/60}min ")
224 |
225 | # load dataset
226 | print("Load checkpoints...")
227 | output_ds = load_dataset(checkpoint_dir, split="train")
228 | # remove empty completions
229 | final_data = output_ds.filter(
230 | lambda x: x["token_length"] >= args.min_token_length
231 | )
232 | print(final_data)
233 | failed = output_ds.filter(lambda x: x["token_length"] <= args.min_token_length)
234 | print(final_data)
235 | if args.push_to_hub:
236 | print(f"📨 Pushing dataset to {repo_id}")
237 | final_data.push_to_hub(repo_id, private=True)
238 | print("Dataset pushed!")
239 | if len(failed) > 0:
240 | print(f"{len(failed)} generations failed")
241 | size = min(len(failed), 1000)
242 | failed = failed.select(range(size))
243 | failed.push_to_hub(f"{repo_id}_failed", private=True)
244 |
245 | asyncio.run(main())
246 | wandb.finish()
--------------------------------------------------------------------------------
/plots/clusters_map.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/cosmopedia/653acad4f6146ce1043e2ca8792c671947e83ce0/plots/clusters_map.png
--------------------------------------------------------------------------------
/plots/cover.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/cosmopedia/653acad4f6146ce1043e2ca8792c671947e83ce0/plots/cover.png
--------------------------------------------------------------------------------
/plots/cover_01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/cosmopedia/653acad4f6146ce1043e2ca8792c671947e83ce0/plots/cover_01.png
--------------------------------------------------------------------------------
/plots/educational_score.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/cosmopedia/653acad4f6146ce1043e2ca8792c671947e83ce0/plots/educational_score.png
--------------------------------------------------------------------------------
/plots/topics_distpng.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/cosmopedia/653acad4f6146ce1043e2ca8792c671947e83ce0/plots/topics_distpng.png
--------------------------------------------------------------------------------
/prompts/README.md:
--------------------------------------------------------------------------------
1 | # Building the prompts
2 |
3 | Here you can find the code for building the prompts for each `seed_data` in [Cosmopedia](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia).
4 | In Cosmopedia we standardized the column names and merged all the data sources together, which isn't the case for some scripts here.
5 |
--------------------------------------------------------------------------------
/prompts/auto_math_text/README.md:
--------------------------------------------------------------------------------
1 | ## Synthetic generations from AutoMathText
2 |
3 | To build prompts from the web subset of AutoMaThText for two audiences: college students and grade school students, run:
4 |
5 | ```
6 | python ./build_science_prompt.py --run_all_styles
7 | # for a specific generation style/audience:
8 | python ./build_science_prompt.py --generation_style college
9 | ```
10 |
--------------------------------------------------------------------------------
/prompts/auto_math_text/build_science_prompts.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from datasets import load_dataset
3 |
4 |
5 | STYLES = {"college":
6 | """Write an educational piece suited for college students related to the following text snippet:
7 | ""
8 |
9 | Do not just list concepts, but develop each one in detail before moving to the next, as we prioritize depth of understanding and comprehensive exploration of the subject matter over breadth. Focus on:
10 |
11 | - Rigor: Ensure in-depth coverage of the concepts/sections.
12 | - Engagement: Write with an academic, professional and engaging tone that captivates interest.
13 | - Application: Incorporate specific, practical examples, such as proofs in calculus or critical dates and figures in history.
14 | Do not include a title or an introduction, simply write the content without headlines and introductory phrases. Do not use images.""",
15 |
16 | "grade_school":
17 | """Here's an extract from a webpage:
18 | ""
19 |
20 | Create an educational piece related to the snippet above targeted at grade-school students. Complex college-like topics such Electromagnetism and Integration shouldn't be used, as they aren't usually taught at grade-school. If that's what the snippet is about, look for a much simpler scientific alternative to explain, and use everyday examples. For instance, if the topic is 'Linear Algebra' you might discuss how arranging objects in rows and columns can help solve puzzles.
21 | Avoid technical terms and LaTeX and only discuss simple grade-school level topics. Start the educational piece right away."""}
22 |
23 | EXTRACT_SIZE = 1000
24 |
25 |
26 | def get_args():
27 | parser = argparse.ArgumentParser()
28 | parser.add_argument("--repo_id", type=str, default="HuggingFaceTB/auto_math")
29 | parser.add_argument("--generation_style", type=str, default="college")
30 | parser.add_argument("--run_all_styles", action="store_true")
31 | return parser.parse_args()
32 |
33 |
34 | def build_prompt(x, style="college"):
35 | """Build the prompt based on the generation type"""
36 | snippet = x["text"].strip()
37 | snippet = snippet[:min(len(snippet), EXTRACT_SIZE)]
38 | prompt = STYLES[style].replace("", snippet)
39 | return {f"prompt_{style}": prompt}
40 |
41 |
42 | if __name__ == "__main__":
43 | args = get_args()
44 |
45 | print(f"Loading AutoMathText web data...")
46 | ds = load_dataset("math-ai/AutoMathText", "web-0.50-to-1.00")["train"]
47 | if args.run_all_styles:
48 | suffix = ""
49 | for style in STYLES.keys():
50 | print(f"📖 Building prompts with a {style}...")
51 | ds = ds.map(build_prompt, num_proc=48, fn_kwargs={"style": style})
52 | else:
53 | suffix = f"_{args.generation_style}"
54 | print(f"📖 Building prompts with a {args.generation_style}...")
55 | ds = ds.map(build_prompt, num_proc=48, fn_kwargs={"style": args.generation_style})
56 | print(ds)
57 | print(ds)
58 | print(ds[0]["prompt_college"])
59 | print("-"*100)
60 | print(ds[1]["prompt_grade_school"])
61 | ds.push_to_hub(f"{args.repo_id}{suffix}", private=True)
62 | print(f"✅ Data available at {args.repo_id}{suffix}!")
--------------------------------------------------------------------------------
/prompts/khanacademy/README.md:
--------------------------------------------------------------------------------
1 | # KhanAcademy
2 | ## Code adapted from https://github.com/rand-net/khan-dl
3 |
4 | ## Run script to download list of courses
5 | ```bash
6 | # install requirements
7 | pip install -r khan_dl/requirements.txt
8 | # run downloader with all courses
9 | python khan_dl/main.py -a
10 | ```
11 |
12 | output will be saved on `khan_courses.json`
13 |
14 | You can then use `generate_textbooks.py` to build the textbook generation prompts.
15 | [TODO]: add code for updated prompts fo Cosmopedia
--------------------------------------------------------------------------------
/prompts/khanacademy/generate_textbooks.py:
--------------------------------------------------------------------------------
1 | import json
2 | from string import Template
3 | import pandas as pd
4 |
5 | TEMPLATE = Template("""Write a long and very detailed course unit for a textbook on "${unit_title}".
6 | ${previous_sections}\"${section}\".
7 | ${previous_sub_units}
8 | Write the new sub-unit titled \"${unit}\" while trying to be:
9 | - Rigorous - you create challenging textbooks that cover the material in depth.
10 | - Engaging - your textbooks have a narrative arc and engaging tone, like the writing of Michael Lewis.
11 | - Applied - you use specific and practical examples. For example, if the topic is integration in calculus, include equations and proofs of the concept you're teaching. As another example, if the topic is the history of the United States, include dates, names, and key events.
12 | Model:""")
13 |
14 |
15 | # file created by khan_dl
16 | with open("khan_courses.json") as f:
17 | data = json.load(f)
18 |
19 | textbooks = []
20 | total_courses = 0
21 | total_sections = 0
22 | for course in data:
23 | total_courses += 1
24 | units = course["subunits"]
25 | for ui, unit in enumerate(units):
26 | for sui, subunit in enumerate(unit["subunits"]):
27 | total_sections += 1
28 | for li, lesson_title in enumerate(subunit["lessons"]):
29 | # previous sections
30 | previous_subunits = [f"\"{sii + 1}. {s['title']}\"" for sii, s in enumerate(unit["subunits"][:sui])]
31 | previous_subunits_text = f"We have already covered chapter(s) {', '.join(previous_subunits)} and are now writing a chapter on " if previous_subunits else "We are currently writing the first chapter: "
32 | # previous lessons
33 | previous_lessons = [f"{lesson_title}\"" for lii, lesson_title in
34 | enumerate(subunit["lessons"][:li])]
35 | previous_lessons_text = f"We have already covered the following lessons in the current chapter: {', '.join(previous_lessons)}." if previous_lessons else "You will be writing the first lesson for this chapter."
36 |
37 | section_name = f"{unit['title']} - {subunit['title']}"
38 | # WIP
39 | sample = {
40 | "unit_title": course["title"].replace("_", " ") + " - " + unit['title'][
41 | unit['title'].index(":") + 1:].strip(),
42 | "section": section_name,
43 | "unit": lesson_title,
44 | }
45 | sample["prompt"] = TEMPLATE.substitute(previous_sections=previous_subunits_text,
46 | previous_sub_units=previous_lessons_text, **sample)
47 | textbooks.append(sample)
48 |
49 | pd.DataFrame(textbooks).to_csv(f"khanacademy_prompts")
50 |
--------------------------------------------------------------------------------
/prompts/khanacademy/khan_dl/khan_dl.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/rand-net/khan-dl
2 |
3 | import logging
4 | import os
5 | import platform
6 | import sys
7 | from typing import List, Tuple
8 |
9 | import requests
10 | from bs4 import BeautifulSoup
11 | from prompt_toolkit import prompt
12 | from prompt_toolkit.completion import FuzzyWordCompleter
13 |
14 | VIDEO_SITE_URL = "https://www.youtube.com/watch?v="
15 | ROOT_URL = "https://www.khanacademy.org"
16 | DOMAINS = [
17 | "math",
18 | "science",
19 | "computing",
20 | "humanities",
21 | "economics-finance-domain",
22 | "ela",
23 | ]
24 |
25 | # Tags and attributes for parsing HTML
26 |
27 | COURSE_HEAD = {"tag": "h2", "class": "_t2uf76"}
28 | COURSE_URL = {"tag": "a", "class": "_dwmetq"}
29 | COURSE_TITLE = {"data-test-id": "course-unit-title"}
30 | COURSE_UNIT_TITLE = {"data-test-id": "unit-header"}
31 | COURSE_SUBUNIT_TITLE_ATTRS = {"data-test-id": "lesson-card-link"}
32 | COURSE_SUBUNIT_BODY = {"tag": "ul", "class": "_37mhyh"}
33 | COURSE_LESSON_BODY = {"tag": "div", "class_i": "_10ct3cvu", "class_ii": "_1p9458yw"}
34 | COURSE_LESSON_SPAN = {"tag": "span", "class": "_e296pg"}
35 | COURSE_LESSON_LABEL = "aria-label"
36 | COURSE_LESSON_TITLE = {"tag": "span", "class": "_14hvi6g8"}
37 |
38 | """
39 |
40 | Domain -> Course -> Unit Page -> Subunit Header + Subunit Block -> Lesson Block -> Lesson Title
41 |
42 | """
43 |
44 |
45 | def clear_screen():
46 | if platform.system() == "Linux" or platform.system() == "Darwin":
47 | os.system("clear")
48 | elif platform.system() == "Windows":
49 | os.system("cls")
50 |
51 |
52 | # Youtube-dl NoLogger
53 | class MyLogger(object):
54 | def debug(self, msg):
55 | pass
56 |
57 | def warning(self, msg):
58 | pass
59 |
60 | def error(self, msg):
61 | pass
62 |
63 |
64 | class KhanDL:
65 | def __init__(self):
66 | self.domain = ""
67 | self.course_url = ""
68 | self.course_title = ""
69 | self.course_page = ""
70 | self.course_unit_titles = []
71 | self.course_unit_slugs = []
72 | self.course_unit_urls = []
73 | self.course_all_slugs = []
74 | self.lesson_titles = []
75 | self.lesson_youtube_ids = []
76 | self.output_rel_path = os.getcwd() + "/"
77 | self.unit_ids_counter = {}
78 | self.unit_slugs_counter = {}
79 | self.nested_courses = []
80 | self.course_subunits = []
81 | self.selected_course = ""
82 |
83 | def get_courses(self, selected_domain_url: str) -> Tuple[List[str], List[str]]:
84 | """Returns the list of courses on a domain"""
85 |
86 | courses, courses_url = [], []
87 | print("\nDownloading Courses...\n")
88 | try:
89 | selected_domain_page = BeautifulSoup(
90 | requests.get(selected_domain_url).text, "lxml"
91 | )
92 | except requests.ConnectionError as e:
93 | print("Error Connecting!\n", e)
94 | sys.exit(1)
95 | except requests.exceptions.HTTPError as errh:
96 | print("Http Error:", errh)
97 | sys.exit(1)
98 | except requests.exceptions.ConnectionError as errc:
99 | print("Error Connecting:", errc)
100 | sys.exit(1)
101 | except requests.exceptions.Timeout as errt:
102 | print("Timeout Error:", errt)
103 | sys.exit(1)
104 | except requests.exceptions.RequestException as err:
105 | print("OOps: Something Else", err)
106 | sys.exit(1)
107 |
108 | for course_header in selected_domain_page.find_all(
109 | COURSE_HEAD["tag"], class_=COURSE_HEAD["class"]
110 | ):
111 | course = course_header.find(
112 | COURSE_URL["tag"], class_=COURSE_URL["class"]
113 | ).text
114 | courses.append(course)
115 |
116 | course_link = course_header.find(
117 | COURSE_URL["tag"], class_=COURSE_URL["class"]
118 | )
119 | course_slug = course_link["href"]
120 | courses_url.append(ROOT_URL + course_slug)
121 | return courses, courses_url
122 |
123 | def domain_prompt(self):
124 | """Returns the selected domain"""
125 |
126 | # Domain selection prompt
127 | domain_completer = FuzzyWordCompleter(
128 | list(map(str.title, DOMAINS))
129 | ) # Titlecase for aesthetics
130 | selected_domain = DOMAINS.index(
131 | prompt("Domain: ", completer=domain_completer).lower()
132 | )
133 |
134 | print("Selected Domain: {}".format(DOMAINS[selected_domain]))
135 | self.domain = DOMAINS[selected_domain]
136 | logging.info("Domain Selected")
137 |
138 | def course_prompt(self):
139 | """Returns URL for the selected course"""
140 |
141 | selected_domain_url = ROOT_URL + "/" + self.domain
142 | courses, courses_url = self.get_courses(selected_domain_url)
143 |
144 | # Course Selection Prompt
145 |
146 | logging.debug(courses)
147 | courses_completer = FuzzyWordCompleter(courses)
148 | selected_course_index = courses.index(
149 | prompt("Course: ", completer=courses_completer)
150 | )
151 | self.selected_course = courses[selected_course_index]
152 | print("Selected Course: {}".format(self.selected_course))
153 | self.course_url = courses_url[selected_course_index]
154 | logging.info("Course Selected")
155 |
156 | def get_all_courses(self) -> List[str]:
157 | """Returns URL for all courses"""
158 |
159 | print("Downloading all Courses from all Domains...")
160 | all_courses_url = []
161 | for domain in DOMAINS:
162 | print("Selected Domain: ", domain)
163 | selected_domain_url = ROOT_URL + "/" + domain
164 | courses, courses_url = self.get_courses(selected_domain_url)
165 | all_courses_url += courses_url
166 |
167 | return all_courses_url
168 |
169 | def get_course_page(self):
170 | """Retrieves course page html"""
171 |
172 | print("Course URL: {}".format(self.course_url))
173 | try:
174 | self.course_page = BeautifulSoup(requests.get(self.course_url).text, "lxml")
175 | except requests.ConnectionError as e:
176 | print("Error Connecting!\n", e)
177 | sys.exit(1)
178 | except requests.exceptions.HTTPError as errh:
179 | print("Http Error:", errh)
180 | sys.exit(1)
181 | except requests.exceptions.ConnectionError as errc:
182 | print("Error Connecting:", errc)
183 | sys.exit(1)
184 | except requests.exceptions.Timeout as errt:
185 | print("Timeout Error:", errt)
186 | sys.exit(1)
187 | except requests.exceptions.RequestException as err:
188 | print("Oops: Something Else", err)
189 | sys.exit(1)
190 |
191 | def get_course_title(self):
192 | """Retrieves the course title"""
193 |
194 | course_title = self.course_page.find(attrs=COURSE_TITLE)
195 | if course_title and course_title.text:
196 | self.course_title = course_title.text.replace(" ", "_")
197 | logging.debug("course_title:{}".format(self.course_title))
198 | logging.info("Course title retrieved")
199 |
200 | def get_course_unit_titles(self):
201 | """Retrieves course unit titles"""
202 | self.course_unit_titles = []
203 | for title in self.course_page.find_all(attrs=COURSE_UNIT_TITLE):
204 | if "unit" in str(title.text).lower():
205 | self.course_unit_titles.append(title.text)
206 | logging.debug("course_unit_titles:{}".format(self.course_unit_titles))
207 | logging.info("Course unit titles retrieved")
208 |
209 | def get_course_unit_slugs(self):
210 | """Retrieves course unit slugs"""
211 | self.course_unit_slugs = []
212 | counter = 0
213 | for title in self.course_unit_titles:
214 | self.course_unit_slugs.append(
215 | self.course_title + "/" + str(counter) + "_" + title.replace(" ", "_")
216 | )
217 | counter += 1
218 | logging.debug("course_unit_slugs:{}".format(self.course_unit_slugs))
219 | logging.info("Course unit slugs generated")
220 |
221 | def get_course_unit_urls(self):
222 | """Retrieves course unit urls"""
223 | self.course_unit_urls = []
224 | self.nested_courses = []
225 | for url in self.course_page.find_all(attrs=COURSE_UNIT_TITLE):
226 | if int(url["href"].count("/")) > 2:
227 | self.course_unit_urls.append(url["href"])
228 | else:
229 | self.nested_courses.append(url["href"])
230 | logging.debug("course_unit_urls:{}".format(self.course_unit_urls))
231 | logging.debug("nested_courses:{}".format(self.nested_courses))
232 | logging.info("Course unit urls retrieved")
233 |
234 | def get_course_all_slugs(self):
235 | """Generate slugs for all units"""
236 |
237 | unit_lessons_counter = 0
238 | # Unit Page -> Subunit Header + Subunit Block -> Lesson Block -> Lesson Title
239 | for course_unit_url, course_unit_slug, course_unit_title in zip(
240 | self.course_unit_urls, self.course_unit_slugs, self.course_unit_titles
241 | ):
242 | unit_lessons_counter = 0
243 | # -> Unit Page
244 | try:
245 | course_unit_page = BeautifulSoup(
246 | requests.get(ROOT_URL + course_unit_url).text, "lxml"
247 | )
248 | except requests.ConnectionError as e:
249 | print("Error Connecting!\n", e)
250 | sys.exit(1)
251 | except requests.exceptions.HTTPError as errh:
252 | print("Http Error:", errh)
253 | sys.exit(1)
254 | except requests.exceptions.ConnectionError as errc:
255 | print("Error Connecting:", errc)
256 | sys.exit(1)
257 | except requests.exceptions.Timeout as errt:
258 | print("Timeout Error:", errt)
259 | sys.exit(1)
260 | except requests.exceptions.RequestException as err:
261 | print("OOps: Something Else", err)
262 | sys.exit(1)
263 |
264 | subunit_couter = 0
265 |
266 | subunits = []
267 | # -> Subunit Header -> Subunit Block
268 | for course_subunit_title, course_subunit_body in zip(
269 | course_unit_page.find_all(attrs=COURSE_SUBUNIT_TITLE_ATTRS),
270 | course_unit_page.find_all(
271 | COURSE_SUBUNIT_BODY["tag"], class_=COURSE_SUBUNIT_BODY["class"]
272 | ),
273 | ):
274 |
275 | logging.debug("course_subunit_title:{}".format(course_subunit_title))
276 | lesson_counter = 0
277 | # -> Lesson Block
278 | lessons = []
279 | for course_lesson_body in course_subunit_body.find_all(
280 | COURSE_LESSON_BODY["tag"],
281 | {
282 | "class": [
283 | COURSE_LESSON_BODY["class_i"],
284 | COURSE_LESSON_BODY["class_ii"],
285 | ]
286 | },
287 | ):
288 | course_lesson_span = course_lesson_body.find_all(
289 | COURSE_LESSON_SPAN["tag"], class_=COURSE_LESSON_SPAN["class"]
290 | )
291 | course_lesson_aria_label = course_lesson_span[0][
292 | COURSE_LESSON_LABEL
293 | ]
294 | logging.debug(
295 | "course_lesson_aria_label:{}".format(course_lesson_aria_label)
296 | )
297 | # -> Lesson Title
298 | # Check whether lesson block is a video
299 | if course_lesson_aria_label == "Video":
300 | lesson_title = course_lesson_body.find(
301 | COURSE_LESSON_TITLE["tag"],
302 | class_=COURSE_LESSON_TITLE["class"],
303 | )
304 |
305 | logging.debug(
306 | "course_lesson_title:{}".format(lesson_title.text)
307 | )
308 | lessons.append(lesson_title.text.strip())
309 | self.lesson_titles.append(lesson_title.text)
310 | self.course_all_slugs.append(
311 | self.output_rel_path
312 | + course_unit_slug
313 | + "/"
314 | + str(subunit_couter)
315 | + "_"
316 | + course_subunit_title.text.replace(" ", "_")
317 | + "/"
318 | + str(lesson_counter)
319 | + "_"
320 | + lesson_title.text.replace(" ", "_")
321 | )
322 |
323 | lesson_counter += 1
324 | unit_lessons_counter += lesson_counter
325 | subunit_couter += 1
326 | subunits.append({
327 | "title": course_subunit_title.text.strip(),
328 | "lessons": lessons
329 | })
330 | self.course_subunits.append({
331 | "title": course_unit_title,
332 | "subunits": subunits
333 | })
334 | self.unit_slugs_counter[course_unit_url] = unit_lessons_counter
335 |
336 | logging.info(len(self.course_all_slugs))
337 | logging.info("Course - All slugs generated")
338 |
339 | def get_course_youtube_ids(self):
340 | """Retrieves youtube id per unit"""
341 | #
342 | # with ProgressBar() as pb:
343 | # for i, unit_url in zip(
344 | # pb(range(len(self.course_unit_urls)), label="Collecting Youtube IDs:"),
345 | # self.course_unit_urls,
346 | # ):
347 | # unit_url = ROOT_URL + unit_url
348 | # yt_dlp_opts = {
349 | # "logger": MyLogger(),
350 | # "retries": 20,
351 | # "ignoreerrors:": True,
352 | # "skip_download": True,
353 | # }
354 | # with yt_dlp.YoutubeDL(yt_dlp_opts) as ydl:
355 | # lessons_counter = 0
356 | # try:
357 | # logging.debug(
358 | # "Collecting youtube ids for unit:{}".format(unit_url)
359 | # )
360 | # info_dict = ydl.extract_info(unit_url, download=False)
361 | # for video in info_dict["entries"]:
362 | # video_id = video.get("id", None)
363 | # self.lesson_youtube_ids.append(video_id)
364 | # lessons_counter += 1
365 | # except DownloadError as e:
366 | # logging.debug(
367 | # "Collecting youtube ids for unit:{}".format(unit_url)
368 | # )
369 | # info_dict = ydl.extract_info(
370 | # unit_url, download=False, process=False
371 | # )
372 | # for video in info_dict["entries"]:
373 | # video_id = video.get("url", None)
374 | # self.lesson_youtube_ids.append(video_id)
375 | # lessons_counter += 1
376 | # except Exception as e:
377 | # print("Youtube-dl: An error occured!", e)
378 | # sys.exit(1)
379 | #
380 | # self.unit_ids_counter[unit_url] = lessons_counter
381 | #
382 | # logging.info(self.lesson_youtube_ids)
383 | # logging.info(len(self.lesson_youtube_ids))
384 | # logging.info("Course - Collected Youtube IDs")
385 |
386 | def download_course_videos(self):
387 | """Downloads Course Videos"""
388 | #
389 | # counter = 0
390 | # number_of_videos = len(self.course_all_slugs)
391 | #
392 | # with ProgressBar() as pb:
393 | # for i, lesson_output_file, lesson_video_id in zip(
394 | # pb(range(len(self.lesson_youtube_ids)), label="Downloading Videos:"),
395 | # self.course_all_slugs,
396 | # self.lesson_youtube_ids,
397 | # ):
398 | # lesson_youtube_url = VIDEO_SITE_URL + lesson_video_id
399 | #
400 | # yt_dlp_opts = {
401 | # "logger": MyLogger(),
402 | # "outtmpl": lesson_output_file,
403 | # "retries": 20,
404 | # }
405 | #
406 | # with yt_dlp.YoutubeDL(yt_dlp_opts) as ydl:
407 | # logging.debug(
408 | # "Downloading video[{}] {} of {}:".format(
409 | # lesson_youtube_url, counter, number_of_videos
410 | # )
411 | # )
412 | # try:
413 | # ydl.download([lesson_youtube_url])
414 | # counter += 1
415 | # except DownloadError:
416 | # error_log = open("error_private_videos.txt", "a")
417 | # error_log.write(
418 | # str(
419 | # lesson_output_file
420 | # + ", "
421 | # + VIDEO_SITE_URL
422 | # + lesson_video_id
423 | # )
424 | # )
425 | # error_log.close()
426 | # except Exception as e:
427 | # print("Youtube-dl: An error occured!", e)
428 | # sys.exit(1)
429 | # logging.info(
430 | # "Course lesson video[{}]downloaded".format(lesson_video_id)
431 | # )
432 | # logging.info("All course videos downloaded")
433 |
434 | def reset_course(self):
435 | self.domain = ""
436 | self.course_url = ""
437 | self.course_title = ""
438 | self.course_page = ""
439 | self.course_unit_titles = []
440 | self.course_unit_slugs = []
441 | self.course_unit_urls = []
442 | self.course_all_slugs = []
443 | self.lesson_titles = []
444 | self.lesson_youtube_ids = []
445 | self.unit_ids_counter = {}
446 | self.unit_slugs_counter = {}
447 | self.selected_course = ""
448 | self.course_subunits = []
449 |
450 | def download_nested_courses(self):
451 | self.reset_course()
452 | if self.nested_courses:
453 | print("\nDownloading nested courses...\n")
454 | for nested_course_url in self.nested_courses:
455 | self.download_course_given(ROOT_URL + nested_course_url)
456 |
457 | def download_course_interactive(self):
458 | """Downloads the chosen course"""
459 | self.domain_prompt()
460 | self.course_prompt()
461 | self.get_course_page()
462 | self.get_course_title()
463 | self.get_course_unit_titles()
464 | self.get_course_unit_slugs()
465 | self.get_course_unit_urls()
466 |
467 | print("\nGenerating Path Slugs...\n")
468 | self.get_course_all_slugs()
469 | self.get_course_youtube_ids()
470 | self.download_course_videos()
471 | self.download_nested_courses()
472 |
473 | def download_course_given(self, course_url: str):
474 | """Downloads the given course"""
475 | self.reset_course()
476 | self.course_url = course_url
477 | self.get_course_page()
478 | self.get_course_title()
479 | self.get_course_unit_titles()
480 | self.get_course_unit_slugs()
481 | self.get_course_unit_urls()
482 |
483 | self.get_course_all_slugs()
484 | return {
485 | "domain": self.domain,
486 | "url": self.course_url,
487 | "title": self.course_title,
488 | # "page": self.course_page,
489 | "unit_titles": self.course_unit_titles,
490 | # "unit_slugs": self.course_unit_slugs,
491 | # "unit_urls": self.course_unit_urls,
492 | # "all_slugs": self.course_all_slugs,
493 | # "lesson_titles": self.lesson_titles,
494 | "subunits": self.course_subunits
495 | }
496 |
497 | print("\nGenerating Path Slugs...\n")
498 | # self.get_course_youtube_ids()
499 | # self.download_course_videos()
500 |
--------------------------------------------------------------------------------
/prompts/khanacademy/khan_dl/main.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/rand-net/khan-dl
2 |
3 | import json
4 | import logging.handlers
5 |
6 | from tqdm import tqdm
7 |
8 | from khan_dl import *
9 | import argparse
10 | import sys
11 | from art import tprint
12 |
13 | __version__ = "1.2.8"
14 |
15 |
16 | def set_log_level(args):
17 | if not args.verbose:
18 | logging.basicConfig(level=logging.ERROR)
19 | elif int(args.verbose) == 1:
20 | logging.basicConfig(level=logging.WARNING)
21 | elif int(args.verbose) == 2:
22 | logging.basicConfig(level=logging.INFO)
23 | elif int(args.verbose) >= 3:
24 | logging.basicConfig(level=logging.DEBUG)
25 |
26 |
27 | def main(argv=None):
28 | argv = sys.argv if argv is None else argv
29 | argparser = argparse.ArgumentParser()
30 | argparser.add_argument(
31 | "-i",
32 | "--interactive",
33 | help="Enter Interactive Course Selection Mode",
34 | dest="interactive_prompt",
35 | action="store_true",
36 | )
37 | argparser.add_argument(
38 | "-c",
39 | "--course_url",
40 | help="Enter Course URL",
41 | )
42 |
43 | argparser.add_argument(
44 | "-a",
45 | "--all",
46 | help="Download all Courses from all Domains",
47 | action="store_true",
48 | )
49 |
50 | argparser.add_argument(
51 | "-v",
52 | "--verbose",
53 | help="Verbose Levels of log. 1 = Warning; 2 = Info; 3 = Debug",
54 | )
55 |
56 | args = argparser.parse_args()
57 |
58 | if args.interactive_prompt:
59 | set_log_level(args)
60 | tprint("KHAN-DL")
61 | khan_down = KhanDL()
62 | khan_down.download_course_interactive()
63 |
64 | elif args.course_url:
65 | set_log_level(args)
66 | tprint("KHAN-DL")
67 | print("Looking up " + args.course_url + "...")
68 | selected_course_url = args.course_url
69 | khan_down = KhanDL()
70 | khan_down.download_course_given(selected_course_url)
71 |
72 | elif args.all:
73 | set_log_level(args)
74 | tprint("KHAN-DL")
75 | khan_down = KhanDL()
76 | all_course_urls = khan_down.get_all_courses()
77 | courses = [khan_down.download_course_given(course_url) for course_url in tqdm(all_course_urls)]
78 | with open("khan_courses.json", "w") as outfile:
79 | outfile.write(json.dumps(courses, indent=4))
80 |
81 |
82 | if __name__ == "__main__":
83 | main()
84 |
--------------------------------------------------------------------------------
/prompts/khanacademy/khan_dl/requirements.txt:
--------------------------------------------------------------------------------
1 | art==5.5
2 | beautifulsoup4==4.11.1
3 | certifi==2021.10.8
4 | charset-normalizer==2.0.12
5 | idna==3.3
6 | lxml==4.8.0
7 | prompt-toolkit==3.0.29
8 | requests==2.27.1
9 | soupsieve==2.3.2.post1
10 | urllib3==1.26.9
11 | wcwidth==0.2.5
12 | yt-dlp==2022.5.18
13 | tqdm>=4.66.2
--------------------------------------------------------------------------------
/prompts/openstax/README.md:
--------------------------------------------------------------------------------
1 | ## Synthetic generations from OpenStax
2 |
3 | To build prompts from the OpenStax, the dataset with the course outline and introductions is avilable [here](https://huggingface.co/datasets/HuggingFaceTB/openstax_paragraphs). We generate textbooks for 4 different audiences: young children, middle school students, professionals and researchers and college students. Each prompt was carefully tailored based on the target audience.
4 | ````
5 | python ./build_openstax_prompts.py
6 | ```
7 |
--------------------------------------------------------------------------------
/prompts/openstax/build_openstax_prompts.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import random
3 | import argparse
4 | import numpy as np
5 | from datasets import Dataset, load_dataset, concatenate_datasets
6 |
7 |
8 | STYLES = {"young children":
9 | {"beginning": "Create a fun and simple e-learning module on {{X}}, tailored for 5 to 10 year-old children. Opt for a playful and imaginative approach, suitable for very young learners.\n",
10 | "criteria":"""In this module for young children, aim to:
11 |
12 | - Use very simple, everyday words and phrases that a 5-year-old would easily understand, avoiding any complex concepts or technical terms.
13 | - Tell a short, engaging story with colorful cartoon characters. For instance, to illustrate economic trade concepts use characters like animals or friendly creatures trading snacks or toys. Another example is addition and calculus, use apples to explain: '2 apples + 3 apples = 5 apples' .
14 | - Keep the tone light, cheerful, and encouraging. Do not use images."""},
15 |
16 | "middle school students":
17 | {"beginning": "Create an engaging and accessible e-learning module on {{X}}, tailored for middle school students without prior knowledge on the topic.\n",
18 | "criteria": """Instead of a traditional textbook approach, use a story-based narrative to explain the concept. Try to:
19 |
20 | - Avoid technical jargon and present the ideas in a straightforward, conversational tone to spark curiosity and relate to the experiences of a younger audience.
21 | - Include interactive elements like thought experiments and real-life scenarios. The goal is to topic approachable and fun, sparking curiosity about how it applies to everyday life.
22 | - Do not use introductory phrases such as "welcome to this unit" at the beginning or conclusions the end. Do not use images."""},
23 |
24 | "professionals and researchers":
25 | {"beginning": "Create an extract of a scientific journal article for {{X}}, tailored for professionals and researchers on the topic.\n",
26 | "criteria": """The style should mirror that of a scholarly publication, not school textbooks, aiming to engage a highly knowledgeable audience with very deep expertise. Try to:
27 |
28 | - Present advanced theories, using technical and academic language.
29 | - Include critical analysis of recent research findings and debates in the field, with a detailed examination of empirical data and statistical methodologies.
30 | - The article should reflect the depth and complexity of content found in top-tier economics journals, intended for a readership deeply entrenched in the field.
31 | - Do not add come up with references or add them at the end of the article. If there are mathematical expressions use a correct LateX formatting and do not use images."""},
32 |
33 | "college students":
34 | {"beginning": "Write a comprehensive and in-depth textbook on {{X}}, tailored for college students.\n",
35 | "criteria": """Try to be:
36 |
37 | - Rigorous: Ensure very detailed and in-depth coverage of the concepts.
38 | - Engaging: Write with an academic and engaging tone that captivates interest.
39 | - Applied: Use specific and practical examples. For example, if the topic is integration in calculus, include equations and proofs of the concept you're teaching. As another example, if the topic is the history of the United States, include dates, names, and key events.
40 | If there are mathematical expressions use a correct LateX formatting. Do not use images and avoid introductory phrases such as "welcome to this unit" at the beginning or conclusions the end."""}}
41 |
42 |
43 | def get_args():
44 | parser = argparse.ArgumentParser()
45 | parser.add_argument("--repo_id", type=str, default="HuggingFaceTB/openstax_prompts")
46 | return parser.parse_args()
47 |
48 |
49 | def parse_chapter(chapter, level=0, trail=[]):
50 | """Parse each chapter recursively and take into account sections"""
51 | trail_current = trail + [chapter["title"]]
52 | chapter_info = {
53 | "title": chapter["title"],
54 | "level": level,
55 | "trail": trail_current,
56 | "abstract": chapter.get("abstract", ""),
57 | "sections": [],
58 | "sub_chapters": [],
59 | }
60 |
61 | if chapter.get("sections"):
62 | for section in chapter.get("sections"):
63 | chapter_info["sections"].append(
64 | {
65 | "title": section["title"],
66 | "content": section.get("paragraph", ""),
67 | "trail": trail_current,
68 | "abstract": section.get("abstract", ""),
69 | }
70 | )
71 |
72 | # Handle sub-chapters recursively
73 | if chapter.get("chapters"):
74 | for sub_chapter in chapter.get("chapters"):
75 | chapter_info["sub_chapters"].append(
76 | parse_chapter(sub_chapter, level + 1, trail_current)
77 | )
78 |
79 | return chapter_info
80 |
81 |
82 | def parse_book(book):
83 | """Parse and rearrange a book"""
84 | book_info = {"title": book["book_title"], "chapters": []}
85 |
86 | for chapter in book["chapters"]:
87 | if "preface" in chapter["title"].lower():
88 | continue
89 | book_info["chapters"].append(parse_chapter(chapter))
90 |
91 | return book_info
92 |
93 |
94 | def build_prompts(
95 | parsed_book, style="college students", include_reference=True, refrence_size=500
96 | ):
97 | """Build prompts based on the (deepest) sections in each book"""
98 | prompts = []
99 | target_units = []
100 | chapters = [chap["title"] for chap in parsed_book["chapters"]]
101 | chosen_style = STYLES[style]
102 | start_prompt = chosen_style["beginning"]
103 | end_prompt = "\n" + chosen_style["criteria"]
104 | empty_content = 0
105 |
106 | for i, chapter in enumerate(parsed_book["chapters"]):
107 | chapter_prompt = start_prompt.replace("{{X}}", f"'{parsed_book['title']}'")
108 | chapter_prompt += f"We are writing on chapter '{chapter['title']}'. "
109 |
110 | for subchapter in chapter["sub_chapters"]:
111 | # Iterate over sections in subchapters
112 | if subchapter["sections"]:
113 | subchapter_prompt = (
114 | chapter_prompt + f"In particular, section '{subchapter['title']}'. "
115 | )
116 | for i, unit in enumerate(subchapter["sections"]):
117 | if i != 0:
118 | units = [s["title"] for s in subchapter["sections"][:i]]
119 | units = list(np.random.choice(units, random.randint(2, 5))) if len(units) > 5 else units
120 | prev_units = ", ".join([f"'{name}'" for name in units])
121 | plural = "s" if i > 1 else ""
122 | subchapter_prompt += f"We have already covered the following unit{plural} in this section: {prev_units}. "
123 | if unit["content"] and include_reference:
124 | size = len(unit["content"])
125 | ref = f" Here's some text for inspiration: {unit['content'][:min(refrence_size, size)]}".rstrip(".").rstrip() + "."
126 | else:
127 | empty_content += 1
128 | ref = ""
129 | new_prompt = (
130 | subchapter_prompt
131 | + f"Write a new unit titled '{unit['title']}'.{ref}\n"
132 | )
133 | prompts.append(new_prompt + end_prompt)
134 | target_units.append(unit['title'])
135 | else:
136 | # Handle nested subchapters
137 | for k, e in enumerate(subchapter["sub_chapters"]):
138 | if e["sections"]:
139 | subchapter_prompt = (
140 | chapter_prompt
141 | + f"In particular, section '{e['title']}' of '{subchapter['title']}' part. "
142 | )
143 | for i, unit in enumerate(e["sections"]):
144 | current_prompt = subchapter_prompt
145 | if i != 0:
146 | units = [s["title"] for s in e["sections"][:i]]
147 | units = list(np.random.choice(units, random.randint(2, 5))) if len(units) > 5 else units
148 | prev_units = ", ".join([f"'{name}'" for name in units])
149 | plural = "s" if i > 1 else ""
150 | current_prompt += f"We have already covered the following unit{plural} in this section: {prev_units}. "
151 | if unit["content"] and include_reference:
152 | size = len(unit["content"])
153 | ref = f" Here's some text for inspiration: {unit['content'][:min(refrence_size, size)]}".rstrip(".").rstrip() + "."
154 | else:
155 | empty_content += 1
156 | ref = ""
157 | new_prompt = (
158 | current_prompt
159 | + f"Write a new unit titled '{unit['title']}'.{ref}\n"
160 | )
161 | target_units.append(unit['title'])
162 | prompts.append(new_prompt + end_prompt)
163 | else:
164 | if "introduction" not in e['title'].lower() and e.get("abstract"):
165 | new_prompt = chapter_prompt
166 | if k != 0:
167 | subchapters = [s["title"] for s in subchapter["sub_chapters"][:k]]
168 | subchapters = list(np.random.choice(subchapters, random.randint(2, 5))) if len(subchapters) > 5 else subchapters
169 | prev_subchapters = ", ".join([f"'{name}'" for name in subchapters])
170 | plural = "s" if k > 1 else ""
171 | new_prompt += f"We have already covered the following unit{plural} in this chapter: {prev_subchapters}. "
172 | new_prompt = new_prompt + f"Write a new unit titled {e['title']}."
173 | if include_reference:
174 | size = len(e["abstract"])
175 | new_prompt += f" Here's some text for inspiration: {e['abstract'][:min(refrence_size, size)]}"
176 | else:
177 | empty_content += 1
178 | target_units.append(e['title'])
179 | prompts.append(new_prompt.rstrip('.').rstrip() + '.\n' + end_prompt)
180 | else:
181 | continue
182 | return prompts, target_units, empty_content
183 |
184 |
185 | if __name__ == "__main__":
186 | args = get_args()
187 | include_references = [True, False]
188 | refrence_size = 600
189 | ds = load_dataset("HuggingFaceTB/openstax_paragraphs", split="train")
190 | ds_en = ds.filter(lambda x: x["language"] == "en")
191 |
192 | print(f"English books dataset: {ds_en}")
193 | print("🔍 Parsing books...")
194 | parsed_books = [parse_book(e) for e in ds_en]
195 | datasets_list = []
196 | for include_reference in include_references:
197 | ref = "with_ref" if include_reference else "no_ref"
198 | for style in STYLES:
199 | print(f"🧩 Building prompts for {style} {ref}...")
200 | outputs = [
201 | build_prompts(
202 | book,
203 | style=style,
204 | include_reference=include_reference,
205 | refrence_size=refrence_size,
206 | )
207 | for book in parsed_books
208 | ]
209 | prompts = [p[0] for p in outputs]
210 | target_units = [p[1] for p in outputs]
211 | empty_content = [p[2] for p in outputs]
212 | sizes = [len(p) for p in prompts]
213 | print(
214 | f"✅ Done building {sum(sizes)} prompts! ({sum(empty_content)} without reference text)"
215 | )
216 |
217 | print(f"🌟 Examples:")
218 | print(f"- {prompts[random.randint(0, 5)][0]}")
219 | print(f"\n- {prompts[random.randint(0, 5)][-1]}")
220 |
221 | print("Converting to HF dataset and pushing to Hub...")
222 | flattened_prompts = []
223 | for book_prompts, book_units in zip(prompts, target_units):
224 | for prompt, unit in zip(book_prompts, book_units):
225 | book_title = prompt.split(", tailored for")[0].split("'")[1]
226 | flattened_prompts.append((prompt, unit, book_title))
227 |
228 | df = pd.DataFrame(flattened_prompts, columns=["prompt", "unit", "book title"])
229 | ds = Dataset.from_pandas(df)
230 | audience = "_".join(style.split(" "))
231 | ds = ds.add_column("audience", [f"{audience}_{ref}" for _ in range(len(ds))])
232 | print(ds)
233 | datasets_list.append(ds)
234 |
235 | final_ds = concatenate_datasets(datasets_list)
236 | print(final_ds)
237 | final_ds.push_to_hub(args.repo_id, private=True)
238 |
239 |
240 |
241 |
--------------------------------------------------------------------------------
/prompts/stanford/1_scraper.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "source": [
6 | "Jupyter notebook to scrape stanford's list of courses. Gets the following:\n",
7 | "- course title\n",
8 | "- course description\n",
9 | "- course numbers/ids"
10 | ],
11 | "metadata": {
12 | "collapsed": false
13 | },
14 | "id": "aeed1babdd1ced9b"
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": 2,
19 | "id": "initial_id",
20 | "metadata": {
21 | "collapsed": true,
22 | "ExecuteTime": {
23 | "end_time": "2023-09-27T10:58:34.215404466Z",
24 | "start_time": "2023-09-27T10:58:34.208962606Z"
25 | }
26 | },
27 | "outputs": [],
28 | "source": [
29 | "from time import sleep\n",
30 | "\n",
31 | "from tqdm import tqdm\n",
32 | "\n",
33 | "MAIN_INDEX_URL = \"https://explorecourses.stanford.edu/search?q=all%20courses\""
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": 3,
39 | "outputs": [],
40 | "source": [
41 | "import requests\n",
42 | "from bs4 import BeautifulSoup\n",
43 | "import re"
44 | ],
45 | "metadata": {
46 | "collapsed": false,
47 | "ExecuteTime": {
48 | "end_time": "2023-09-27T10:58:34.720087515Z",
49 | "start_time": "2023-09-27T10:58:34.640417536Z"
50 | }
51 | },
52 | "id": "b4ca6254d01f8dc8"
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": 99,
57 | "outputs": [],
58 | "source": [
59 | "headers = {'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36'}"
60 | ],
61 | "metadata": {
62 | "collapsed": false,
63 | "ExecuteTime": {
64 | "end_time": "2023-09-25T10:03:23.322727830Z",
65 | "start_time": "2023-09-25T10:03:23.319789021Z"
66 | }
67 | },
68 | "id": "cd55f81c3d1d42dd"
69 | },
70 | {
71 | "cell_type": "markdown",
72 | "source": [
73 | "Scrape all courses.\n",
74 | "Change the number of pages based on the footer on https://explorecourses.stanford.edu/search?q=all%20courses"
75 | ],
76 | "metadata": {
77 | "collapsed": false
78 | },
79 | "id": "bbce7731b5ff3ad2"
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": 103,
84 | "outputs": [
85 | {
86 | "name": "stderr",
87 | "output_type": "stream",
88 | "text": [
89 | "100%|██████████| 1541/1541 [37:09<00:00, 1.45s/it]\n"
90 | ]
91 | }
92 | ],
93 | "source": [
94 | "TOTAL_PAGES = 1541\n",
95 | "\n",
96 | "def formatt(text):\n",
97 | " text = text.strip(\"\\r\\n\\t\")\n",
98 | " if text.endswith(\"more »\"):\n",
99 | " text = text[:-6]\n",
100 | " return text.strip(\"\\r\\n\\t\")\n",
101 | "\n",
102 | "all_courses = []\n",
103 | "\n",
104 | "# sadly the api seems to be for students and faculty only\n",
105 | "\n",
106 | "for p in tqdm(range(TOTAL_PAGES), total=TOTAL_PAGES):\n",
107 | " r = requests.get(MAIN_INDEX_URL + f\"&page={p}\", headers=headers)\n",
108 | " soup = BeautifulSoup(r.content)\n",
109 | " courses = [{\n",
110 | " \"number\": x.find(\"span\", {\"class\": 'courseNumber'}).text.rstrip(\":\"),\n",
111 | " \"title\": x.find(\"span\", {\"class\": 'courseTitle'}).text,\n",
112 | " \"description\": formatt(x.find(\"div\", {\"class\": 'courseDescription'}).text),\n",
113 | " } for x in soup.find_all(\"div\", {\"class\": \"courseInfo\"})]\n",
114 | " all_courses.extend(courses)\n",
115 | " # don't spam their servers too much\n",
116 | " sleep(0.5)"
117 | ],
118 | "metadata": {
119 | "collapsed": false,
120 | "ExecuteTime": {
121 | "end_time": "2023-09-25T10:41:21.425942204Z",
122 | "start_time": "2023-09-25T10:04:11.761537359Z"
123 | }
124 | },
125 | "id": "4523f23672936f0a"
126 | },
127 | {
128 | "cell_type": "markdown",
129 | "source": [
130 | "Deduplicate courses.\n",
131 | "Courses listed multiple times with different ids have the other ids inside brackets"
132 | ],
133 | "metadata": {
134 | "collapsed": false
135 | },
136 | "id": "b140124d216e49b3"
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": 127,
141 | "outputs": [],
142 | "source": [
143 | "course_ids = set()\n",
144 | "unique_courses = []\n",
145 | "for course in all_courses:\n",
146 | " # check if we already found a duplicate of this course\n",
147 | " if course[\"number\"] in course_ids:\n",
148 | " continue\n",
149 | " ids = [course[\"number\"]]\n",
150 | " res = re.search(r\"\\((.*?)\\)\", course[\"title\"])\n",
151 | " if res:\n",
152 | " ids.extend(res.group(1).split(\", \"))\n",
153 | " course[\"title\"] = course[\"title\"][:course[\"title\"].rindex(\"(\") - 1] # strip the course ids from the title \"(...\"\n",
154 | " course_ids.update(ids)\n",
155 | " unique_courses.append({\n",
156 | " **course,\n",
157 | " \"number\": \", \".join(ids)\n",
158 | " })"
159 | ],
160 | "metadata": {
161 | "collapsed": false,
162 | "ExecuteTime": {
163 | "end_time": "2023-09-25T11:13:39.243448406Z",
164 | "start_time": "2023-09-25T11:13:39.189180954Z"
165 | }
166 | },
167 | "id": "dc18a648c4fa5d50"
168 | },
169 | {
170 | "cell_type": "code",
171 | "execution_count": 140,
172 | "outputs": [],
173 | "source": [
174 | "import pandas as pd\n",
175 | "df = pd.DataFrame(unique_courses)\n",
176 | "df.to_csv(\"stanford_courses_unique.csv\")"
177 | ],
178 | "metadata": {
179 | "collapsed": false,
180 | "ExecuteTime": {
181 | "end_time": "2023-09-25T11:46:00.367181772Z",
182 | "start_time": "2023-09-25T11:46:00.283586989Z"
183 | }
184 | },
185 | "id": "e8e4338438a8e5ae"
186 | },
187 | {
188 | "cell_type": "markdown",
189 | "source": [
190 | "Clean descriptions:\n",
191 | "- remove urls\n",
192 | "- remove course ids\n",
193 | "- remove \"Continuation of\""
194 | ],
195 | "metadata": {
196 | "collapsed": false
197 | },
198 | "id": "78788009ded5f401"
199 | },
200 | {
201 | "cell_type": "code",
202 | "execution_count": 141,
203 | "outputs": [],
204 | "source": [
205 | "cleaned_courses = []\n",
206 | "for course in unique_courses:\n",
207 | " desc = course[\"description\"]\n",
208 | " # urls\n",
209 | " desc = re.sub('http[s]?://\\S+', '', desc)\n",
210 | " # course names\n",
211 | " desc = re.sub('[A-Z]+ \\d+([A-Z]+)?', '', desc)\n",
212 | " cleaned_courses.append({\n",
213 | " **course,\n",
214 | " \"description\": desc\n",
215 | " })\n",
216 | " "
217 | ],
218 | "metadata": {
219 | "collapsed": false,
220 | "ExecuteTime": {
221 | "end_time": "2023-09-25T12:20:44.838325706Z",
222 | "start_time": "2023-09-25T12:20:44.761095868Z"
223 | }
224 | },
225 | "id": "a59836f513898f95"
226 | },
227 | {
228 | "cell_type": "code",
229 | "execution_count": 4,
230 | "outputs": [],
231 | "source": [
232 | "import pandas as pd"
233 | ],
234 | "metadata": {
235 | "collapsed": false,
236 | "ExecuteTime": {
237 | "end_time": "2023-09-27T10:58:39.102293877Z",
238 | "start_time": "2023-09-27T10:58:38.900613044Z"
239 | }
240 | },
241 | "id": "c2524526f5e1bdf4"
242 | },
243 | {
244 | "cell_type": "code",
245 | "execution_count": 143,
246 | "outputs": [],
247 | "source": [
248 | "df = pd.DataFrame(cleaned_courses)"
249 | ],
250 | "metadata": {
251 | "collapsed": false,
252 | "ExecuteTime": {
253 | "end_time": "2023-09-25T12:20:45.440192733Z",
254 | "start_time": "2023-09-25T12:20:45.438743636Z"
255 | }
256 | },
257 | "id": "ed912eb58e38574c"
258 | },
259 | {
260 | "cell_type": "code",
261 | "execution_count": 144,
262 | "outputs": [],
263 | "source": [
264 | "df.to_csv(\"stanford_courses_cleaned.csv\")"
265 | ],
266 | "metadata": {
267 | "collapsed": false,
268 | "ExecuteTime": {
269 | "end_time": "2023-09-25T12:20:45.777534972Z",
270 | "start_time": "2023-09-25T12:20:45.702459885Z"
271 | }
272 | },
273 | "id": "4d21646ef7ff9593"
274 | },
275 | {
276 | "cell_type": "code",
277 | "execution_count": 5,
278 | "outputs": [],
279 | "source": [
280 | "df = pd.read_csv(\"stanford_courses_cleaned.csv\", dtype=str, na_values='', keep_default_na=False)"
281 | ],
282 | "metadata": {
283 | "collapsed": false,
284 | "ExecuteTime": {
285 | "end_time": "2023-09-27T10:58:44.461354916Z",
286 | "start_time": "2023-09-27T10:58:44.414283703Z"
287 | }
288 | },
289 | "id": "303ce35ad306db4b"
290 | },
291 | {
292 | "cell_type": "markdown",
293 | "source": [
294 | "Some preprocessing to remove generic descriptions"
295 | ],
296 | "metadata": {
297 | "collapsed": false
298 | },
299 | "id": "8ae1b01a983d663f"
300 | },
301 | {
302 | "cell_type": "code",
303 | "execution_count": 6,
304 | "outputs": [],
305 | "source": [
306 | "import string\n",
307 | "\n",
308 | "def detect_generic(text):\n",
309 | " text = text.lower().translate(str.maketrans('', '', string.punctuation))\n",
310 | " if text in (\"tba\", \"tbd\", \"description tbd\"):\n",
311 | " return False\n",
312 | " for x in (\"prerequisite\", \"continuation of\", \"graduation\", \"prior arrangement\", \"consent of instructor\", \"doctoral practicum\", \"may be repeated\", \"required suprvised\", \"program consent required\", \"supervised experience\", \"students must obtain\", \"graduate\", \"research\", \"tutorial in\", \"independent study\", \"for credit\", \"for advanced\"):\n",
313 | " text = text.replace(x, \"\")\n",
314 | " return len(text) < 20"
315 | ],
316 | "metadata": {
317 | "collapsed": false,
318 | "ExecuteTime": {
319 | "end_time": "2023-09-27T10:58:45.230915973Z",
320 | "start_time": "2023-09-27T10:58:45.223450162Z"
321 | }
322 | },
323 | "id": "ced2e95857c5d12b"
324 | },
325 | {
326 | "cell_type": "code",
327 | "execution_count": 7,
328 | "outputs": [],
329 | "source": [
330 | "non_generic_courses = []\n",
331 | "\n",
332 | "for a, b in df.iterrows():\n",
333 | " # no description\n",
334 | " if not isinstance(b[\"description\"], str):\n",
335 | " if len(b[\"title\"]) < 25: # no description + short title = unusable\n",
336 | " continue\n",
337 | " b[\"description\"] = \"TBD\"\n",
338 | " if detect_generic(b[\"description\"]):\n",
339 | " continue\n",
340 | " non_generic_courses.append(b)"
341 | ],
342 | "metadata": {
343 | "collapsed": false,
344 | "ExecuteTime": {
345 | "end_time": "2023-09-27T10:58:46.428860754Z",
346 | "start_time": "2023-09-27T10:58:45.871251538Z"
347 | }
348 | },
349 | "id": "8dbb9403c43e07a9"
350 | },
351 | {
352 | "cell_type": "code",
353 | "execution_count": 10,
354 | "outputs": [],
355 | "source": [
356 | "pd.DataFrame(non_generic_courses).to_csv(\"stanford_courses_cleaned_non_generic.csv\", index=False)"
357 | ],
358 | "metadata": {
359 | "collapsed": false,
360 | "ExecuteTime": {
361 | "end_time": "2023-09-27T11:08:23.018882075Z",
362 | "start_time": "2023-09-27T11:08:20.205486903Z"
363 | }
364 | },
365 | "id": "5a321cbac67d727f"
366 | },
367 | {
368 | "cell_type": "code",
369 | "execution_count": null,
370 | "outputs": [],
371 | "source": [],
372 | "metadata": {
373 | "collapsed": false
374 | },
375 | "id": "b122079355a393a0"
376 | }
377 | ],
378 | "metadata": {
379 | "kernelspec": {
380 | "display_name": "Python 3",
381 | "language": "python",
382 | "name": "python3"
383 | },
384 | "language_info": {
385 | "codemirror_mode": {
386 | "name": "ipython",
387 | "version": 2
388 | },
389 | "file_extension": ".py",
390 | "mimetype": "text/x-python",
391 | "name": "python",
392 | "nbconvert_exporter": "python",
393 | "pygments_lexer": "ipython2",
394 | "version": "2.7.6"
395 | }
396 | },
397 | "nbformat": 4,
398 | "nbformat_minor": 5
399 | }
400 |
--------------------------------------------------------------------------------
/prompts/stanford/2_generate_course_outlines.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "initial_id",
7 | "metadata": {
8 | "collapsed": true
9 | },
10 | "outputs": [],
11 | "source": [
12 | "\n",
13 | "import pandas as pd\n",
14 | "\n",
15 | "df = pd.read_csv(\"stanford_courses_cleaned_non_generic.csv\", dtype=str)"
16 | ]
17 | },
18 | {
19 | "cell_type": "markdown",
20 | "source": [
21 | "1-shot generation of course outlines from the title and description"
22 | ],
23 | "metadata": {
24 | "collapsed": false
25 | },
26 | "id": "255265fa0caac0da"
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": null,
31 | "outputs": [],
32 | "source": [
33 | "from string import Template\n",
34 | "\n",
35 | "OUTLINE_TEMPLATE = Template(\"\"\"Write a course outline for a textbook on \\\"The Global Positioning System: Where on Earth are We, and What Time is It?\\\" covering the following topics: \\\"Why people want to know where they are: answers include cross-Pacific trips of Polynesians, missile guidance, and distraught callers. How people determine where they are: navigation technology from dead-reckoning, sextants, and satellite navigation (GPS). Hands-on experience. How GPS works; when it does not work; possibilities for improving performance.\\\".\n",
36 | "Model: 1. Introduction\n",
37 | "- What is the Global Positioning System?\n",
38 | "- Importance of GPS\n",
39 | "- Overview of the course\n",
40 | "\n",
41 | "2. Navigation technology\n",
42 | "- Dead-reckoning\n",
43 | "- Sextants\n",
44 | "- Satellite navigation\n",
45 | "- Comparison of technologies\n",
46 | "- Hands-on experience with navigation technology\n",
47 | "\n",
48 | "3. GPS technology\n",
49 | "- How GPS works\n",
50 | " - Satellites\n",
51 | " - Ground receivers\n",
52 | " - Triangulation\n",
53 | "- When GPS does not work\n",
54 | " - Blockage\n",
55 | " - Multipath\n",
56 | "- Possibilities for improving performance\n",
57 | "\n",
58 | "4. Applications of GPS\n",
59 | "- Cross-Pacific trips of Polynesians\n",
60 | "- Missile guidance\n",
61 | "- Distraught callers\n",
62 | "- Other applications of GPS\n",
63 | "\n",
64 | "User: Write a course outline for a textbook on \\\"${COURSE_TITLE}\\\" covering the following topics: \\\"${COURSE_DESCRIPTION}\\\". Do not include assignments, exams or prerequisites.\n",
65 | "Model: \"\"\")"
66 | ],
67 | "metadata": {
68 | "collapsed": false
69 | },
70 | "id": "52c53c272dffa9d2"
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": null,
75 | "outputs": [],
76 | "source": [
77 | "courses_to_generate = []\n",
78 | "for a, b in df.iterrows():\n",
79 | " prompt = OUTLINE_TEMPLATE.substitute({\"COURSE_TITLE\": b[\"title\"], \"COURSE_DESCRIPTION\": b[\"description\"]})\n",
80 | " courses_to_generate.append({\n",
81 | " \"course_title\": b[\"title\"],\n",
82 | " \"course_description\": b[\"description\"],\n",
83 | " \"prompt\": prompt,\n",
84 | " })"
85 | ],
86 | "metadata": {
87 | "collapsed": false
88 | },
89 | "id": "de59641276702b38"
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": null,
94 | "outputs": [],
95 | "source": [
96 | "generations = [...] # code to generate using the prompts here"
97 | ],
98 | "metadata": {
99 | "collapsed": false
100 | },
101 | "id": "29965888236f2bdd"
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": null,
106 | "outputs": [],
107 | "source": [
108 | "for course, generation in zip(courses_to_generate, generations):\n",
109 | " course[\"outline\"] = generation\n",
110 | "\n",
111 | "pd.DataFrame(courses_to_generate).to_csv(\"outlines_full.csv\")"
112 | ],
113 | "metadata": {
114 | "collapsed": false
115 | },
116 | "id": "55c5350ab00b5d54"
117 | },
118 | {
119 | "cell_type": "markdown",
120 | "source": [
121 | "(very large) 2-shot prompt to have the model correct and clean up the generated outlines"
122 | ],
123 | "metadata": {
124 | "collapsed": false
125 | },
126 | "id": "c0d4dd91d694584f"
127 | },
128 | {
129 | "cell_type": "code",
130 | "execution_count": null,
131 | "outputs": [],
132 | "source": [
133 | "\n",
134 | "OUTLINE_FILTER_TEMPLATE = Template(\"\"\"The following is a course outline for a course on \\\"Anesthesia Operating Room Clerkship\\\". This outline needs to be anonymized and adapted to an online audience:\n",
135 | "1.1 Introduction: Overview of the Anesthesia Operating Room Clerkship\n",
136 | "1.2 Introduction: Objectives of the clerkship\n",
137 | "1.3 Introduction: Prerequisites for the clerkship\n",
138 | "2.1 Clinical settings: Sequoia Hospital in Redwood City\n",
139 | "2.2 Clinical settings: Outpatient surgery centers throughout the community\n",
140 | "2.3 Clinical settings: Exposure to general and regional anesthetic techniques\n",
141 | "2.4 Clinical settings: Adult and pediatric patients\n",
142 | "3.1 Personalized discussion: Applied physiology\n",
143 | "3.2 Personalized discussion: Pharmacology\n",
144 | "3.3 Personalized discussion: Pathophysiology of the surgical patient\n",
145 | "3.4 Personalized discussion: Daily basis\n",
146 | "3.5 Personalized discussion: Final paper to be submitted by the students\n",
147 | "4.1 Transportation: Students need to arrange transportation to the various workplaces\n",
148 | "5.1 Prerequisites: A major clerkship in medicine or surgery is strongly recommended\n",
149 | "6.1 Periods available: 1-12, full-time for 2 weeks\n",
150 | "6.2 Periods available: 1 student per period\n",
151 | "7.1 Clerkship director and coordinator: Kurt Fink, M.D.\n",
152 | "7.2 Clerkship director and coordinator: Yun Tao, 650-724-1706, yuntao@stanford.edu, Stanford Hospital\n",
153 | "8.1 Reporting instructions: Contact Dr. Kurt Fink one week prior\n",
154 | "8.2 Reporting instructions: Time: TBA\n",
155 | "8.3 Reporting instructions: Call code: 0\n",
156 | "9.1 Other faculty: Palo Alto Medical Clinic Anesthesiologist\n",
157 | "10.1 Location: Palo Alto Medical Foundation.\n",
158 | "\n",
159 | "Which of the sections of the outline contain: \n",
160 | "- private faculty members information (names or contact information)\n",
161 | "- prerequisites, requirements, application processes or other practical course information not related to the course content\n",
162 | "- assignments, final papers, exams, presentations or other student evaluation information\n",
163 | "Falcon:\n",
164 | "- private faculty members information (names or contact information): 7.1, 7.2., 8.1, 9.1\n",
165 | "- prerequisites, requirements, application processes, schedules or other practical course information not related to the course content: 1.3, 4.1, 5.1, 6.1, 6.2, 8.1, 8.2, 8.3, 10.1\n",
166 | "- assignments, final papers, exams, presentations or other student evaluation information: 3.5\n",
167 | "User: The following is a course outline for a course on \"Numerical Methods for Compressible Flows\". This outline needs to be anonymized and adapted to an online audience:\n",
168 | "1.1 Introduction: Overview of the course\n",
169 | "1.2 Introduction: Importance of numerical methods for compressible flows\n",
170 | "1.3 Introduction: Prerequisites for the course\n",
171 | "2.1 Mathematical models for compressible flows: Hierarchy of mathematical models\n",
172 | "2.2 Mathematical models for compressible flows: Ideal potential flow\n",
173 | "2.3 Mathematical models for compressible flows: Transonic potential flow\n",
174 | "3.1 Numerical methods for compressible flows: Finite difference methods\n",
175 | "3.2 Numerical methods for compressible flows: Finite volume methods\n",
176 | "3.3 Numerical methods for compressible flows: Finite element methods\n",
177 | "4.1 Representative model problems: Shocks\n",
178 | "4.2 Representative model problems: Expansions\n",
179 | "5.1 Treatment of boundary conditions: Dirichlet boundary conditions\n",
180 | "5.2 Treatment of boundary conditions: Neumann boundary conditions\n",
181 | "6.1 Applications of numerical methods for compressible flows: Aerospace engineering\n",
182 | "6.3 Applications of numerical methods for compressible flows: Other applications of numerical methods for compressible flows\n",
183 | "\n",
184 | "Which of the sections of the outline contain: \n",
185 | "- private faculty members information (names or contact information)\n",
186 | "- prerequisites, requirements, application processes or other practical course information not related to the course content\n",
187 | "- assignments, final papers, exams, presentations or other student evaluation information\n",
188 | "Falcon: \n",
189 | "- private faculty members information (names or contact information): None\n",
190 | "- prerequisites, requirements, application processes, schedules or other practical course information not related to the course content: 1.3\n",
191 | "- assignments, final papers, exams, presentations or other student evaluation information: None\n",
192 | "User: The following is a course outline for a course on \\\"${COURSE_TITLE}\\\". This outline needs to be anonymized and adapted to an online audience:\n",
193 | "${SECTIONS_LIST}\n",
194 | "\n",
195 | "Which of the sections of the outline contain: \n",
196 | "- private faculty members information (names or contact information)\n",
197 | "- prerequisites, requirements, application processes, schedules or other practical course information not related to the course content\n",
198 | "- assignments, final papers, exams, presentations or other student evaluation information\n",
199 | "Falcon: \"\"\")"
200 | ],
201 | "metadata": {
202 | "collapsed": false
203 | },
204 | "id": "53702cd87a677315"
205 | },
206 | {
207 | "cell_type": "markdown",
208 | "source": [
209 | "Reformat cells into numbered format"
210 | ],
211 | "metadata": {
212 | "collapsed": false
213 | },
214 | "id": "ca753f9d85a95139"
215 | },
216 | {
217 | "cell_type": "code",
218 | "execution_count": null,
219 | "outputs": [],
220 | "source": [
221 | "import re\n",
222 | "\n",
223 | "FIND_SECTIONS_REGEX = re.compile(r\"\\d\\. .*(?:\\n\\s*- .*)+\")\n",
224 | "FIND_TITLES_REGEX = re.compile(r\"\\d\\. (.*)\")\n",
225 | "FIND_UNIT_TITLES_REGEX = re.compile(r\"\\n\\s*- (.*)\")\n",
226 | "\n",
227 | "def extract_sections(outline):\n",
228 | " sections = FIND_SECTIONS_REGEX.findall(outline)\n",
229 | " return [\n",
230 | " {\n",
231 | " \"section_nr\": si + 1,\n",
232 | " \"title\": FIND_TITLES_REGEX.search(section).group(1),\n",
233 | " \"unit_titles\": FIND_UNIT_TITLES_REGEX.findall(section),\n",
234 | " } for si, section in enumerate(sections)\n",
235 | " ]\n",
236 | "\n",
237 | "\n",
238 | "df = pd.read_csv(\"outlines_full.csv\", dtype=str)\n",
239 | "for a, b in df.iterrows():\n",
240 | " sections = extract_sections(b[\"outline\"])\n",
241 | " sections_list = '\\n'.join(\n",
242 | " [f\"{si + 1}.{ui + 1} {section['title']}: {unit_title}\" for si, section in enumerate(sections) for\n",
243 | " ui, unit_title in enumerate(section[\"unit_titles\"])])\n",
244 | " prompt = OUTLINE_FILTER_TEMPLATE.substitute({\"COURSE_TITLE\": b[\"course_title\"], \"SECTIONS_LIST\": sections_list})\n",
245 | " df.loc[a, 'filter_outline_prompt'] = prompt\n",
246 | " df.loc[a, 'filter_outline_result'] = generate... # actually generate the filter results"
247 | ],
248 | "metadata": {
249 | "collapsed": false
250 | },
251 | "id": "1cfa8766032e5ee2"
252 | },
253 | {
254 | "cell_type": "code",
255 | "execution_count": null,
256 | "outputs": [],
257 | "source": [
258 | "df.to_csv(\"outlines_full_filtered.csv\", index=False)"
259 | ],
260 | "metadata": {
261 | "collapsed": false
262 | },
263 | "id": "ea6a4080d3619d40"
264 | }
265 | ],
266 | "metadata": {
267 | "kernelspec": {
268 | "display_name": "Python 3",
269 | "language": "python",
270 | "name": "python3"
271 | },
272 | "language_info": {
273 | "codemirror_mode": {
274 | "name": "ipython",
275 | "version": 2
276 | },
277 | "file_extension": ".py",
278 | "mimetype": "text/x-python",
279 | "name": "python",
280 | "nbconvert_exporter": "python",
281 | "pygments_lexer": "ipython2",
282 | "version": "2.7.6"
283 | }
284 | },
285 | "nbformat": 4,
286 | "nbformat_minor": 5
287 | }
288 |
--------------------------------------------------------------------------------
/prompts/stanford/README.md:
--------------------------------------------------------------------------------
1 | # Synthetic textbooks from Stanford course outlines
2 |
3 | You can find the code for scraping Stanford courses in `1-scraper.ipynb`, and the code for generating course outlines in `2-generate_course_outlines.ipynb`.
4 |
5 | [TODO]: add code for updated prompts fo Cosmopedia
--------------------------------------------------------------------------------
/prompts/stories/README.md:
--------------------------------------------------------------------------------
1 | # Stories from UltraChat and OpenHermes
2 |
3 | We build several types of stories: educational stories for young children, stories involving morals and principles, stories involving problem solving, and posts found on forums and reddit. The prompts are based on seed samples from [UltraChat](https://huggingface.co/datasets/stingning/ultrachat) and [OpenHermes 2.5](https://huggingface.co/datasets/teknium/OpenHermes-2.5).
4 |
5 | We only use the "Questions about the world" [subset](https://huggingface.co/datasets/HuggingFaceTB/ultrachat_questions_about_world) of UltraChat. For OpenHermes we filter out non English instruction in OpenHermes and remove categories and sources that wouldn't be suitable for stories, the filtered dataset is available [here](https://huggingface.co/datasets/HuggingFaceTB/openhermes_filtered).
6 |
7 | To run the filtering
8 | ```bash
9 | python filter_openhermes.py
10 | ```
11 |
12 | To build the prompts
13 | ```bash
14 | python build_openhermes_stories_prompts.py --run_all_styles
15 | python build_ultrachat_stories_prompts.py --run_all_styles
16 | ```
17 |
--------------------------------------------------------------------------------
/prompts/stories/build_openhermes_stories_prompts.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from datasets import load_dataset
3 |
4 |
5 | STYLES = {"young_children_story":
6 | """Write an educational story (3-5 paragraphs) targeted at young children using simple words. The story should be inspired from this text snippet:
7 | “”
8 |
9 | The story doesn’t have to be addressing everything in the snippet, it is there just for inspiration.
10 | The story should have the following features:
11 | - Science integration: embed basic science concepts within the story, explaining them through the characters' adventures and discoveries. For example, if the story includes a scene where characters are looking at the sky, you could have them wonder why it's blue and explain the physics behind in grade school level.
12 | - Dialogue: include at least one dialogue and insightful conversation.
13 | - Unexpected twist: conclude with a twist that doesn't resolve as hoped, but leaves a clear lesson about life and science.
14 | Do not start with classic sentences like "Once upon a time", be creative.""",
15 |
16 | "problem_solving_story":
17 | """Write a story that explores a situation slightly related to this text snippet:
18 | “”
19 |
20 | The story should unfold through the characters interactions, decisions, and the consequences of their actions. Aim to weave in common sense lessons and social cues. The narrative should cater to a diverse age group, including at least one dialogue and presenting both positive and negative outcomes.
21 | Do not start with classic sentences like "Once upon a time", be creative.""",
22 |
23 | "reddit_post":
24 | """Write a real-life story shared by someone in a reddit forum. The story should be somehow related to this text snippet:
25 | “”
26 |
27 | The story should include:
28 | - Niche interests or humor: dive into specific hobbies, interests, or humorous situations
29 | - An unexpected plot twist or engaging conflict: introduce a relatable yet challenging situation or dilemma that the author faced.
30 | - Reflection and insight: end with a resolution that offers a new understanding, a sense of community, or a personal revelation, much like the conclusions drawn in forum discussions.
31 | Start the story right away. Do not start with sentences like "Once upon a time" as this is a reddit post and not a novel, you should also avoid starting with classic sentences like "A few years ago" or "A few years back", be creative."""}
32 |
33 | EXTRACT_SIZE = 1000
34 |
35 |
36 | def get_args():
37 | parser = argparse.ArgumentParser()
38 | parser.add_argument("--repo_id", type=str, default="HuggingFaceTB/prompts_stories_openhermes")
39 | parser.add_argument("--generation_style", type=str, default="problem_solving_story")
40 | parser.add_argument("--run_all_styles", action="store_true")
41 | return parser.parse_args()
42 |
43 |
44 | def build_prompt(x, style="forums_story"):
45 | """Build the prompt based on the generation type"""
46 | snippet = x["prompt"].strip()
47 | snippet = snippet[:min(len(snippet), EXTRACT_SIZE)]
48 | prompt = STYLES[style].replace("", snippet)
49 | return {f"prompt_{style}": prompt}
50 |
51 |
52 | if __name__ == "__main__":
53 | args = get_args()
54 |
55 | print(f"Loading ultrachat data...")
56 | ds = load_dataset("HuggingFaceTB/openhermes_filtered", split="train", num_proc=36)
57 | if args.run_all_styles:
58 | suffix = ""
59 | for style in STYLES.keys():
60 | print(f"📖 Building prompts with a {style}...")
61 | ds = ds.map(build_prompt, num_proc=48, fn_kwargs={"style": style})
62 | else:
63 | suffix = f"_{args.generation_style}"
64 | print(f"📖 Building prompts with a {args.generation_style}...")
65 | ds = ds.map(build_prompt, num_proc=48, fn_kwargs={"style": args.generation_style})
66 | print(ds)
67 | print(ds)
68 | print(ds[0]["prompt_young_children_story"])
69 | print("-"*100)
70 | print(ds[1]["prompt_problem_solving_story"])
71 | print("-"*100)
72 | print(ds[2]["prompt_reddit_post"])
73 | ds.push_to_hub(f"{args.repo_id}{suffix}", private=True)
74 | print(f"✅ Data available at {args.repo_id}{suffix}")
75 |
--------------------------------------------------------------------------------
/prompts/stories/build_ultrachat_stories_prompts.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import random
3 | from datasets import load_dataset
4 |
5 |
6 | STYLES = {"young_children_story":
7 | """Write an educational story (3-5 paragraphs) targeted at young children using simple words. The story should be inspired from this text snippet:
8 | “”
9 |
10 | The story doesn’t have to be addressing everything in the snippet, it is there just for inspiration.
11 | The story should have the following features:
12 | - Science integration: embed basic science concepts within the story, explaining them through the characters' adventures and discoveries. For example, if the story includes a scene where characters are looking at the sky, you could have them wonder why it's blue and explain the physics behind in grade-school level.
13 | - Dialogue: include at least one dialogue and insightful conversation.
14 | - Unexpected twist: conclude with a twist that doesn't resolve as hoped, but leaves a clear lesson about life and science.
15 | Do not start with classic sentences like "Once upon a time", be creative.""",
16 |
17 | "morality_story":
18 | """Write a compelling story related to the following text snippet:
19 | “”
20 |
21 | The story doesn’t need to mention everything in the snippet, use it just for inspiration and be creative!
22 | The story should incorporate the following elements:
23 | - Dialogue: the story must feature at least one meaningful dialogue that reveals character depth, advances the plot, or unravels a crucial piece of the mystery
24 | - Interesting themes: explore themes resonant with a mature audience, such as moral ambiguity, existential queries, personal transformation, or the consequences of past actions.
25 | Do not start with classic sentences like "Once upon a time", "The sun hung low in the sky" or "In the dimly lit", be creative.""",
26 |
27 | "problem_solving_story":
28 | """Write a story that explores a situation slightly related to this text snippet:
29 | “”
30 |
31 | The story should unfold through the characters interactions, decisions, and the consequences of their actions. Aim to weave in common sense lessons and social cues, emphasizing the importance of problem-solving. The narrative should cater to a diverse age group, including at least one dialogue and presenting both positive and negative outcomes.
32 | Do not start with classic sentences like "Once upon a time", be creative.""",
33 |
34 | "forums_story":
35 | """Write a story in the style of real-life situations that people share in forums. The story should be somehow related to this text snippet:
36 | “”
37 |
38 | The story needs to include a compelling and unexpected plot twist. Your narrative should resonate with the authenticity and personal touch found in forum discussions. Include relatable events and emotional depth.
39 | Do not start with classic sentences like "Once upon a time", "A few years back" or "A few montsh ago", be creative.""",
40 |
41 | "reddit_post":
42 | """Write a real-life story shared by someone in a reddit forum. The story should be somehow related to this text snippet:
43 | “”
44 |
45 | The story should include:
46 | - Niche interests or humor: dive into specific hobbies, interests, or humorous situations
47 | - An unexpected plot twist or engaging conflict: introduce a relatable yet challenging situation or dilemma that the author faced.
48 | - Reflection and insight: end with a resolution that offers a new understanding, a sense of community, or a personal revelation, much like the conclusions drawn in forum discussions.
49 | Start the story right away. Do not start with sentences like "Once upon a time" as this is a reddit post and not a novel, you should also avoid starting with classic sentences like "A few years ago" or "A few years back", be creative."""}
50 |
51 | EXTRACT_SIZE = 1000
52 |
53 |
54 | def get_args():
55 | parser = argparse.ArgumentParser()
56 | parser.add_argument("--repo_id", type=str, default="HuggingFaceTB/ultrachat_stories_prompts")
57 | parser.add_argument("--generation_style", type=str, default="textbook_academic")
58 | parser.add_argument("--run_all_styles", action="store_true")
59 | return parser.parse_args()
60 |
61 |
62 | def build_prompt(x, style="forums_story"):
63 | """Build the prompt based on the generation type"""
64 | snippet = x["first_turn"].strip()
65 | snippet = snippet[:min(len(snippet), EXTRACT_SIZE)]
66 | prompt = STYLES[style].replace("", snippet)
67 | return {f"prompt_{style}": prompt}
68 |
69 |
70 | if __name__ == "__main__":
71 | args = get_args()
72 |
73 | print(f"Loading ultrachat data...")
74 | ds = load_dataset("HuggingFaceTB/ultrachat_questions_about_world", split="train")
75 | if args.run_all_styles:
76 | suffix = ""
77 | for style in STYLES.keys():
78 | print(f"📖 Building prompts with a {style}...")
79 | ds = ds.map(build_prompt, num_proc=48, fn_kwargs={"style": style})
80 | else:
81 | suffix = f"_{args.generation_style}"
82 | print(f"📖 Building prompts with a {args.generation_style}...")
83 | ds = ds.map(build_prompt, num_proc=48, fn_kwargs={"style": args.generation_style})
84 | print(ds)
85 | print(ds)
86 | print(ds[0]["prompt_young_children_story"])
87 | print("-"*100)
88 | print(ds[1]["prompt_morality_story"])
89 | print("-"*100)
90 | print(ds[2]["prompt_reddit_post"])
91 | ds.push_to_hub(f"{args.repo_id}_{suffix}", private=True)
92 | print(f"✅ Data available at {args.repo_id}_{suffix}!")
93 |
--------------------------------------------------------------------------------
/prompts/stories/filter_openhermes.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 |
3 | ds = load_dataset("teknium/OpenHermes-2.5", split="train", num_proc=36)
4 | drop_sources = ["camelai", "glaive-code-assist"]
5 | drop_categories = ["rp", "gtkm", "coding", "wordgame", "riddle"]
6 |
7 | def filter_files(x):
8 | if x["category"] and x["category"].lower() in drop_categories:
9 | return False
10 | if x["source"] and x["source"].lower() in drop_sources:
11 | return False
12 | return True
13 |
14 | def get_prompt(x):
15 | conversations = x["conversations"]
16 | prompt = ""
17 | for i in range(len(conversations)):
18 | if conversations[i]["from"] == "human":
19 | prompt += conversations[i]["value"] + "\n"
20 | assert conversations[i+1]["from"] == "gpt", f"role is {conversations[i+1]['from']} not 'gpt'!"
21 | prompt += conversations[i+1]["value"]
22 | break
23 | return {"prompt": prompt}
24 |
25 | print("Start...")
26 | print(ds)
27 | print("Language filter...")
28 | ds = ds.filter(lambda x: x["language"] in [None, "English"], num_proc=12)
29 | print(ds)
30 | print("Category & source filter...")
31 | ds_f = ds.filter(filter_files, num_proc=36)
32 | print(ds_f)
33 | ds_f = ds_f.map(get_prompt, num_proc=36)
34 | ds_f = ds_f.remove_columns([col for col in ds_f.column_names if col not in ["prompt", "source", "category"]])
35 | ds_f.push_to_hub("HuggingFaceTB/openhermes_filtered")
--------------------------------------------------------------------------------
/prompts/web_samples/README.md:
--------------------------------------------------------------------------------
1 | # Synthetic data from Web samples
2 |
3 | We built several types of synthetic content from seed web samples: textbooks (in narrative or academic tone), blogposts and WikiHow articles.
4 |
5 | To select the web samples, we initially clustered 100k samples from a web dataset like [ReFineWeb](https://huggingface.co/datasets/tiiuae/falcon-refinedweb). This resulted into 145 cluters. You can inspect the clusters in this [demo](https://huggingface.co/spaces/HuggingFaceTB/inspect_clusters_free_topics). Then we inferred the clusters of 15M other web samples and used them for the prompts with their topic.
6 |
7 | The clustering code can be found in [text-clustering](https://github.com/huggingface/text-clustering?tab=readme-ov-file#cosmopedia-experiments-clustering-of-web-samples-and-topic-labeling) repository. We then excluded 38 clutsers, deemed uneducational, using the scores generated in the clustering, but also after doing some manual inspection of each cluster. We noticed that medium scores weren't always of the topic quality.
8 |
9 | We also tried to infer which generation style would best suit each topic: e.g Mathematics are suitable for textbooks, Beauty & Lifetyle might be suitable for blogposts and DIY for WikiHow articles. However we didn't respect this classification as the prompted LLM seemed to address each topic from interesting different angles when using it with different styles.
10 |
11 | Script for classification and filtering (this depends on the clusters you find in your dataset)
12 | ```bash
13 | python filter_and_classify_clusters.py
14 | ```
15 |
16 | Script for building the web prompts:
17 | ```bash
18 | python build_web_prompts.py
19 | ```
20 |
--------------------------------------------------------------------------------
/prompts/web_samples/build_web_prompts.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import random
3 | from datasets import load_dataset
4 |
5 |
6 | STYLES = {"wikihow":
7 | """Here is an extract from a webpage: "".
8 |
9 | Write a long and very detailed tutorial that could be part of WikiHow whose title is related to the extract above. Include in depth explanations for each step and how it helps achieve the desired outcome, inluding key tips and guidelines.
10 | Ensure clarity and practicality, allowing readers to easily follow and apply the instructions. Do not use images.""",
11 |
12 | "textbook_narrative":
13 | """Here is an extract from a webpage: "".
14 |
15 | Write an extensive and detailed course unit suitable for a textbook, related to the given extract. Do not just list concepts, but develop each one in detail before moving to the next, as we prioritize depth of understanding and comprehensive exploration of the subject matter over breadth. Focus on:
16 |
17 | - Rigor: Ensure in-depth coverage of the concepts.
18 | - Engagement: Use a narrative style akin to Michael Lewis, making it captivating and thought-provoking.
19 | - Relevance: Connect the topic with current trends, real-life examples, or recent studies. Do not use images.
20 | Do not include a title or an introduction, simply write the content without headlines and introductory phrases. Do not use images.""",
21 |
22 | "textbook_academic":
23 | """Here is an extract from a webpage: "".
24 |
25 | Write an extensive and detailed course unit suitable for a textbook targeted at college students, related to the given extract. Do not just list concepts, but develop each one in detail before moving to the next, as we prioritize depth of understanding and comprehensive exploration of the subject matter over breadth. Focus on:
26 |
27 | - Rigor: Ensure in-depth coverage of the concepts/sections.
28 | - Engagement: Write with an academic, professional and engaging tone that captivates interest.
29 | - Application: Incorporate specific, practical examples, such as proofs in calculus or critical dates and figures in history.
30 | Do not include a title or an introduction, simply write the content without headlines and introductory phrases. Do not use images.""",
31 |
32 | "blogpost":
33 | """Here is an extract from a webpage: "".
34 |
35 | Write an informative and insightful blog post that expands upon the extract above. Your post should delve into the nuances of the topic, offering fresh perspectives and deeper analysis. Aim to:
36 |
37 | - Inform: Provide valuable, well-researched information that educates the reader.
38 | - Engage: Write in a conversational tone that connects with the audience, making complex ideas accessible.
39 | - Illustrate: Use examples, anecdotes, or personal experiences to bring the topic to life.
40 | Do not give a title and do not start with sentences like "Have you ever..." or "Hello dear readers..", simply write the content without these introductory phrases."""
41 | }
42 |
43 | EXTRACT_SIZE = 1000
44 |
45 |
46 | def get_args():
47 | parser = argparse.ArgumentParser()
48 | parser.add_argument("--repo_id", type=str, default="HuggingFaceTB/web_prompts")
49 | parser.add_argument("--data_type", type=str, default="textbook")
50 | parser.add_argument("--generation_style", type=str, default="textbook_academic")
51 | parser.add_argument("--run_all_styles", action="store_true")
52 | return parser.parse_args()
53 |
54 |
55 | def build_prompt(x, style="textbook_academic"):
56 | """Build the prompt based on the generation type"""
57 | # web extract and topic
58 | web_sample = x["examples"]
59 | web_sample = web_sample[:min(EXTRACT_SIZE, len(web_sample))]
60 | topic = x["category"]
61 | add_topic = f', within the context of "{topic}"' if random.random() < 0.5 else ""
62 | # requested generation style
63 | prompt = STYLES[style].replace("", add_topic).replace("", web_sample)
64 | return {f"prompt_{style}": prompt}
65 |
66 |
67 | if __name__ == "__main__":
68 | # load data=data_type and generate content in style=stayle
69 | args = get_args()
70 |
71 | print(f"Loading data fw2_as_{args.data_type}...")
72 | ds = load_dataset(f"HuggingFaceTB/fw2_as_{args.data_type}", split="train", num_proc=48)
73 | if args.run_all_styles:
74 | suffix = ""
75 | for style in STYLES.keys():
76 | print(f"📖 Building prompts with a {style}...")
77 | ds = ds.map(build_prompt, num_proc=48, fn_kwargs={"style": style})
78 | else:
79 | suffix = f"_{args.generation_style}"
80 | print(f"📖 Building prompts with a {args.generation_style}...")
81 | ds = ds.map(build_prompt, num_proc=48, fn_kwargs={"style": args.generation_style})
82 | print(ds)
83 | print(ds)
84 | print(ds[0]["prompt_textbook_academic"])
85 | print("-"*100)
86 | print(ds[1]["prompt_textbook_academic"])
87 | print("-"*100)
88 | print(ds[2]["prompt_textbook_academic"])
89 | ds.push_to_hub(f"{args.repo_id}_{args.data_type}{suffix}", private=True)
90 | print(f"✅ Data available at {args.repo_id}_{args.data_type}{suffix}!")
--------------------------------------------------------------------------------
/prompts/web_samples/filter_and_classify_clusters.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import pandas as pd
3 | from collections import Counter
4 | from datasets import load_dataset, Dataset
5 |
6 |
7 | # re-classifying each topic into the most adequate class: textbook, blogpost, or wikihow
8 | classifications = {
9 | "textbook": [
10 | 'International Relations and Politics',
11 | 'Product Marketing and Design',
12 | 'Digital Imaging and Photography',
13 | 'Computer Science',
14 | 'Economics and Finance',
15 | 'Pharmaceutical manufacturing and technology',
16 | 'Real Estate & Investment',
17 | 'Business and Entrepreneurship',
18 | 'Astronomy and Astrophysics',
19 | 'Finance and Investment',
20 | 'Computer Antivirus Software and Security',
21 | 'Healthcare and Operations Management',
22 | 'Technology and Computer Science',
23 | 'Computer Programming and Web Development',
24 | 'Taxation and Finance',
25 | 'Human Resources / Organizational Management',
26 | 'Computer Hardware and Graphics Cards',
27 | 'Marketing and Business Strategies',
28 | 'Digital Marketing and Business',
29 | 'Audio Equipment and Home Theater Systems',
30 | 'HIV Treatment and Care',
31 | 'Legal Studies and Public Policy',
32 | 'Legal Studies / Law',
33 | 'Jewelry Design and Manufacturing',
34 | 'Biochemistry and Molecular Biology',
35 | 'Insurance',
36 | 'Energy and Environmental Policy',
37 | 'Data Privacy and Protection',
38 | 'International Relations and Conflict',
39 | 'Entomology and Apiculture',
40 | 'Loans and Mortgages',
41 | 'Public Transit and Transportation',
42 | 'International Relations and Current Events',
43 | 'Politics and Government',
44 | 'Political Science',
45 | 'Genetics and Mental Health',
46 | 'Public Administration and Policy',
47 | 'Technology and Consumer Electronics',
48 | 'Computer Security & Privacy',
49 | 'Online Platforms & Web Technologies',
50 | 'Human Resources and Education',
51 | 'Sports and Education',
52 | 'Lighting Design and Technology',
53 | 'Medicine',
54 | 'Cryptocurrency and Blockchain Technology',
55 | 'Mental Health Counseling',
56 | 'Geography and Weather',
57 | 'Leadership and Education',
58 | 'Infant Feeding and Child Development',
59 | 'Molecular Biology and Genetics',
60 | 'Energy and Natural Resources',
61 | 'Mental Health and Therapy',
62 | 'Business and Management',
63 | 'Legal Services and Issues',
64 | 'Christian Theology and Spirituality',
65 | 'Personal Finance and Investments',
66 | 'Psychology',
67 | 'Healthcare & Medical Services',
68 | 'Watchmaking and Horology',
69 | 'Online Chat Platforms and Data Privacy',
70 | 'Waste Management and Recycling'
71 | ],
72 | "blogpost": [
73 | 'Health and Lifestyle',
74 | 'Physical Fitness and Health',
75 | 'Music',
76 | 'Fiction and Fantasy Writing',
77 | 'Literature and Creative Writing',
78 | 'Arts and Crafts',
79 | 'Education',
80 | 'Education and Youth Development',
81 | 'Writing and Storytelling',
82 | 'Hair Care and Styling',
83 | 'Automotive Parts and Accessories',
84 | 'Astrology',
85 | 'Culinary Arts and Beverages',
86 | 'Events and Community Happenings',
87 | 'Cooking and Baking',
88 | 'Online Dating & Relationships',
89 | 'Career Development and Job Opportunities',
90 | 'Cosmetic Surgery and Body Modifications',
91 | 'Skincare and Beauty Products',
92 | 'Addiction and Mental Illness',
93 | 'Visual Arts and Art Appreciation',
94 | 'Pets and Pet Care',
95 | 'Personal Development and Empowerment',
96 | 'Video Games',
97 | 'Hair Care',
98 | 'Nutrition and Health',
99 | 'Fashion & Apparel',
100 | 'Travel',
101 | 'Performing Arts',
102 | 'Cannabis and CBD Products',
103 | 'Wine & Winemaking',
104 | 'Cooking and Recipes'
105 | ],
106 | "wikihow": [
107 | 'Dentistry',
108 | 'Football/Soccer',
109 | 'Cleaning and Maintenance',
110 | 'American Football',
111 | 'Baseball',
112 | 'Recreational Fishing',
113 | 'Public Safety and Emergency Response',
114 | 'Ice Hockey',
115 | 'Professional Basketball/NBA',
116 | 'Home Improvement and Maintenance',
117 | 'Tennis',
118 | 'Professional Wrestling and Sports Entertainment',
119 | 'Cricket',
120 | 'Gun Control and Violence',
121 | 'Fire Incidents',
122 | 'Electric Vehicles and Battery Technology',
123 | 'Christianity and Theology'
124 | ]
125 | }
126 | remove = ['Online Gambling and Casinos',
127 | 'Reality Television and Celebrity Gossip',
128 | 'Explicit Adult Content',
129 | 'Fashion & Accessories',
130 | 'Solicitation and Prostitution',
131 | 'Adult Entertainment and Webcam Sites',
132 | 'Clothing & Fashion',
133 | 'Firearms and Accessories',
134 | 'Marketing and Sales Promotions',
135 | 'Cookies and Privacy Policies',
136 | 'Fragrances and Personal Care Products',
137 | 'Obituaries and Personal Profiles',
138 | 'Motor Sports',
139 | 'Death',
140 | 'Male Enhancement Products and Supplements',
141 | 'Political Science',
142 | 'Heating',
143 | 'Anabolic Steroids and',
144 | 'Combat Sports',
145 | 'Events and Community Happenings',
146 | 'Footwear and Fashion',
147 | 'Furniture Design and Sales',
148 | 'Entertainment & Media',
149 | 'Real Estate & Property Management',
150 | 'Weight Loss and Body Contouring',
151 | 'Moving Services and Logistics'
152 | 'Transportation and City Planning',
153 | 'Home Decoration and Furniture',
154 | 'Events and Conferences',
155 | 'Cosmetics and Beauty Products',
156 | 'Mattresses and Sleep',
157 | 'Soap & Skincare Products',
158 | 'E-commerce and Online Shopping',
159 | 'Optical Equipment and Accessories',
160 | ]
161 |
162 | new_dict = {v[i]: k for k, v in classifications.items() for i in range(len(v))}
163 |
164 |
165 | def get_args():
166 | parser = argparse.ArgumentParser()
167 | parser.add_argument("--clusters_dataset", type=str, default="HuggingFaceTB/web_clusters")
168 | parser.add_argument("--user", type=str, default="HuggingFaceTB")
169 | return parser.parse_args()
170 |
171 |
172 | def extract_category(example):
173 | summary = example["summary"]
174 | category = summary.split(". Educational")[0].strip()
175 | score = summary.split(" Educational score: ")[1].strip()
176 | return {"category": category, "educational_score": score}
177 |
178 |
179 | def add_generation_type(x):
180 | topic = x["category"]
181 | try:
182 | generation_type = new_dict[topic]
183 | except:
184 | print(f"{topic} not in keep list")
185 | generation_type = "blogpost"
186 | return {"generation_type": generation_type}
187 |
188 | args = get_args()
189 | print("Loading web samples (after the clustering)...")
190 | ds_1 = load_dataset(args.clusters_dataset, split="train")
191 | print(ds_1)
192 |
193 | print("Converting to dataframe...")
194 | full_df = ds_1.to_pandas().explode("examples")
195 |
196 | full_df.sort_values(by=['cluster_id'], inplace=True)
197 |
198 | print("Full df info...")
199 | print(full_df.head())
200 | print(full_df.info())
201 |
202 | print("Convert to HF dataset...")
203 | final_ds = Dataset.from_pandas(full_df)
204 | final_ds = final_ds.map(extract_category)
205 | print("HF dataset:")
206 | print(final_ds)
207 |
208 | print("Filter out bad topics...")
209 | ds_keep = final_ds.filter(lambda x: x["category"] not in remove, num_proc=64)
210 | print(f"Size after dropping low quality clusters: {len(ds_keep)}={len(ds_keep)*100/len(final_ds):.2f}% of the original dataset")
211 |
212 | print("Add generation type...")
213 | ds_keep = ds_keep.map(add_generation_type, num_proc=24)
214 | print(Counter(ds_keep["generation_type"]))
215 |
216 | print("Retrieve textbooks...")
217 | textbooks = ds_keep.filter(lambda x: x["generation_type"] == "textbook")
218 | print(textbooks)
219 | print("Retrieve wikihow...")
220 | wikihow = ds_keep.filter(lambda x: x["generation_type"] == "wikihow")
221 | print(wikihow)
222 | print("Retrieve blopgpot...")
223 | blogpost = ds_keep.filter(lambda x: x["generation_type"] == "blogpost")
224 | print(blogpost)
225 |
226 | print("Pushing to hub ...")
227 | textbooks.push_to_hub(f"{args.user}/fw2_as_textbook", private=True)
228 | wikihow.push_to_hub(f"{args.user}/fw2_as_wikihow", private=True)
229 | blogpost.push_to_hub(f"{args.user}/fw2_as_blogpost", private=True)
230 | print("Done!")
--------------------------------------------------------------------------------
/prompts/wikihow/README.md:
--------------------------------------------------------------------------------
1 | # Synthetic WikiHow articles from scraped WikiHow titles
2 |
3 | You can find the list fo WikiHow titles we scraped in `wikihowcom-20231012-titles.txt`.
4 | An updated list of wikihow titles can be extracted using https://github.com/mediawiki-client-tools/mediawiki-dump-generator
5 |
6 | [TODO] Add code for updated prompts of Cosmopedia
--------------------------------------------------------------------------------