├── zyda ├── __init__.py ├── utils │ ├── __init__.py │ ├── text.py │ ├── common.py │ └── filtering.py ├── lsh_minhash │ ├── __init__.py │ ├── compute_optimal_params.py │ ├── compute_minhash.py │ └── build_lsh_index.py ├── preprocessing_and_filtering │ ├── __init__.py │ ├── profanity_word_list.json │ ├── sexual_word_list.json │ ├── cursed_substrings.json │ ├── zh_pornsignals.json │ └── preprocess_and_filter.py └── connected_components │ ├── generate_indices_to_remove.py │ └── generate_connected_components.py ├── .gitignore ├── zyda_reproduction ├── 6_generating_final_dataset │ ├── run_generate_lsh_0.4_jsonls.sh │ ├── convert_jsonls_to_parquet.py │ ├── run_convert_to_parquet_lsh_0.4.sh │ └── generate_final_jsonls.py ├── 3_minhashing │ ├── minhash_refinedweb.sh │ ├── minhash_slimpajama.sh │ ├── minhash_pile_c4_peS2o_arxiv.sh │ └── minhash_starcoder.sh ├── 2_preprocessing_and_filtering │ ├── preprocess_slimpajama.sh │ ├── preprocess_refinedweb.sh │ ├── preprocess_pile_c4_peS2o_arxiv.sh │ └── preprocess_starcoder.sh ├── 1_downloading │ ├── download_refinedweb.py │ ├── process_repo_starcoder.py │ ├── download_arxiv_pile_peS2o_c4_refinedweb.py │ └── process_repo_slimpajama.py ├── 5_clustering │ └── run_cc_lsh_0.4_dupes.sh └── 4_lsh_indexing │ └── run_lsh_dupes_0.4_all.sh ├── setup.py ├── README.md └── LICENSE /zyda/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /zyda/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /zyda/lsh_minhash/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /zyda/preprocessing_and_filtering/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .DS_Store 3 | *.so 4 | build 5 | .coverage_* 6 | *.egg-info 7 | *~ 8 | logs 9 | .vscode 10 | dist/ 11 | *.nfs* 12 | -------------------------------------------------------------------------------- /zyda/preprocessing_and_filtering/profanity_word_list.json: -------------------------------------------------------------------------------- 1 | {"words": [" fuck ", " shit ", " damn ", " cunt ", " douchebag ", " crap ", " nigger ", " slut ", " turd ", " asshole ", " bugger "]} -------------------------------------------------------------------------------- /zyda/preprocessing_and_filtering/sexual_word_list.json: -------------------------------------------------------------------------------- 1 | {"words": [" penis ", " dick ", " cock ", " vagina ", " pussy ", " tits ", " boobs ", " cunt ", " anal ", " blowjob ", " dildo ", " jerk off ", " pussies ", " masturbating ", " jizz ", " masturbate "]} -------------------------------------------------------------------------------- /zyda_reproduction/6_generating_final_dataset/run_generate_lsh_0.4_jsonls.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python $REPO_BASE/zyda_reproduction/6_generating_final_dataset/generate_final_jsonls.py \ 4 | --input-indices $DATA_BASE/lsh_0.4/dupes/output/dupes.pickle \ 5 | --out-folder $DATA_BASE/zyda_0.4-final/jsonl \ 6 | --jsonl-partitions 48 7 | -------------------------------------------------------------------------------- /zyda_reproduction/3_minhashing/minhash_refinedweb.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python $REPO_BASE/zyda/lsh_minhash/compute_minhash.py \ 4 | --load-path $DATA_BASE/processed/refinedweb \ 5 | --save-path $DATA_BASE/minhash/refinedweb \ 6 | --num-proc $NUM_PROC \ 7 | --width 13 \ 8 | --num-perm 128 \ 9 | --key transformed_text 10 | -------------------------------------------------------------------------------- /zyda_reproduction/3_minhashing/minhash_slimpajama.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python $REPO_BASE/zyda/lsh_minhash/compute_minhash.py \ 4 | --load-path $DATA_BASE/processed/slimpajama \ 5 | --save-path $DATA_BASE/minhash/slimpajama \ 6 | --num-proc $NUM_PROC \ 7 | --width 13 \ 8 | --num-perm 128 \ 9 | --key transformed_text 10 | -------------------------------------------------------------------------------- /zyda_reproduction/2_preprocessing_and_filtering/preprocess_slimpajama.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python $REPO_BASE/zyda/preprocessing_and_filtering/preprocess_and_filter.py \ 4 | --hf-path $DATA_BASE/raw/slimpajama \ 5 | --load-from-disk \ 6 | --num-proc $NUM_PROC \ 7 | --key text \ 8 | --name slimpajama \ 9 | --save-path $DATA_BASE/processed/slimpajama 10 | -------------------------------------------------------------------------------- /zyda_reproduction/2_preprocessing_and_filtering/preprocess_refinedweb.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NAME=refinedweb 4 | python $REPO_BASE/zyda/preprocessing_and_filtering/preprocess_and_filter.py \ 5 | --hf-path $DATA_BASE/raw/$NAME \ 6 | --load-from-disk \ 7 | --num-proc $NUM_PROC \ 8 | --key text \ 9 | --name $NAME \ 10 | --save-path $DATA_BASE/processed/$NAME 11 | -------------------------------------------------------------------------------- /zyda_reproduction/1_downloading/download_refinedweb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datasets 3 | 4 | DATA_BASE = os.environ.get("DATA_BASE") 5 | NUM_PROC = int(os.environ.get("NUM_PROC", 1)) 6 | 7 | data = datasets.load_dataset( 8 | "tiiuae/falcon-refinedweb", 9 | split="train", 10 | download_config=datasets.DownloadConfig( 11 | num_proc=8, 12 | resume_download=True, 13 | ), 14 | num_proc=NUM_PROC, 15 | ) 16 | print(data) 17 | data.save_to_disk(os.path.join(DATA_BASE, "raw/refinedweb")) 18 | -------------------------------------------------------------------------------- /zyda/preprocessing_and_filtering/cursed_substrings.json: -------------------------------------------------------------------------------- 1 | {"words": [" \u2116", "\ufffd\ufffd\ufffd", "\\|\\s*$", " nr\\.$", "aute irure dolor ", " sunt in culpa qui ", "orem ipsum ", " quis nostrud ", " adipisicing ", " dolore eu ", " cupidatat ", "autem vel eum", "wisi enim ad", " sex ", " porn ", "\u9ec4\u8272\u7535\u5f71", "mp3", "ownload", "Vol\\.", " Ep\\.", "Episode", " \u0433\\.\\s*$", " \u043a\u0433\\.\\s*$", " \u0448\u0442\\.", "Develop", "Facebook", " crusher ", " xxx ", " ... ... ... ... ... ... ... ... ...", " .... .... .... .... .... .... .... .... ....", " [^ ] [^ ] [^ ] [^ ] [^ ] [^ ] [^ ] [^ ] [^ ]", ", ..,,? ..,,? ..,,? ..,,?"]} -------------------------------------------------------------------------------- /zyda_reproduction/5_clustering/run_cc_lsh_0.4_dupes.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | INPUT_DIR=$DATA_BASE/lsh_0.4/dupes 4 | CC_PICKLE=$DATA_BASE/lsh_0.4/dupes/output/cc.pickle 5 | DUPES_PICKLE=$DATA_BASE/lsh_0.4/dupes/output/dupes.pickle 6 | 7 | WORKERS=2 8 | 9 | python $REPO_BASE/zyda/connected_components/generate_connected_components.py \ 10 | --input-dir $INPUT_DIR \ 11 | --out-file $CC_PICKLE \ 12 | --workers $WORKERS 13 | 14 | python $REPO_BASE/zyda/connected_components/generate_indices_to_remove.py \ 15 | --input-file $CC_PICKLE \ 16 | --out-file $DUPES_PICKLE \ 17 | --ranking \ 18 | starcoder-languages \ 19 | starcoder-github-issues-filtered-structured \ 20 | starcoder-jupyter-structured-clean-dedup \ 21 | starcoder-git-commits-cleaned \ 22 | refinedweb \ 23 | peS2o \ 24 | arxiv \ 25 | c4-en \ 26 | pile-uncopyrighted \ 27 | slimpajama 28 | -------------------------------------------------------------------------------- /zyda_reproduction/3_minhashing/minhash_pile_c4_peS2o_arxiv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python $REPO_BASE/zyda/lsh_minhash/compute_minhash.py \ 4 | --load-path $DATA_BASE/processed/pile-uncopyrighted \ 5 | --save-path $DATA_BASE/minhash/pile-uncopyrighted \ 6 | --num-proc $NUM_PROC \ 7 | --width 13 \ 8 | --num-perm 128 \ 9 | --key transformed_text 10 | 11 | python $REPO_BASE/zyda/lsh_minhash/compute_minhash.py \ 12 | --load-path $DATA_BASE/processed/c4-en \ 13 | --save-path $DATA_BASE/minhash/c4-en \ 14 | --num-proc $NUM_PROC \ 15 | --width 13 \ 16 | --num-perm 128 \ 17 | --key transformed_text 18 | 19 | python $REPO_BASE/zyda/lsh_minhash/compute_minhash.py \ 20 | --load-path $DATA_BASE/processed/peS2o \ 21 | --save-path $DATA_BASE/minhash/peS2o \ 22 | --num-proc $NUM_PROC \ 23 | --width 13 \ 24 | --num-perm 128 \ 25 | --key transformed_text 26 | 27 | python $REPO_BASE/zyda/lsh_minhash/compute_minhash.py \ 28 | --load-path $DATA_BASE/processed/arxiv \ 29 | --save-path $DATA_BASE/minhash/arxiv \ 30 | --num-proc $NUM_PROC \ 31 | --width 13 \ 32 | --num-perm 128 \ 33 | --key transformed_text 34 | -------------------------------------------------------------------------------- /zyda_reproduction/4_lsh_indexing/run_lsh_dupes_0.4_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | LSH_OUT_DIR=$DATA_BASE/lsh_0.4 3 | BANDS=32 4 | RANGE=4 5 | 6 | # Running on one machine will be slow. 7 | # Use --nodes and --node-rank for ditributing a job among multiple compute nodes. 8 | # We split our job across 8 nodes with 2TB of RAM. Peak RAM usge was a bit more than 1TB. 9 | 10 | python $REPO_BASE/zyda/lsh_minhash/build_lsh_index.py \ 11 | --load-path \ 12 | $DATA_BASE/minhash/pile-uncopyrighted \ 13 | $DATA_BASE/minhash/c4-en \ 14 | $DATA_BASE/minhash/peS2o \ 15 | $DATA_BASE/minhash/arxiv \ 16 | $DATA_BASE/minhash/refinedweb \ 17 | $DATA_BASE/minhash/slimpajama \ 18 | $DATA_BASE/minhash/starcoder-languages \ 19 | $DATA_BASE/minhash/starcoder-git-commits-cleaned \ 20 | $DATA_BASE/minhash/starcoder-github-issues-filtered-structured \ 21 | $DATA_BASE/minhash/starcoder-jupyter-structured-clean-dedup \ 22 | --dupes-out $LSH_OUT_DIR/dupes/all_pairs.txt \ 23 | --lsh-out $LSH_OUT_DIR/lsh_index.pickle \ 24 | --range $RANGE \ 25 | --bands $BANDS \ 26 | --bands-parallel 2 \ 27 | --reader-processes 12 28 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | NAME = "zyda" 4 | AUTHOR = "Zyphra Technologies" 5 | VERSION = "0.0.1" 6 | DESCRIPTION = "Processing of LLM datasets at a large scale" 7 | 8 | with open("README.md", "r", encoding="utf-8") as f: 9 | long_description = f.read() 10 | 11 | # Setting up 12 | setuptools.setup( 13 | name=NAME, 14 | version=VERSION, 15 | author=AUTHOR, 16 | description=DESCRIPTION, 17 | long_description=long_description, 18 | long_description_content_type="text/markdown", 19 | url="https://github.com/Zyphra/Zyda_processing", 20 | packages=setuptools.find_packages( 21 | exclude=( 22 | "dist", 23 | "zyda.egg-info", 24 | ) 25 | ), 26 | install_requires=[ 27 | "datasketch==1.5.9", 28 | "networkit==10.1", 29 | "nltk==3.8.1", 30 | "numpy==1.24.3", 31 | "regex==2023.6.3", 32 | "scipy==1.10.1", 33 | "tqdm==4.65.0", 34 | "ftfy==6.1.1", 35 | "more-itertools==9.1.0", 36 | "Levenshtein==0.25.1", 37 | "zstandard==0.22.0", 38 | "transformers==4.41.2", 39 | "datasets==2.18.0", 40 | ], 41 | python_requires=">=3.10", 42 | ) 43 | -------------------------------------------------------------------------------- /zyda_reproduction/2_preprocessing_and_filtering/preprocess_pile_c4_peS2o_arxiv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NAME=pile-uncopyrighted 4 | python $REPO_BASE/zyda/preprocessing_and_filtering/preprocess_and_filter.py \ 5 | --hf-path $DATA_BASE/raw/$NAME \ 6 | --load-from-disk \ 7 | --num-proc $NUM_PROC \ 8 | --key text \ 9 | --name $NAME \ 10 | --save-path $DATA_BASE/processed/$NAME 11 | 12 | NAME=c4-en 13 | python $REPO_BASE/zyda/preprocessing_and_filtering/preprocess_and_filter.py \ 14 | --hf-path $DATA_BASE/raw/$NAME \ 15 | --load-from-disk \ 16 | --num-proc $NUM_PROC \ 17 | --key text \ 18 | --name $NAME \ 19 | --save-path $DATA_BASE/processed/$NAME 20 | 21 | NAME=peS2o 22 | python $REPO_BASE/zyda/preprocessing_and_filtering/preprocess_and_filter.py \ 23 | --hf-path $DATA_BASE/raw/$NAME \ 24 | --load-from-disk \ 25 | --num-proc $NUM_PROC \ 26 | --key text \ 27 | --name $NAME \ 28 | --save-path $DATA_BASE/processed/$NAME 29 | 30 | NAME=arxiv 31 | python $REPO_BASE/zyda/preprocessing_and_filtering/preprocess_and_filter.py \ 32 | --hf-path $DATA_BASE/raw/$NAME \ 33 | --load-from-disk \ 34 | --num-proc $NUM_PROC \ 35 | --key text \ 36 | --name $NAME \ 37 | --save-path $DATA_BASE/processed/$NAME 38 | -------------------------------------------------------------------------------- /zyda_reproduction/3_minhashing/minhash_starcoder.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python $REPO_BASE/zyda/lsh_minhash/compute_minhash.py \ 4 | --load-path $DATA_BASE/processed/starcoder-languages \ 5 | --save-path $DATA_BASE/minhash/starcoder-languages \ 6 | --num-proc $NUM_PROC \ 7 | --width 13 \ 8 | --num-perm 128 \ 9 | --key content 10 | 11 | python $REPO_BASE/zyda/lsh_minhash/compute_minhash.py \ 12 | --load-path $DATA_BASE/processed/starcoder-github-issues-filtered-structured \ 13 | --save-path $DATA_BASE/minhash/starcoder-github-issues-filtered-structured \ 14 | --num-proc $NUM_PROC \ 15 | --width 13 \ 16 | --num-perm 128 \ 17 | --key content 18 | 19 | python $REPO_BASE/zyda/lsh_minhash/compute_minhash.py \ 20 | --load-path $DATA_BASE/processed/starcoder-jupyter-structured-clean-dedup \ 21 | --save-path $DATA_BASE/minhash/starcoder-jupyter-structured-clean-dedup \ 22 | --num-proc $NUM_PROC \ 23 | --width 13 \ 24 | --num-perm 128 \ 25 | --key content 26 | 27 | python $REPO_BASE/zyda/lsh_minhash/compute_minhash.py \ 28 | --load-path $DATA_BASE/processed/starcoder-git-commits-cleaned \ 29 | --save-path $DATA_BASE/minhash/starcoder-git-commits-cleaned \ 30 | --num-proc $NUM_PROC \ 31 | --width 13 \ 32 | --num-perm 128 \ 33 | --key content 34 | -------------------------------------------------------------------------------- /zyda_reproduction/2_preprocessing_and_filtering/preprocess_starcoder.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NAME=starcoder-languages 4 | python $REPO_BASE/zyda/preprocessing_and_filtering/preprocess_and_filter.py \ 5 | --hf-path $DATA_BASE/raw/$NAME \ 6 | --load-from-disk \ 7 | --num-proc $NUM_PROC \ 8 | --key content \ 9 | --keep-key \ 10 | --name $NAME \ 11 | --save-path $DATA_BASE/processed/$NAME 12 | 13 | NAME=starcoder-github-issues-filtered-structured 14 | python $REPO_BASE/zyda/preprocessing_and_filtering/preprocess_and_filter.py \ 15 | --hf-path $DATA_BASE/raw/$NAME \ 16 | --load-from-disk \ 17 | --num-proc $NUM_PROC \ 18 | --key content \ 19 | --keep-key \ 20 | --name $NAME \ 21 | --save-path $DATA_BASE/processed/$NAME 22 | 23 | NAME=starcoder-jupyter-structured-clean-dedup 24 | python $REPO_BASE/zyda/preprocessing_and_filtering/preprocess_and_filter.py \ 25 | --hf-path $DATA_BASE/raw/$NAME \ 26 | --load-from-disk \ 27 | --num-proc $NUM_PROC \ 28 | --key content \ 29 | --keep-key \ 30 | --name $NAME \ 31 | --save-path $DATA_BASE/processed/$NAME 32 | 33 | NAME=starcoder-git-commits-cleaned 34 | python $REPO_BASE/zyda/preprocessing_and_filtering/preprocess_and_filter.py \ 35 | --hf-path $DATA_BASE/raw/$NAME \ 36 | --load-from-disk \ 37 | --num-proc $NUM_PROC \ 38 | --key content \ 39 | --keep-key \ 40 | --name $NAME \ 41 | --save-path $DATA_BASE/processed/$NAME 42 | -------------------------------------------------------------------------------- /zyda/utils/text.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Zyphra Technologies. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import string 16 | import ftfy 17 | import re 18 | import nltk 19 | nltk.download('punkt') 20 | 21 | # Inspired by: https://github.com/Cerebras/modelzoo/blob/0bb30b6e681e792f3ba1804835d3f966a7ec9611/src/cerebras/modelzoo/data_preparation/nlp/slimpajama/dedup/to_hash.py#L32 22 | def get_normalized_words(s: str): 23 | # normalize string 24 | s = ftfy.fix_text(s, normalization="NFC") 25 | # lower cased 26 | s = s.lower() 27 | # remove punctuation 28 | s = s.translate(str.maketrans("", "", string.punctuation)) 29 | # remove consecutive spaces, newlines, tabs in the middle and in the beginning / end 30 | s = re.sub(r"\s+", " ", s.strip()) 31 | # return words 32 | return nltk.word_tokenize(s) 33 | 34 | 35 | def get_features(s: str, width: int): 36 | return map(lambda x: " ".join(x), nltk.ngrams(get_normalized_words(s), width)) 37 | -------------------------------------------------------------------------------- /zyda/utils/common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Zyphra Technologies. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import datasets 17 | import tqdm 18 | 19 | import logging 20 | logging.basicConfig(format='%(asctime)s: %(message)s', level=logging.INFO) 21 | 22 | 23 | class ComboDataset: 24 | def __init__(self, path: str): 25 | self.path = path 26 | 27 | shards_dirs = sorted(os.listdir(path)) 28 | logging.info(f"\nLoading {len(shards_dirs)} shards from {path}") 29 | self.shards_dirs = shards_dirs 30 | 31 | shards = [] 32 | for shard_dir in tqdm.tqdm(shards_dirs): 33 | load_path = os.path.join(path, shard_dir) 34 | shards.append(datasets.load_from_disk(load_path)) 35 | self.shards = shards 36 | 37 | ds = datasets.concatenate_datasets(shards) 38 | self.ds = ds 39 | logging.info(ds) 40 | 41 | 42 | def ensure_directory_exists(filename: str): 43 | os.makedirs(os.path.dirname(filename), exist_ok = True) 44 | -------------------------------------------------------------------------------- /zyda_reproduction/1_downloading/process_repo_starcoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from collections import defaultdict 4 | import datasets 5 | 6 | NUM_PROC = os.environ.get("NUM_PROC", 1) 7 | REPO_PATH = os.environ.get("STARCODER_REPO_PATH") 8 | DATA_BASE = os.environ.get("DATA_BASE") 9 | 10 | all_dirs = sorted([item for item in os.listdir(REPO_PATH) if os.path.isdir(os.path.join(REPO_PATH, item)) and ".git" not in item]) 11 | print(all_dirs) 12 | 13 | has_different_features = set( 14 | [ 15 | 'jupyter-scripts-dedup-filtered', 16 | 'jupyter-structured-clean-dedup', 17 | 'github-issues-filtered-structured', 18 | 'git-commits-cleaned', 19 | ] 20 | ) 21 | 22 | type2data = defaultdict(list) 23 | for key in all_dirs: 24 | print(f"\nLoading {key}") 25 | t0 = time.time() 26 | data_raw = datasets.load_dataset( 27 | os.path.join(REPO_PATH, key), 28 | num_proc=NUM_PROC, 29 | split='train' 30 | ) 31 | data_with_dir = data_raw.map(lambda row: {"dir": key}, num_proc=NUM_PROC) 32 | 33 | new_features = data_with_dir.features.copy() 34 | for feature in new_features: 35 | new_features[feature] = datasets.Value(dtype="string") 36 | data_casted = data_with_dir.cast(new_features, num_proc=NUM_PROC) 37 | 38 | if key in has_different_features: 39 | type2data[key].append(data_casted) 40 | else: 41 | type2data['languages'].append(data_casted) 42 | 43 | print() 44 | for key, datas in type2data.items(): 45 | print(f"{key}: {sum([len(data) for data in datas])}") 46 | 47 | for key, datas in type2data.items(): 48 | print(f"\nSaving {key}") 49 | conc_dataset = datasets.concatenate_datasets(datas) 50 | conc_dataset.save_to_disk(os.path.join(DATA_BASE, f"raw/starcoder-{key}")) 51 | -------------------------------------------------------------------------------- /zyda_reproduction/1_downloading/download_arxiv_pile_peS2o_c4_refinedweb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datasets 3 | 4 | DATA_BASE = os.environ.get("DATA_BASE") 5 | NUM_PROC = int(os.environ.get("NUM_PROC", 1)) 6 | 7 | data = datasets.load_dataset( 8 | "ArtifactAI/arxiv_s2orc_parsed", 9 | split="train", 10 | download_config=datasets.DownloadConfig( 11 | num_proc=8, 12 | resume_download=True, 13 | ), 14 | num_proc=NUM_PROC, 15 | ) 16 | print(data) 17 | data.save_to_disk(os.path.join(DATA_BASE, "raw/arxiv")) 18 | 19 | data = datasets.load_dataset( 20 | "monology/pile-uncopyrighted", 21 | split="train", 22 | download_config=datasets.DownloadConfig( 23 | num_proc=8, 24 | resume_download=True, 25 | ), 26 | num_proc=NUM_PROC, 27 | ) 28 | print(data) 29 | data.save_to_disk(os.path.join(DATA_BASE, "raw/pile-uncopyrighted")) 30 | 31 | data = datasets.load_dataset( 32 | "allenai/peS2o", 33 | split="train", 34 | trust_remote_code=True, 35 | download_config=datasets.DownloadConfig( 36 | num_proc=8, 37 | resume_download=True, 38 | ), 39 | num_proc=NUM_PROC, 40 | ) 41 | print(data) 42 | data.save_to_disk(os.path.join(DATA_BASE, "raw/peS2o")) 43 | 44 | data = datasets.load_dataset( 45 | "allenai/c4", "en", 46 | split="train", 47 | download_config=datasets.DownloadConfig( 48 | num_proc=8, 49 | resume_download=True, 50 | ), 51 | num_proc=NUM_PROC, 52 | ) 53 | print(data) 54 | data.save_to_disk(os.path.join(DATA_BASE, "raw/c4-en")) 55 | 56 | data = datasets.load_dataset( 57 | "tiiuae/falcon-refinedweb", 58 | split="train", 59 | download_config=datasets.DownloadConfig( 60 | num_proc=8, 61 | resume_download=True, 62 | ), 63 | num_proc=NUM_PROC, 64 | ) 65 | print(data) 66 | data.save_to_disk(os.path.join(DATA_BASE, "raw/refinedweb")) 67 | -------------------------------------------------------------------------------- /zyda/lsh_minhash/compute_optimal_params.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Zyphra Technologies. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | 17 | from datasketch.lsh import _false_negative_probability, _false_positive_probability, _optimal_param 18 | 19 | if __name__ == '__main__': 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--num-perm', type=int, default=128, help='Number of permutations in minhash') 22 | parser.add_argument('--threshold', type=float, default=0.4, help='Jaccard similarity threshold to declare a pair duplicate') 23 | parser.add_argument('--fn-weight', type=float, default=0.5, help='Weight of false negatives when determining optimal parameters') 24 | parser.add_argument('--fp-weight', type=float, default=0.5, help='Weight of false positives when determining optimal parameters') 25 | args = parser.parse_args() 26 | 27 | b, r = _optimal_param( 28 | threshold=args.threshold, 29 | num_perm=args.num_perm, 30 | false_negative_weight=args.fn_weight, 31 | false_positive_weight=args.fp_weight, 32 | ) 33 | 34 | print(f"\nOptimal LSH index parameters: b = {b}, r = {r}") 35 | print(f"FN probability = {_false_negative_probability(threshold=args.threshold, b=b, r=r) * 100:.2f}%") 36 | print(f"FP probability = {_false_positive_probability(threshold=args.threshold, b=b, r=r) * 100:.2f}%") 37 | -------------------------------------------------------------------------------- /zyda_reproduction/1_downloading/process_repo_slimpajama.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datasets 3 | from tqdm import tqdm 4 | import zstandard as zstd 5 | import io 6 | from zyda.utils.common import ensure_directory_exists 7 | 8 | # Generating Slimpajama jsonl file directly from the repo, since loading Slimpajama from HF repository has a bug 9 | # that results in doubling the dataset. 10 | 11 | REPO_PATH = os.environ.get("SLIMPAJAMA_REPO_PATH") 12 | DATA_BASE = os.environ.get("DATA_BASE") 13 | OUTPUT_JSONL = os.path.join(DATA_BASE, "SlimPajama.jsonl") 14 | 15 | def list_files_in_directory(directory): 16 | file_list = [] 17 | for root, directories, files in os.walk(directory): 18 | for filename in files: 19 | file_list.append(os.path.join(root, filename)) 20 | return file_list 21 | 22 | 23 | def combine_jsonl_files(input_folder, output_file): 24 | 25 | jsonl_files = sorted([f for f in list_files_in_directory(input_folder) if f.endswith(".jsonl")]) 26 | zst_files = sorted([f for f in list_files_in_directory(input_folder) if f.endswith(".jsonl.zst")]) 27 | 28 | with tqdm(total=len(jsonl_files) + len(zst_files), desc='Combining JSONL files') as pbar: 29 | with open(output_file, 'w') as out_f: 30 | for filename in jsonl_files: 31 | input_file = os.path.join(input_folder, filename) 32 | with open(input_file, 'r') as f: 33 | for line in f: 34 | out_f.write(line) 35 | pbar.update(1) 36 | 37 | for filename in zst_files: 38 | input_file = os.path.join(input_folder, filename) 39 | with open(input_file, 'rb') as f: 40 | decompressor = zstd.ZstdDecompressor() 41 | text_stream = io.TextIOWrapper(decompressor.stream_reader(f), encoding='utf-8') 42 | for line in text_stream: 43 | out_f.write(line) 44 | pbar.update(1) 45 | 46 | 47 | ensure_directory_exists(OUTPUT_JSONL) 48 | combine_jsonl_files( 49 | input_folder=os.path.join(REPO_PATH, "train"), 50 | output_file=OUTPUT_JSONL, 51 | ) 52 | 53 | data = datasets.Dataset.from_json(OUTPUT_JSONL, split='train') 54 | 55 | print(data) 56 | 57 | print('Saving the dataset...') 58 | data.save_to_disk(os.path.join(DATA_BASE, "raw/slimpajama")) 59 | -------------------------------------------------------------------------------- /zyda/preprocessing_and_filtering/zh_pornsignals.json: -------------------------------------------------------------------------------- 1 | {"words": ["caoporn", "caoprom", "caopron", "caoporen", "caoponrn", "caoponav", "caopom", "caoorn", "99re", "dy888", "caopro", "hezyo", "re99", "4438x", "zooskool", "xfplay", "7tav", "xxoo", "xoxo", "52av", "freexx", "91chinese", "anquye", "cao97", "538porm", "87fuli", "91pron", "91porn", "26uuu", "4438x", "182tv", "kk4444", "777me", "ae86", "91av", "720lu", "yy6080", "6080yy", "qqchub", "paa97", "aiai777", "yy4480", "videossexo", "91free", "\u4e00\u7ea7\u7279\u9ec4\u5927\u7247", "\u5077\u62cd\u4e45\u4e45\u56fd\u4ea7\u89c6\u9891", "\u65e5\u672c\u6bdb\u7247\u514d\u8d39\u89c6\u9891\u89c2\u770b", "\u4e45\u4e45\u514d\u8d39\u70ed\u5728\u7ebf\u7cbe\u54c1", "\u9ad8\u6e05\u6bdb\u7247\u5728\u7ebf\u770b", "\u65e5\u672c\u6bdb\u7247\u9ad8\u6e05\u514d\u8d39\u89c6\u9891", "\u4e00\u7ea7\u9ec4\u8272\u5f55\u50cf\u5f71\u7247", "\u4e9a\u6d32\u7537\u4eba\u5929\u5802", "\u4e45\u4e45\u7cbe\u54c1\u89c6\u9891\u5728\u7ebf\u770b", "\u81ea\u62cd\u533a\u5077\u62cd\u4e9a\u6d32\u89c6\u9891", "\u4e9a\u6d32\u4eba\u6210\u89c6\u9891\u5728\u7ebf\u64ad\u653e", "\u8272\u59d1\u5a18\u7efc\u5408\u7ad9", "\u4e01\u9999\u4e94\u6708\u556a\u556a", "\u5728\u7ebf\u89c6\u9891\u6210\u4eba\u793e\u533a", "\u4e9a\u6d32\u4eba\u6210\u89c6\u9891\u5728\u7ebf\u64ad\u653e", "\u4e45\u4e45\u56fd\u4ea7\u81ea\u5077\u62cd", "\u4e00\u672c\u9053", "\u5927\u9999\u8549\u65e0\u7801", "\u9999\u6e2f\u7ecf\u5178\u4e09\u7ea7", "\u4e9a\u6d32\u6210\u5728\u4eba\u7ebf\u514d\u8d39\u89c6\u9891", "\u5929\u5929\u8272\u7efc\u5408\u7f51", "\u5927\u9999\u8549\u4f0a\u4eba\u4e45\u8349", "\u6b27\u7f8e\u4e00\u7ea7\u9ad8\u6e05\u7247", "\u5929\u5929\u9c81\u591c\u591c\u556a\u89c6\u9891\u5728\u7ebf", "\u514d\u8d39\u9ec4\u7247\u89c6\u9891\u5728\u7ebf\u89c2\u770b", "\u52a0\u6bd4\u52d2\u4e45\u4e45\u7efc\u5408", "\u4e45\u8349\u70ed\u4e45\u8349\u5728\u7ebf\u89c6\u9891", "\u97e9\u56fd\u4e09\u7ea7\u7247\u5927\u5168\u5728\u7ebf\u89c2\u770b", "\u9752\u9752\u8349\u5728\u7ebf\u89c6\u9891", "\u7f8e\u56fd\u4e00\u7ea7\u6bdb\u7247", "\u4e45\u8349\u5728\u7ebf\u798f\u5229\u8d44\u6e90", "\u556a\u556a\u556a\u89c6\u9891\u5728\u7ebf\u89c2\u770b\u514d\u8d39", "\u6210\u4eba\u798f\u5229\u89c6\u9891\u5728\u7ebf\u89c2\u770b", "\u5a77\u5a77\u6211\u53bb\u4e5f", "\u8001\u53f8\u673a\u5728\u7ebf\u56fd\u4ea7", "\u4e45\u4e45\u6210\u4eba\u89c6\u9891", "\u624b\u673a\u770b\u7247\u798f\u5229\u6c38\u4e45\u56fd\u4ea7", "\u9ad8\u6e05\u56fd\u4ea7\u5077\u62cd\u5728\u7ebf", "\u5927\u9999\u8549\u5728\u7ebf\u5f71\u9662", "\u65e5\u672c\u9ad8\u6e05\u514d\u8d39\u4e00\u672c\u89c6\u9891", "\u7537\u4eba\u7684\u5929\u5802\u4e1c\u4eac\u70ed", "\u5f71\u97f3\u5148\u950b\u7537\u4eba\u8d44\u6e90", "\u4e94\u6708\u5a77\u5a77\u5f00\u5fc3\u4e2d\u6587\u5b57\u5e55", "\u4e9a\u6d32\u9999\u8549\u89c6\u9891\u5728\u7ebf\u64ad\u653e", "\u5929\u5929\u556a\u4e45\u4e45\u7231\u89c6\u9891\u7cbe\u54c1", "\u8d85\u78b0\u4e45\u4e45\u4eba\u4eba\u6478\u4eba\u4eba\u641e"]} -------------------------------------------------------------------------------- /zyda_reproduction/6_generating_final_dataset/convert_jsonls_to_parquet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import more_itertools 4 | import multiprocessing as mp 5 | import datasets 6 | from zyda.utils.common import ensure_directory_exists 7 | 8 | import logging 9 | logging.basicConfig(format='%(asctime)s: %(message)s', level=logging.INFO) 10 | 11 | 12 | def save_partitions_to_parquet(data, out_folder, total, indices): 13 | for idx in indices: 14 | shard = data.shard(num_shards=total, index=idx, contiguous=True) 15 | parquet_name = f"train-{idx:05d}-of-{total:05d}.parquet" 16 | save_path = os.path.join(out_folder, parquet_name) 17 | ensure_directory_exists(save_path) 18 | logging.info(f"Saving shard {idx}") 19 | shard.to_parquet(save_path) 20 | 21 | 22 | def convert_to_parquet(args): 23 | files = os.listdir(args.input_folder) 24 | sorted_files = sorted(files, key = lambda x: x.split("_")[1].split(".")[0]) 25 | logging.info(f"Found {len(sorted_files)} files: {sorted_files}") 26 | 27 | files_paths = [os.path.join(args.input_folder, x) for x in sorted_files] 28 | ds = datasets.Dataset.from_json(files_paths, num_proc=args.num_proc) 29 | ds.cleanup_cache_files() 30 | 31 | logging.info("Removing index column") 32 | ds = ds.remove_columns("source_index") 33 | 34 | logging.info("Converting columns") 35 | ds = ds.map( 36 | lambda row: { 37 | "source_other": str(row["source_other"]), 38 | "filtering_features": str(row["filtering_features"]), 39 | }, 40 | num_proc=48, 41 | ) 42 | logging.info("Shuffling the dataset") 43 | ds = ds.shuffle() 44 | 45 | processes = [] 46 | partitions = args.partitions 47 | out_folder = args.output_folder 48 | logging.info(f"Saving {partitions} parquets to {out_folder}") 49 | indices = [list(x) for x in more_itertools.divide(args.num_proc, range(partitions))] 50 | for process_id in range(args.num_proc): 51 | p = mp.Process(target=save_partitions_to_parquet, args=(ds, out_folder, partitions, indices[process_id])) 52 | processes.append(p) 53 | p.start() 54 | for p in processes: 55 | p.join() 56 | logging.info("Done!") 57 | 58 | 59 | if __name__ == "__main__": 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument("--input-folder", type=str, required=True, help="Input folder with jsonl files.") 62 | parser.add_argument("--output-folder", type=str, required=True, help="Input folder for saving parquet files.") 63 | parser.add_argument('--num-proc', type=int, default=1, help="Number of processes for saving.") 64 | parser.add_argument("--partitions", type=int, default=1, help="Number of parquet partitions. Partitions will be distributed among saving processes.") 65 | 66 | args = parser.parse_args() 67 | 68 | convert_to_parquet(args) 69 | -------------------------------------------------------------------------------- /zyda/connected_components/generate_indices_to_remove.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Zyphra Technologies. 2 | # Based on SlimPajama codebase: https://github.com/Cerebras/modelzoo/blob/main/src/cerebras/modelzoo/data_preparation/nlp/slimpajama/dedup/generate_duplicates_dict.py 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import argparse 17 | import pickle 18 | from collections import defaultdict 19 | 20 | import tqdm 21 | 22 | import logging 23 | logging.basicConfig(format='%(asctime)s: %(message)s', level=logging.INFO) 24 | 25 | 26 | def generate_duplicates(args): 27 | print() 28 | 29 | if args.ranking: 30 | ranking = {} 31 | for i, dataset in enumerate(args.ranking): 32 | ranking[dataset] = i 33 | logging.info(f"Ranking of datasets: {ranking}") 34 | 35 | # load pickled components and other artifacts 36 | logging.info(f'Loading {args.input_file}') 37 | with open(args.input_file, "rb") as fin: 38 | components, n_components, reverse_mapper = pickle.load(fin) 39 | 40 | logging.info("Processing connected components...") 41 | duplicates = defaultdict(set) 42 | n_duplicate_docs = 0 43 | for component in tqdm.tqdm(components, unit="components", unit_scale=True): 44 | if args.ranking: 45 | component.sort(key=lambda x: ranking[reverse_mapper[x].split("@")[0]]) 46 | for j in range(1, len(component)): 47 | doc = reverse_mapper[component[j]] 48 | file_name, shard, shard_index, global_index = doc.split("@") 49 | duplicates[file_name].add((int(shard), int(shard_index), int(global_index))) 50 | n_duplicate_docs += 1 51 | 52 | logging.info(f"Total number of duplicate documents that will be removed: {n_duplicate_docs}") 53 | logging.info("Duplicates to remove per dataset:") 54 | for ds_name, dupes in duplicates.items(): 55 | logging.info(f" {ds_name}: {len(dupes)}") 56 | 57 | logging.info(f"Saving to {args.out_file}...") 58 | with open(args.out_file, "wb") as fout: 59 | pickle.dump(duplicates, fout, protocol=5) 60 | logging.info("Done!") 61 | 62 | 63 | if __name__ == "__main__": 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument("--input-file", type=str, required=True, help="Input pickle file with connected components") 66 | parser.add_argument("--out-file", type=str, required=True, help="Output pickle file to save indices of duplicates to remove") 67 | parser.add_argument("--ranking", nargs="+", type=str, help="Ranking of datasets for choosing a single document from a component") 68 | 69 | args = parser.parse_args() 70 | generate_duplicates(args) 71 | -------------------------------------------------------------------------------- /zyda/lsh_minhash/compute_minhash.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Zyphra Technologies. 2 | # Based on SlimPajama codebase: https://github.com/Cerebras/modelzoo/blob/main/src/cerebras/modelzoo/data_preparation/nlp/slimpajama/dedup/to_hash.py 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | import argparse 18 | import datasets 19 | from datasketch import MinHash 20 | from collections import defaultdict 21 | from zyda.utils.text import get_features 22 | 23 | import nltk 24 | nltk.download('punkt') 25 | 26 | import logging 27 | logging.basicConfig(format='%(asctime)s: %(message)s', level=logging.INFO) 28 | 29 | COLUMNS_TO_SAVE = ["dataset_name", "shard", "shard_index", "global_index"] 30 | 31 | 32 | def to_minhash( 33 | batch, 34 | key: str = "transformed_text", 35 | width: int = 13, 36 | num_perm: int = 128, 37 | ): 38 | output = defaultdict(list) 39 | for text in batch[key]: 40 | m = MinHash(num_perm=num_perm) 41 | m.update_batch(map(lambda x: x.encode('utf8'), get_features(text, width))) 42 | output["seed"].append(m.seed) 43 | output["hashvalues"].append(m.hashvalues) 44 | return output 45 | 46 | 47 | if __name__ == '__main__': 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument('--load-path', type=str, required=True, help='Path to folder with preprocessed shards') 50 | parser.add_argument('--save-path', type=str, required=True, help='Path to folder to where we save ') 51 | parser.add_argument('--num-proc', type=int, default=1, help='Number of processes') 52 | parser.add_argument('--key', type=str, default='transformed_text', help='Key to use for minhashing') 53 | parser.add_argument('--width', type=int, default=13, help='n-grams size for minhashes') 54 | parser.add_argument('--num-perm', type=int, default=128, help='Number of permutation for computing minhashes') 55 | parser.add_argument('--from-scratch', action='store_true', help='If specified, will forcefully do every shard regardless of previous progress') 56 | args = parser.parse_args() 57 | 58 | shards_dirs = sorted(os.listdir(args.load_path)) 59 | logging.info(f"Found {len(shards_dirs)} shards") 60 | for i, shard_dir in enumerate(shards_dirs): 61 | load_path = os.path.join(args.load_path, shard_dir) 62 | save_path = os.path.join(args.save_path, shard_dir) 63 | print() 64 | logging.info(f"Processing shard {i + 1} / {len(shards_dirs)} from {load_path}") 65 | if os.path.exists(save_path) and not args.from_scratch: 66 | logging.info(f"Already processed!") 67 | continue 68 | 69 | shard = datasets.load_from_disk(load_path) 70 | logging.info(f"Cache cleaned: {shard.cleanup_cache_files()}") 71 | 72 | shard_minhash = shard.map( 73 | lambda batch: to_minhash(batch, key=args.key, width=args.width, num_perm=args.num_perm), 74 | batched=True, 75 | num_proc=args.num_proc, 76 | remove_columns=[col for col in shard.column_names if col not in COLUMNS_TO_SAVE] 77 | ) 78 | logging.info(f"Saving minhash to: {save_path}") 79 | shard_minhash.save_to_disk(save_path, max_shard_size="8GB") 80 | logging.info(f"Cache cleaned: {shard.cleanup_cache_files()}") 81 | -------------------------------------------------------------------------------- /zyda/utils/filtering.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Zyphra Technologies. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Dict, Optional 16 | 17 | FILTERING_FEATURES = ["mean_word_length", "fraction_non_alphanumeric", "fraction_numerical", "pii_count", "pattern_counts", "word_list_counts", "substrings_counts"] 18 | 19 | MIN_LENGTH = 100 20 | MIN_MEAN_WORD_LENGTH = 3 21 | MAX_MEAN_WORD_LENGTH = 12 22 | MAX_FRACTION_NON_ALPHANUMERIC = 0.3 23 | MAX_FRACTION_NUMERICAL = 0.2 24 | MAX_CURSED_FRACTION = 0.01 25 | MAX_NUM_REPEATED_SUBSTRINGS = 50 26 | PATTERNS_WITH_MAX_COUNTS = { 27 | "xml": 10, 28 | " bool: 62 | """ 63 | Given a row, decided whether to remove or keep it. 64 | Returns True is the row to be kept, False otherwise. 65 | """ 66 | 67 | if len(row[key]) < min_length: 68 | return False 69 | 70 | if row["mean_word_length"] < min_mean_word_length: 71 | return False 72 | 73 | if row["mean_word_length"] > max_mean_word_length: 74 | return False 75 | 76 | if row["fraction_non_alphanumeric"] > max_fraction_non_alphanumeric: 77 | return False 78 | 79 | if row["fraction_numerical"] > max_fraction_numerical: 80 | return False 81 | 82 | if row["substrings_counts"] > max_repeated_substrings: 83 | return False 84 | 85 | for pattern, max_count in patterns_with_max_counts.items(): 86 | if row["pattern_counts"][pattern] > max_count: 87 | return False 88 | 89 | for pattern, max_fraction in patterns_with_max_fractions.items(): 90 | if row["pattern_counts"][pattern] / len(row[key]) > max_fraction: 91 | return False 92 | 93 | for word_list, max_count in word_lists_with_max_counts.items(): 94 | if row["word_list_counts"][word_list] > max_count: 95 | return False 96 | 97 | for word_list, max_fraction in word_lists_with_max_fractions.items(): 98 | if row["word_list_counts"][word_list] / len(row[key]) > max_fraction: 99 | return False 100 | 101 | if dupe_inds is not None and row["global_index"] in dupe_inds: 102 | return False 103 | 104 | return True 105 | -------------------------------------------------------------------------------- /zyda_reproduction/6_generating_final_dataset/run_convert_to_parquet_lsh_0.4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | IN=$DATA_BASE/zyda_0.4-final/jsonl 3 | OUT=$DATA_BASE/zyda_0.4-final/parquet 4 | 5 | DATASET=refinedweb 6 | if test -d $IN/$DATASET; then 7 | echo Processing $DATASET 8 | mkdir -p $OUT/$DATASET 9 | python $REPO_BASE/zyda_reproduction/6_generating_final_dataset/convert_jsonls_to_parquet.py \ 10 | --input-folder $IN/$DATASET \ 11 | --output-folder $OUT/$DATASET \ 12 | --num-proc $NUM_PROC \ 13 | --partitions 8192 14 | fi 15 | 16 | DATASET=arxiv 17 | if test -d $IN/$DATASET; then 18 | echo Processing $DATASET 19 | mkdir -p $OUT/$DATASET 20 | python $REPO_BASE/zyda_reproduction/6_generating_final_dataset/convert_jsonls_to_parquet.py \ 21 | --input-folder $IN/$DATASET \ 22 | --output-folder $OUT/$DATASET \ 23 | --num-proc $NUM_PROC \ 24 | --partitions 64 25 | fi 26 | 27 | DATASET=peS2o 28 | if test -d $IN/$DATASET; then 29 | echo Processing $DATASET 30 | mkdir -p $OUT/$DATASET 31 | python $REPO_BASE/zyda_reproduction/6_generating_final_dataset/convert_jsonls_to_parquet.py \ 32 | --input-folder $IN/$DATASET \ 33 | --output-folder $OUT/$DATASET \ 34 | --num-proc $NUM_PROC \ 35 | --partitions 512 36 | fi 37 | 38 | DATASET=c4-en 39 | if test -d $IN/$DATASET; then 40 | echo Processing $DATASET 41 | mkdir -p $OUT/$DATASET 42 | python $REPO_BASE/zyda_reproduction/6_generating_final_dataset/convert_jsonls_to_parquet.py \ 43 | --input-folder $IN/$DATASET \ 44 | --output-folder $OUT/$DATASET \ 45 | --num-proc $NUM_PROC \ 46 | --partitions 2048 47 | fi 48 | 49 | DATASET=pile-uncopyrighted 50 | if test -d $IN/$DATASET; then 51 | echo Processing $DATASET 52 | mkdir -p $OUT/$DATASET 53 | python $REPO_BASE/zyda_reproduction/6_generating_final_dataset/convert_jsonls_to_parquet.py \ 54 | --input-folder $IN/$DATASET \ 55 | --output-folder $OUT/$DATASET \ 56 | --num-proc $NUM_PROC \ 57 | --partitions 1024 58 | fi 59 | 60 | DATASET=slimpajama 61 | if test -d $IN/$DATASET; then 62 | echo Processing $DATASET 63 | mkdir -p $OUT/$DATASET 64 | python $REPO_BASE/zyda_reproduction/6_generating_final_dataset/convert_jsonls_to_parquet.py \ 65 | --input-folder $IN/$DATASET \ 66 | --output-folder $OUT/$DATASET \ 67 | --num-proc $NUM_PROC \ 68 | --partitions 2048 69 | fi 70 | 71 | DATASET=starcoder-languages 72 | if test -d $IN/$DATASET; then 73 | echo Processing $DATASET 74 | mkdir -p $OUT/$DATASET 75 | python $REPO_BASE/zyda_reproduction/6_generating_final_dataset/convert_jsonls_to_parquet.py \ 76 | --input-folder $IN/$DATASET \ 77 | --output-folder $OUT/$DATASET \ 78 | --num-proc $NUM_PROC \ 79 | --partitions 2048 80 | fi 81 | 82 | DATASET=starcoder-git-commits-cleaned 83 | if test -d $IN/$DATASET; then 84 | echo Processing $DATASET 85 | mkdir -p $OUT/$DATASET 86 | python $REPO_BASE/zyda_reproduction/6_generating_final_dataset/convert_jsonls_to_parquet.py \ 87 | --input-folder $IN/$DATASET \ 88 | --output-folder $OUT/$DATASET \ 89 | --num-proc $NUM_PROC \ 90 | --partitions 128 91 | fi 92 | 93 | DATASET=starcoder-jupyter-structured-clean-dedup 94 | if test -d $IN/$DATASET; then 95 | echo Processing $DATASET 96 | mkdir -p $OUT/$DATASET 97 | python $REPO_BASE/zyda_reproduction/6_generating_final_dataset/convert_jsonls_to_parquet.py \ 98 | --input-folder $IN/$DATASET \ 99 | --output-folder $OUT/$DATASET \ 100 | --num-proc $NUM_PROC \ 101 | --partitions 32 102 | fi 103 | 104 | DATASET=starcoder-github-issues-filtered-structured 105 | if test -d $IN/$DATASET; then 106 | echo Processing $DATASET 107 | mkdir -p $OUT/$DATASET 108 | python $REPO_BASE/zyda_reproduction/6_generating_final_dataset/convert_jsonls_to_parquet.py \ 109 | --input-folder $IN/$DATASET \ 110 | --output-folder $OUT/$DATASET \ 111 | --num-proc $NUM_PROC \ 112 | --partitions 256 113 | fi 114 | -------------------------------------------------------------------------------- /zyda_reproduction/6_generating_final_dataset/generate_final_jsonls.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Zyphra Technologies. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import pickle 17 | import os 18 | import tqdm 19 | import random 20 | import gc 21 | import json 22 | import multiprocessing as mp 23 | 24 | from zyda.utils.common import ComboDataset, ensure_directory_exists 25 | from zyda.utils.filtering import FILTERING_FEATURES 26 | 27 | import logging 28 | logging.basicConfig(format='%(asctime)s: %(message)s', level=logging.INFO) 29 | 30 | KEY = "transformed_text" 31 | NEW_KEY = "text" 32 | DISCARD_KEYS = ["shard", "shard_index"] 33 | DATASETS_NAMES = [ 34 | "pile-uncopyrighted", 35 | "c4-en", 36 | "peS2o", 37 | "arxiv", 38 | "refinedweb", 39 | "slimpajama", 40 | "starcoder-languages", 41 | "starcoder-github-issues-filtered-structured", 42 | "starcoder-jupyter-structured-clean-dedup", 43 | "starcoder-git-commits-cleaned", 44 | ] 45 | DATA_BASE = os.environ.get("DATA_BASE") 46 | PATHS = {name: os.path.join(DATA_BASE, f"processed/{name}") for name in DATASETS_NAMES} 47 | DATASETS = {name: ComboDataset(path) for name, path in PATHS.items() if os.path.exists(path)} 48 | 49 | 50 | def save_shard_to_jsonl(data, out_folder, text_key, ds_name, total, idx): 51 | shard = data.shard(num_shards=total, index=idx, contiguous=True) 52 | jsonl_name = f"{ds_name}_{idx}.jsonl" 53 | save_path = os.path.join(out_folder, jsonl_name) 54 | ensure_directory_exists(save_path) 55 | with open(save_path, 'w') as f: 56 | for row in tqdm.tqdm(shard, desc=jsonl_name, unit="docs", unit_scale=True, position=idx): 57 | row_dict = { 58 | "filtering_features": {}, 59 | "source_other": {}, 60 | } 61 | for key, val in row.items(): 62 | if key in DISCARD_KEYS: 63 | continue 64 | elif key in FILTERING_FEATURES: 65 | row_dict["filtering_features"][key] = str(val) 66 | elif key == text_key: 67 | row_dict[NEW_KEY] = str(val) 68 | elif key == "dataset_name": 69 | row_dict["source"] = str(val) 70 | elif key == "global_index": 71 | row_dict["source_index"] = str(val) 72 | elif key != KEY: # need this check for starcoder 73 | row_dict["source_other"][key] = str(val) 74 | f.write(json.dumps(row_dict) + "\n") 75 | 76 | 77 | def generate_datasets(args): 78 | 79 | logging.info(f"Loading duplicates indices from {args.input_indices}...") 80 | with open(args.input_indices, "rb") as f: 81 | duplicates = pickle.load(f) 82 | gc.collect() 83 | 84 | for ds_name, dataset in DATASETS.items(): 85 | 86 | if ds_name not in duplicates: 87 | logging.info(f"Dataset {ds_name} is not in duplicates dict, so setting dupes to empty set") 88 | dupes = set() 89 | else: 90 | dupes = duplicates[ds_name] 91 | 92 | logging.info(f"Processing {ds_name}") 93 | logging.info(f"Number of duplicates to remove: {len(dupes)}") 94 | 95 | dupe_inds = set() 96 | for dupe in tqdm.tqdm(dupes, desc="Processing dupe indices", unit="inds", unit_scale=True): 97 | _, _, global_index = dupe 98 | dupe_inds.add(global_index) 99 | num_dupes_removed = len(dupe_inds) 100 | 101 | if len(dupe_inds) > 0: 102 | select_inds = [] 103 | for i in tqdm.tqdm(range(len(dataset.ds)), desc="Generating indices to select", unit="inds", unit_scale=True): 104 | if args.check_indices and random.random() < 1e-4: 105 | assert i == dataset.ds[i]["global_index"] 106 | if i not in dupe_inds: 107 | select_inds.append(i) 108 | del dupe_inds 109 | gc.collect() 110 | 111 | logging.info(f"Selecting {len(select_inds)} rows") 112 | filtered_ds = dataset.ds.select(select_inds) 113 | del select_inds 114 | gc.collect() 115 | else: 116 | filtered_ds = dataset.ds 117 | 118 | logging.info(f"Removed {len(dataset.ds) - len(filtered_ds)} rows overall") 119 | logging.info(f"Length of the final dataset: {len(filtered_ds)} rows ({100 * len(filtered_ds) / len(dataset.ds):.2f}% of original {len(dataset.ds)} rows)") 120 | assert len(dataset.ds) - len(filtered_ds) >= num_dupes_removed 121 | 122 | text_key = KEY 123 | if "starcoder" in ds_name: 124 | text_key = "content" 125 | logging.info(f"Starcoder detected: using {text_key} as a key") 126 | 127 | out_folder = os.path.join(args.out_folder, ds_name) 128 | partitions = args.jsonl_partitions 129 | processes = [] 130 | logging.info(f"Saving {partitions} jsonls to {out_folder}") 131 | for process_id in range(partitions): 132 | p = mp.Process(target=save_shard_to_jsonl, args=(filtered_ds, out_folder, text_key, ds_name, partitions, process_id)) 133 | processes.append(p) 134 | p.start() 135 | for p in processes: 136 | p.join() 137 | print("\n" * partitions) 138 | 139 | logging.info(f"Done!") 140 | 141 | 142 | if __name__ == "__main__": 143 | parser = argparse.ArgumentParser() 144 | parser.add_argument("--input-indices", type=str, required=True, help="Input pickle file with connected components.") 145 | parser.add_argument("--out-folder", type=str, required=True, help="Output folder for saving jsonl files.") 146 | parser.add_argument("--jsonl-partitions", type=int, default=48, help="Number of jsonl partitions. Note: one process will be used per partition.") 147 | 148 | args = parser.parse_args() 149 | 150 | generate_datasets(args) 151 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Zyda 2 | This repository contains the source code of the `zyda` package - a powerful Python-based package for filtering and deduplicating text datasets for subsequent use in LLM pretraining. The core dataset objects are based on HuggingFace's `datasets` package, making it straightforward to work with datasets hosted on HuggingFace. In addition, the default HuggingFace format is based on `pyarrow`, which allows for fast indexing and easy parallel processing of the datasets. 3 | 4 | This package was used for producing [Zyda dataset](https://huggingface.co/datasets/Zyphra/Zyda). See accompanying [technical report](https://arxiv.org/abs/2406.01981) for details. 5 | 6 | We release step-by-step instructions on how to reproduce Zyda using this package. 7 | 8 | ## Installation 9 | ``` 10 | git clone https://github.com/Zyphra/Zyda_processing.git 11 | cd Zyda_processing 12 | pip install -e . 13 | ``` 14 | 15 | ## Repository structure 16 | The `zyda` folder contains the source code for the package. It consists of the following subfolders: 17 | - `preprocessing_and_filtering` - code for preprocessing and filtering the datasets 18 | - `lsh_minhash` - code for computing minhash signatures and building the LSH index 19 | - `connected_components` - code for finding connected components in the graph of duplicates and for identifying which document to keep in every component 20 | - `utils` - commonly reused code 21 | 22 | The `zyda_reproduction` folder contains scripts for reproducing the Zyda dataset arranged in a step-by-step manner. 23 | 24 | ## How to reproduce Zyda dataset 25 | All the scripts necessary for Zyda reproduction are in `zyda_reproduction` folder. Below are step-by-step instructions on how to run them. 26 | Before running the scripts, please, set the following environment variables: 27 | - `DATA_BASE` - base location in your filesystem to save results of processing steps 28 | - `REPO_BASE` - location of `Zyda_processing` repository 29 | - `NUM_PROC` - number of parallel processes to use at various stages of processing 30 | - [optional] `HF_HOME` - location of HuggingFace cache 31 | 32 | ### 1. Downloading component datasets 33 | Scripts for downloading the component datasets of Zyda are in `zyda_reproduction/1_downloading`. 34 | 35 | All the scripts save downloaded datasets into separate folders in HuggingFace format in `$DATA_BASE/raw/` folders. 36 | 37 | Downloading most components from HuggingFace is straighforward: e.g. run `download_arxiv_pile_peS2o_c4_refinedweb.py` and `download_refinedweb.py` scripts. 38 | 39 | However, we had to apply special handling to SlimPajama and StarCoder: 40 | 1. Clone their HuggingFace repositories locally 41 | 2. Set `SLIMPAJAMA_REPO_PATH` and `STARCODER_REPO_PATH` environmental variable with paths to local SlimPajama and StarCoders repositories respectively 42 | 3. Run scripts `process_repo_slimpajama.py` and `process_repo_starcoder.py` for generating raw versions of the datasets in HuggingFace format. 43 | 44 | ### 2. Preprocessing and filtering 45 | Scripts for preprocessing and filtering are in `zyda_reproduction/2_preprocessing_and_filtering`. 46 | 47 | Run all the bash scripts in this folder. 48 | 49 | This stage performs the following operations: 50 | 1. Generation of filtering features 51 | 2. Transformation of the text 52 | 3. Filtering of the documents (default filtering parameters can be found in `zyda/utils/filtering.py`) 53 | 4. Splitting of the resultant datasets to shards, and then saving them in `$DATA_BASE/processed//shard_` folders in HuggingFace format 54 | 55 | ### 3. Computing minhashes 56 | Scripts for computing minhash signatures are in `zyda_reproduction/3_minhashing`. 57 | 58 | Run all the bash scripts in this folder. 59 | 60 | This stage performs the following operations: 61 | 1. Normalizes the text of each document and splits it into words 62 | 2. Generates 13-grams based on words 63 | 3. Computes minhash signatures with of the size of 128 64 | 4. Saves results in `$DATA_BASE/minhash/` folders in HuggingFace format (it only saves columns necessary for indexing along with minhashes) 65 | 66 | ### 4. Building LSH index 67 | Script for building the LSH index is at `zyda_reproduction/4_lsh_indexing/run_lsh_dupes_0.4_all.sh`. 68 | 69 | For Zyda we used 40% Jaccard similarity threshold when building our LSH index. The optimal split of minhash signatures can be computed using `zyda/lsh_minhash/compute_optimal_params.py`, which for our threshold and signature size gave us 32 bands with a range of 4. 70 | 71 | This is the most time-consuming and memory-intensive stage. We split it in a parallel job distributed among 8 nodes of our HPC cluster, each with 92 physical cores and 2TB of RAM. It took approximately 2 days with a peak RAM consumption of 1.5TB. 72 | 73 | We stripped away our distributed configuration in the script `run_lsh_dupes_0.4_all.sh`, basically assuming it will be run on one node. To limit RAM consumption we allow only 2 minhash bands to be processed in parallel by specifying `--bands-parallel 2` flag. On one compute node, bands are be split into 16 groups of size 2, and such groups are processed sequentially. 74 | 75 | The resultant LSH index is saved in `$DATA_BASE/lsh_0.4/lsh_index-.pickle` files. We also save all the identified duplicate pairs in `$DATA_BASE/lsh_0.4/dupes/all_pairs-.txt` files. 76 | 77 | ### 5. Clustering duplicates using connected components and generating indices of documents to remove 78 | Script for clustering duplicates using connected components and generating indices of documents to remove is at `zyda_reproduction/5_clustering/run_cc_lsh_0.4_dupes.sh`. 79 | 80 | This stage performs clustering of identified duplicated documents by identifying connected components in a graph, where the nodes are documents and the edges are duplicate pairs. Graph processing is implemented in `zyda/connected_components/generate_connected_components.py`. 81 | 1. It first performs processing of all duplicate pairs text files (coming from building indices of individual bands) and generates a single set that is saved to `$DATA_BASE/lsh_0.4/dupes/output/cc-set-final.txt` 82 | 2. It uses `networkit` package for building a graph and finding connecting components. It saves the graph at `$DATA_BASE/lsh_0.4/dupes/output/cc-graph.graph`, document-to-node mapper at `$DATA_BASE/lsh_0.4/dupes/output/cc-mapper.pickle`, and connected components with node-to-document reverse mapper at `$DATA_BASE/lsh_0.4/dupes/output/cc.pickle`. 83 | 84 | Finally, we generate indices of duplicate documents to remove by sorting every document in a cluster according to a ranking and keeping only the highest ranked one. This is implemented in `zyda/connected_components/generate_indices_to_remove.py`. The resultant dict with a mapping of datasets names to indices to remove is saved in `$DATA_BASE/lsh_0.4/dupes/output/dupes.pickle`. We decided to use the following ranking: 85 | 1. starcoder components 86 | 2. refinedweb 87 | 3. peS2o 88 | 4. arxiv 89 | 5. c4-en 90 | 6. pile-uncopyrighted 91 | 7. slimpajama 92 | 93 | This stage took roughly half a day to run. 94 | 95 | 96 | ### 6. Generating final dataset 97 | Scripts for generating final dataset are in `zyda_reproduction/6_generating_final_dataset`. 98 | 99 | Bash script `zyda_reproduction/6_generating_final_dataset/run_generate_lsh_0.4_jsonls.sh` generates final jsonl files of our dataset: 100 | 1. Load processed and filtered local HuggingFace datasets from stage 2 101 | 2. Remove documents using indices from the previous stage 102 | 3. Save resultant datasets in jsonl partitions to `$DATA_BASE/zyda_0.4-final/jsonl` 103 | 104 | If you want to generate parquet files, you can convert jsonl's by running `zyda_reproduction/6_generating_final_dataset/run_convert_to_parquet_lsh_0.4.sh`, which saves them to `$DATA_BASE/zyda_0.4-final/parquet`. Files generated by this script were uploaded to Zyda's HuggingFace dataset repository. 105 | 106 | ## Citation 107 | To cite our work please use: 108 | 109 | ``` 110 | @misc{tokpanov2024zyda, 111 | title={Zyda: A 1.3T Dataset for Open Language Modeling}, 112 | author={Yury Tokpanov and Beren Millidge and Paolo Glorioso and Jonathan Pilault and Adam Ibrahim and James Whittington and Quentin Anthony}, 113 | year={2024}, 114 | eprint={2406.01981}, 115 | archivePrefix={arXiv}, 116 | primaryClass={cs.CL} 117 | } 118 | ``` 119 | 120 | ## Acknowledgements 121 | We would like to acknolwedge SlimPajama's team for publicly releasing their codebase with detailed instructions and explanations: [huggingface link](https://huggingface.co/datasets/cerebras/SlimPajama-627B), [github link](https://github.com/Cerebras/modelzoo/tree/Release_2.2.1/src/cerebras/modelzoo/data_preparation/nlp/slimpajama). We used their code as a starting point for LSH minhash deduplication. We made significant changes to optimize parallel performance and enable distributed deduplication jobs on our HPC cluster. 122 | 123 | 124 | ## License 125 | [Apache License 2.0](./LICENSE) 126 | -------------------------------------------------------------------------------- /zyda/connected_components/generate_connected_components.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Zyphra Technologies. 2 | # Based on SlimPajama codebase: https://github.com/Cerebras/modelzoo/blob/main/src/cerebras/modelzoo/data_preparation/nlp/slimpajama/dedup/generate_connected_components.py 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Tuple, List, Dict 17 | from glob import glob 18 | import argparse 19 | import pickle 20 | import os 21 | import networkit as nk 22 | import tqdm 23 | import subprocess 24 | import multiprocessing as mp 25 | import gc 26 | 27 | import logging 28 | logging.basicConfig(format='%(asctime)s: %(message)s', level=logging.INFO) 29 | 30 | from zyda.utils.common import ensure_directory_exists 31 | 32 | 33 | def construct_graph(set_of_duplicate_pairs: set) -> Tuple[nk.Graph, Dict[str, int]]: 34 | G = nk.Graph() 35 | mapper = {} 36 | for pair in tqdm.tqdm(set_of_duplicate_pairs, desc="Bulding graph", unit="dupes", unit_scale=True): 37 | node1_name, node2_name = pair 38 | if node1_name not in mapper: 39 | mapper[node1_name] = G.addNode() 40 | if node2_name not in mapper: 41 | mapper[node2_name] = G.addNode() 42 | G.addEdge(mapper[node1_name], mapper[node2_name]) 43 | return G, mapper 44 | 45 | 46 | def find_connected_components(G: nk.Graph): 47 | cc = nk.components.ConnectedComponents(G) 48 | cc.run() 49 | return cc.getComponents(), cc.numberOfComponents() 50 | 51 | 52 | def count_file_lines(fname: str) -> int: 53 | p = subprocess.Popen(['wc', '-l', fname], stdout=subprocess.PIPE, 54 | stderr=subprocess.PIPE) 55 | result, err = p.communicate() 56 | if p.returncode != 0: 57 | raise IOError(err) 58 | return int(result.strip().split()[0]) 59 | 60 | 61 | def process_files(args: Tuple[int, str, bool, List[str]]) -> str: 62 | set_of_duplicate_pairs = set() 63 | pid = args[0] 64 | save_path = args[1] 65 | from_scratch = args[2] 66 | files = args[3] 67 | 68 | # save_path = save_path.replace(".pickle", f"-set-{pid}.pickle") 69 | save_path = save_path.replace(".pickle", f"-set-{pid}.txt") 70 | if os.path.exists(save_path) and not from_scratch: 71 | logging.info(f"Found {save_path}") 72 | return save_path, None 73 | 74 | for file in files: 75 | with open(file, "r") as f: 76 | for line in tqdm.tqdm(f, desc=file, unit="dupes", unit_scale=True, position=pid): 77 | pair = tuple(line.strip().split(" :: ")) 78 | if pair[0] != pair[1]: 79 | set_of_duplicate_pairs.add(pair) 80 | 81 | desc = f"Saving set to {save_path}" 82 | ensure_directory_exists(save_path) 83 | with open(save_path, "w") as f: 84 | for pair in tqdm.tqdm(set_of_duplicate_pairs, desc=desc, unit="dupes", unit_scale=True, position=pid): 85 | f.write(f"{pair[0]} :: {pair[1]}\n") 86 | return save_path, len(set_of_duplicate_pairs) 87 | 88 | 89 | def get_set_of_duplicate_pairs(set_save_path: str, from_scratch: bool) -> set: 90 | set_of_duplicate_pairs = set() 91 | if not from_scratch and os.path.exists(set_save_path): 92 | # Set of duplicates exists, load it 93 | logging.info(f"Counting lines in {set_save_path}") 94 | total_lines = count_file_lines(set_save_path) 95 | logging.info(f"Constructing set of duplicates from {set_save_path}") 96 | with open(set_save_path, "r") as f: 97 | for line in tqdm.tqdm(f, total=total_lines, unit="docs", unit_scale=True): 98 | pair = tuple(line.strip().split(" :: ")) 99 | if pair[0] != pair[1]: 100 | set_of_duplicate_pairs.add(pair) 101 | else: 102 | # Need to generate a set of duplicates 103 | all_files = sorted(glob(f"{args.input_dir}/*.txt")) 104 | workers_files = [[] for _ in range(args.workers)] 105 | for i, file in enumerate(all_files): 106 | workers_files[i % args.workers].append(file) 107 | 108 | workers_args = [] 109 | for i, files in enumerate(workers_files): 110 | workers_args.append((i, args.out_file, args.from_scratch, files)) 111 | 112 | with mp.Pool(processes=args.workers) as p: 113 | sets_files = p.map(process_files, workers_args) 114 | 115 | logging.info("Processing sets from workers") 116 | logging.info("Constructing final set") 117 | for file, total_pairs in sets_files: 118 | with open(file, "r") as f: 119 | for line in tqdm.tqdm(f, total=total_pairs, desc=file, unit="dupes", unit_scale=True): 120 | pair = tuple(line.strip().split(" :: ")) 121 | if pair[0] != pair[1]: 122 | set_of_duplicate_pairs.add(pair) 123 | 124 | desc = f"Saving final set to {set_save_path}..." 125 | ensure_directory_exists(set_save_path) 126 | with open(set_save_path, "w") as f: 127 | for pair in tqdm.tqdm(set_of_duplicate_pairs, desc=desc, unit="dupes", unit_scale=True): 128 | f.write(f"{pair[0]} :: {pair[1]}\n") 129 | 130 | return set_of_duplicate_pairs 131 | 132 | 133 | def generate_connected_components_mp(args): 134 | print() 135 | nk.setNumberOfThreads(args.nk_threads) 136 | 137 | graph_save_path = args.out_file.replace(".pickle", f"-graph.graph") 138 | mapper_save_path = args.out_file.replace(".pickle", f"-mapper.pickle") 139 | set_save_path = args.out_file.replace(".pickle", f"-set-final.txt") 140 | 141 | if not args.from_scratch and os.path.exists(graph_save_path) and os.path.exists(mapper_save_path): 142 | # Graph and mapper exist, so load them 143 | logging.info(f"Loading a graph from {graph_save_path}") 144 | G = nk.graphio.readGraph(graph_save_path, nk.Format.METIS) 145 | 146 | logging.info(f"Loading a mapper from {mapper_save_path}") 147 | with open(mapper_save_path, "rb") as f: 148 | mapper = pickle.load(f) 149 | 150 | else: 151 | logging.info("Processing text files with duplicates...") 152 | set_of_duplicate_pairs = get_set_of_duplicate_pairs(set_save_path, args.from_scratch) 153 | logging.info(f"Length of the set of duplicates: {len(set_of_duplicate_pairs)}") 154 | 155 | # Generate a graph using id's as nodes and a pair of ids as an edge 156 | logging.info("Building graph...") 157 | G, mapper = construct_graph(set_of_duplicate_pairs) 158 | del set_of_duplicate_pairs 159 | gc.collect() 160 | 161 | logging.info(f"Saving graph to {graph_save_path}") 162 | nk.graphio.writeGraph(G, graph_save_path, nk.Format.METIS) 163 | 164 | logging.info(f"Saving mapper to {mapper_save_path}") 165 | ensure_directory_exists(mapper_save_path) 166 | with open(mapper_save_path, "wb") as f: 167 | pickle.dump(mapper, f, protocol=5) 168 | 169 | logging.info("Finding connected components...") 170 | components, n_components = find_connected_components(G) 171 | del G 172 | gc.collect() 173 | logging.info(f"Number of connected components: {n_components}") 174 | 175 | logging.info("Building reverse mapper...") 176 | reverse_mapper = {value: key for key, value in tqdm.tqdm(mapper.items(), unit="docs", unit_scale=True)} 177 | del mapper 178 | gc.collect() 179 | 180 | # dump pickled cc on disk and load if needed 181 | logging.info(f"Saving connected components to {args.out_file}...") 182 | ensure_directory_exists(args.out_file) 183 | with open(args.out_file, "wb") as fout: 184 | pickle.dump((components, n_components, reverse_mapper), fout, protocol=5) 185 | logging.info("Done!") 186 | 187 | 188 | if __name__ == "__main__": 189 | parser = argparse.ArgumentParser() 190 | parser.add_argument("--input-dir", type=str, required=True, help="Input directory containing text files with duplicate pairs") 191 | parser.add_argument( 192 | "--out-file", type=str, required=True, 193 | help="Output pickle file to save connected components. Prefix will be used for saving intermediate things as well (sets of duplicates, graph)" 194 | ) 195 | parser.add_argument("--workers", type=int, default=1, help="Number of workers for processing text files with duplicate pairs") 196 | parser.add_argument("--nk-threads", type=int, default=96, help="Number of threads for graph processing") 197 | parser.add_argument("--from-scratch", action="store_true", help="Start from scratch ignoring any intermediate files") 198 | args = parser.parse_args() 199 | generate_connected_components_mp(args) 200 | -------------------------------------------------------------------------------- /zyda/lsh_minhash/build_lsh_index.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Zyphra Technologies. 2 | # Based on SlimPajama codebase: https://github.com/Cerebras/modelzoo/blob/main/src/cerebras/modelzoo/data_preparation/nlp/slimpajama/dedup/generate_duplicate_pairs.py 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import argparse 17 | import pickle 18 | import queue 19 | import time 20 | import os 21 | import more_itertools 22 | from typing import List 23 | from tqdm import tqdm 24 | from collections import defaultdict 25 | from multiprocessing import Process, Queue 26 | from datasketch.lean_minhash import LeanMinHash 27 | from zyda.utils.common import ensure_directory_exists 28 | 29 | import datasets 30 | 31 | import logging 32 | logging.basicConfig(format='%(asctime)s: %(message)s', level=logging.INFO) 33 | 34 | 35 | def _H(hs): 36 | return bytes(hs.byteswap().data) 37 | 38 | 39 | def get_hashes_band( 40 | shards: List[datasets.Dataset], 41 | doc_queue: Queue, 42 | i: int, 43 | r: int, 44 | log_interval: int = 0, 45 | ): 46 | for shard in shards: 47 | for j, item in enumerate(shard): 48 | if log_interval and j % log_interval == 0: 49 | logging.debug(f"Band {i}: read {j} records") 50 | key = f"{item['dataset_name']}@{item['shard']}@{item['shard_index']}@{item['global_index']}" 51 | minhash = LeanMinHash(seed=item['seed'], hashvalues=item['hashvalues']) 52 | H = _H(minhash.hashvalues[i * r : (i + 1) * r]) 53 | doc_queue.put((key, H)) 54 | 55 | 56 | def lsh_process( 57 | dupes_out: str, 58 | lsh_in: str, 59 | lsh_out: str, 60 | doc_queue: Queue, 61 | queue_idx: int, 62 | band_idx: int, 63 | log_interval: int = 1_000_000, 64 | n_docs: int = 0, 65 | check_only: bool = False, 66 | ): 67 | lsh_dict = defaultdict(str) 68 | if lsh_in: 69 | lsh_in_band = lsh_in.replace(".pickle", f"-{band_idx}.pickle") 70 | if os.path.exists(lsh_in_band): 71 | with open(lsh_in_band, "rb") as f: 72 | logging.info(f"Band {band_idx}: loading LSH index from {lsh_in_band}") 73 | lsh_dict = pickle.load(f) 74 | logging.info(f"Band {band_idx}: loaded LSH index from {lsh_in_band}") 75 | elif check_only: 76 | raise FileExistsError() 77 | else: 78 | logging.info(f"Band {band_idx}: did not find existing LSH index at {lsh_in_band}, so creating a new one") 79 | 80 | ensure_directory_exists(dupes_out) 81 | with open(dupes_out.replace(".txt", f"-{band_idx}.txt"), "w") as f: 82 | i = 0 83 | start_time = time.time() 84 | t0 = start_time 85 | if n_docs: 86 | pbar = tqdm(desc=f"Band {band_idx}", total=n_docs, unit_scale=True, position=queue_idx, dynamic_ncols=True) 87 | while True: 88 | try: 89 | key, H = doc_queue.get(timeout=30) 90 | cand = lsh_dict.get(H, "None") 91 | if cand != "None": 92 | f.write(f'{key} :: {cand}\n') 93 | elif not check_only: 94 | lsh_dict[H] = key 95 | i += 1 96 | if n_docs: 97 | pbar.update(1) 98 | elif i % log_interval == 0: 99 | speed = log_interval / (time.time() - t0) 100 | t0 = time.time() 101 | logging.info( 102 | f"Band {band_idx}: Processed {i / 1_000_000:.1f}M in {time.time() - start_time:.1f}s; " 103 | f"{speed / 1_000:.1f}kdocs/sec. Index size: {len(lsh_dict) / i * 100:.2f}%. " 104 | f"Doc queue size: {doc_queue.qsize()}" 105 | ) 106 | except queue.Empty: 107 | break 108 | if n_docs: 109 | pbar.close() 110 | 111 | if not check_only: 112 | lsh_out_band = lsh_out.replace(".pickle", f"-{band_idx}.pickle") 113 | ensure_directory_exists(lsh_out_band) 114 | with open(lsh_out_band, "wb") as f: 115 | logging.info(f"Band {band_idx}: saving LSH index to {lsh_out_band}") 116 | pickle.dump(lsh_dict, f, protocol=5) 117 | logging.info(f"Band {band_idx}: saved LSH index to {lsh_out_band}") 118 | logging.info(f"Band {band_idx}: Total number of documents: {i}") 119 | 120 | 121 | def generate_pairs(args): 122 | print() 123 | 124 | bands_inds = range(args.bands) 125 | bands_splits = [list(x) for x in more_itertools.divide(args.num_nodes, bands_inds)] 126 | if args.node_rank > -1: 127 | bands_splits = [bands_splits[args.node_rank]] 128 | if args.bands_parallel > 0: 129 | bands_splits_flattened = [band for bands in bands_splits for band in bands] 130 | bands_splits = [list(x) for x in more_itertools.chunked(bands_splits_flattened, args.bands_parallel)] 131 | logging.info(f"Bands splits: {bands_splits}") 132 | 133 | num_queues = max([len(x) for x in bands_splits]) 134 | doc_queues = [Queue(1_000_000) for _ in range(num_queues)] 135 | 136 | reader_shards = [[] for _ in range(args.reader_processes)] 137 | total_length = 0 138 | for arg_load_path in args.load_path: 139 | mh_dirs = sorted(os.listdir(arg_load_path)) 140 | logging.info(f'Loading {len(mh_dirs)} minhash shards from {arg_load_path}') 141 | mh_shards = [] 142 | for mh_dir in tqdm(mh_dirs): 143 | load_path = os.path.join(arg_load_path, mh_dir) 144 | mh_shards.append(datasets.load_from_disk(load_path)) 145 | logging.info('Concatenating into a single dataset') 146 | mh_ds = datasets.concatenate_datasets(mh_shards) 147 | total_length += len(mh_ds) 148 | logging.info(f'Splitting into {args.reader_processes} shards') 149 | for i in range(args.reader_processes): 150 | shard = mh_ds.shard(num_shards=args.reader_processes, index=i, contiguous=True) 151 | reader_shards[i].append(shard) 152 | 153 | t0 = time.time() 154 | for bands_split in bands_splits: 155 | logging.info('-' * 120) 156 | logging.info(f"Processing bands: {bands_split}") 157 | logging.info('-' * 120) 158 | processes = [] 159 | for q_i, band_i in enumerate(bands_split): 160 | for process_id in range(args.reader_processes): 161 | p = Process( 162 | target=get_hashes_band, 163 | args=(reader_shards[process_id], doc_queues[q_i], band_i, args.range, args.log_interval), 164 | ) 165 | processes.append(p) 166 | p.start() 167 | 168 | p = Process( 169 | target=lsh_process, 170 | args=(args.dupes_out, args.lsh_in, args.lsh_out, doc_queues[q_i], q_i, band_i, args.log_interval, total_length, args.check_only), 171 | ) 172 | processes.append(p) 173 | p.start() 174 | 175 | for p in processes: 176 | p.join() 177 | 178 | logging.info('-' * 120) 179 | logging.info(f'Done processing LSH index in {time.time() - t0:.1f}s.') 180 | logging.info('-' * 120) 181 | 182 | 183 | if __name__ == "__main__": 184 | parser = argparse.ArgumentParser() 185 | parser.add_argument("--load-path", nargs="+", type=str, required=True, help="Path to a folder with shards from minhashing step") 186 | parser.add_argument("--check-only", action="store_true", help="Only check existing LSH index") 187 | parser.add_argument("--dupes-out", type=str, required=True, help="Output text file with duplicates") 188 | parser.add_argument("--lsh-in", type=str, help="Pickle file with LSH index to load") 189 | parser.add_argument("--lsh-out", type=str, required=True, help="Output pickle file with LSH index") 190 | parser.add_argument("--range", type=int, required=True, help="Range of LSH index") 191 | parser.add_argument("--bands", type=int, required=True, help="Number of bands of LSH index") 192 | parser.add_argument("--num-nodes", type=int, default=1, help="Number of nodes for dsitributed processing") 193 | parser.add_argument("--node-rank", type=int, default=-1, help="Rank of the node") 194 | parser.add_argument("--bands-parallel", type=int, default=-1, help="Number of bands to be processed in parallel") 195 | parser.add_argument("--reader-processes", type=int, default=1, help="Number of reader processes per band to populate document queues") 196 | parser.add_argument("--log-interval", type=int, default=100_000, help="Interval of logging/updating progress bar") 197 | args = parser.parse_args() 198 | 199 | generate_pairs(args) 200 | -------------------------------------------------------------------------------- /zyda/preprocessing_and_filtering/preprocess_and_filter.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Zyphra Technologies. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from collections import defaultdict 16 | from typing import Dict, List 17 | import os 18 | import argparse 19 | import re 20 | import json 21 | import datasets 22 | import transformers 23 | from zyda.utils.text import get_normalized_words 24 | from zyda.utils.filtering import filter 25 | 26 | import nltk 27 | nltk.download('punkt') 28 | 29 | import logging 30 | logging.basicConfig(format='%(asctime)s: %(message)s', level=logging.INFO) 31 | 32 | REPO_BASE = os.environ.get("REPO_BASE", "") 33 | 34 | TOKENIZERS = { 35 | "neox": transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b"), 36 | } 37 | 38 | def read_json_file(fname): 39 | with open(fname, "r") as f: 40 | result_dict = json.loads(f.read()) 41 | return result_dict["words"] 42 | 43 | # Taken from https://arxiv.org/pdf/2309.04662 44 | WORD_LISTS = { 45 | "profanity_word_list.json": read_json_file(os.path.join(REPO_BASE, "zyda/preprocessing_and_filtering/profanity_word_list.json")), 46 | "sexual_word_list.json": read_json_file(os.path.join(REPO_BASE, "zyda/preprocessing_and_filtering/sexual_word_list.json")), 47 | "zh_pornsignals.json": read_json_file(os.path.join(REPO_BASE, "zyda/preprocessing_and_filtering/zh_pornsignals.json")), 48 | "cursed_substrings.json": read_json_file(os.path.join(REPO_BASE, "zyda/preprocessing_and_filtering/cursed_substrings.json")), 49 | } 50 | 51 | PATTERNS = ["xml", "", "\":", "www."] 52 | 53 | REPEATING_CHARACTER_THRESHOLD = 10 54 | CHARS_FOR_TRANSFORM = { 55 | "-": REPEATING_CHARACTER_THRESHOLD, 56 | " ": 40, 57 | "_": REPEATING_CHARACTER_THRESHOLD, 58 | "/": REPEATING_CHARACTER_THRESHOLD, 59 | r"\\": REPEATING_CHARACTER_THRESHOLD, 60 | "\n": REPEATING_CHARACTER_THRESHOLD, 61 | "\t": 20, 62 | "\r": REPEATING_CHARACTER_THRESHOLD, 63 | r"\.": REPEATING_CHARACTER_THRESHOLD, 64 | ",": REPEATING_CHARACTER_THRESHOLD, 65 | ":": REPEATING_CHARACTER_THRESHOLD, 66 | r"\?": REPEATING_CHARACTER_THRESHOLD, 67 | "\xa0": REPEATING_CHARACTER_THRESHOLD, 68 | } 69 | 70 | 71 | REGEX_MEAN_WORD_LENGTH = re.compile(r'\s|-|/|\\|\.') 72 | def mean_word_length(text: str): 73 | if text: 74 | words = REGEX_MEAN_WORD_LENGTH.split(text) 75 | return sum(len(word) for word in words) / len(words) 76 | return 0.0 77 | 78 | 79 | REGEX_NON_ALPHANUMERIC = re.compile(r'[^\w\s]') 80 | def fraction_non_alphanumeric(text: str): 81 | if text: 82 | return len(REGEX_NON_ALPHANUMERIC.findall(text)) / len(text) 83 | return 0.0 84 | 85 | 86 | REGEX_COUNT_NUMERICS = re.compile(r'\d') 87 | def fraction_numerical(text: str): 88 | if text: 89 | return len(REGEX_COUNT_NUMERICS.findall(text)) / len(text) 90 | return 0.0 91 | 92 | 93 | def count_substrings(text: str, allowed_num_repeats: int = 7): 94 | substrings = re.findall(r'(\w)\1{%d,}' % (allowed_num_repeats - 1), text) 95 | return len(substrings) 96 | 97 | 98 | def count_pattern(text: str, pattern: str): 99 | return len(re.findall(pattern, text)) 100 | 101 | 102 | def count_word_list(text: str, word_list: str): 103 | return sum([count_pattern(text, word) for word in word_list]) 104 | 105 | REGEX_EMAIL = re.compile(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b') 106 | REGEX_PHONE_NUMBER = re.compile(r'(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9]{1,2})\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9]{1,2})') 107 | def count_PII_items(input_string): 108 | num_email = len(REGEX_EMAIL.findall(input_string)) 109 | num_phone_number = len(REGEX_PHONE_NUMBER.findall(input_string)) 110 | return num_email + num_phone_number 111 | 112 | 113 | def transform( 114 | text: str, 115 | chars_with_thresholds: dict =CHARS_FOR_TRANSFORM, 116 | ) -> str: 117 | new_text = text 118 | for char, threshold in chars_with_thresholds.items(): 119 | pattern = char * threshold + '+' 120 | char = '?' if char == r'\?' else char 121 | char = '.' if char == r'\.' else char 122 | new_text = re.sub(pattern, char, new_text) 123 | return new_text 124 | 125 | 126 | def preprocess( 127 | batch, 128 | indices, 129 | key: str, 130 | name: str, 131 | shard: int, 132 | offset: int, 133 | patterns: list = PATTERNS, 134 | word_lists: dict = WORD_LISTS, 135 | ) -> Dict[str, List]: 136 | texts = batch[key] 137 | features = defaultdict(list) 138 | for ind, text in zip(indices, texts): 139 | features["dataset_name"].append(name) 140 | features["shard"].append(shard) 141 | features["shard_index"].append(ind) 142 | features["global_index"].append(ind + offset) 143 | features["mean_word_length"].append(mean_word_length(text)) 144 | features["fraction_non_alphanumeric"].append(fraction_non_alphanumeric(text)) 145 | features["fraction_numerical"].append(fraction_numerical(text)) 146 | features["pii_count"].append(count_PII_items(text)) 147 | 148 | pattern_counts = {} 149 | for pattern in patterns: 150 | pattern_counts[pattern] = count_pattern(text, pattern) 151 | features["pattern_counts"].append(pattern_counts) 152 | 153 | word_list_counts = {} 154 | for word_list_key, word_list in word_lists.items(): 155 | word_list_counts[word_list_key] = count_word_list(text, word_list) 156 | features["word_list_counts"].append(word_list_counts) 157 | 158 | tokenized = TOKENIZERS["neox"].encode(text) 159 | features["n_tokens_neox"].append(len(tokenized)) 160 | 161 | words = get_normalized_words(text) 162 | features["n_words"].append(len(words)) 163 | 164 | transformed_text = transform(text) 165 | features["transformed_text"].append(transformed_text) 166 | features["substrings_counts"].append(count_substrings(transformed_text)) 167 | 168 | return features 169 | 170 | 171 | if __name__ == '__main__': 172 | parser = argparse.ArgumentParser() 173 | parser.add_argument('--hf-path', type=str, required=True, help='Path of HF dataset') 174 | parser.add_argument('--hf-dir', type=str, default=None, help='Dir in HF dataset') 175 | parser.add_argument('--name', type=str, required=True, help='Descriptor for identifying the dataset') 176 | parser.add_argument('--load-from-disk', action='store_true', help='Use datasets.load_from_disk() to load the dataset') 177 | parser.add_argument('--num-proc', type=int, default=1, help='Number of processes for HF processing') 178 | parser.add_argument('--num-records-per-shard', type=int, default=20_000_000, help='Approximate number of records per shard') 179 | parser.add_argument('--key', type=str, default='text', help='Key to extract') 180 | parser.add_argument('--keep-key', action='store_true', help='If specified, key column will be saved in shards') 181 | parser.add_argument('--save-path', type=str, required=True, help='Folder to save processed HF dataset to') 182 | parser.add_argument('--from-scratch', action='store_true', help='If specified, will forcefully do every shard regardless of previous progress') 183 | 184 | args = parser.parse_args() 185 | print() 186 | 187 | logging.info(f"Loading {args.hf_path}, dir={args.hf_dir}") 188 | if args.load_from_disk: 189 | dataset = datasets.load_from_disk(args.hf_path) 190 | else: 191 | dataset = datasets.load_dataset(args.hf_path, args.hf_dir, num_proc=args.num_proc, split='train', trust_remote_code=True) 192 | 193 | logging.info(f"Loaded dataset:\n{dataset}") 194 | logging.info(f"Cache cleaned: {dataset.cleanup_cache_files()}") 195 | 196 | num_shards = 1 197 | if args.num_records_per_shard: 198 | num_shards = 1 + len(dataset) // args.num_records_per_shard 199 | 200 | offset = 0 201 | for i in range(num_shards): 202 | print() 203 | logging.info(f"Processing shard {i + 1} / {num_shards}. Current offset = {offset}.") 204 | save_path = f"{args.save_path}/shard_{i:02d}" 205 | if os.path.exists(save_path) and not args.from_scratch: 206 | logging.info(f"Already processed!") 207 | offset += len(datasets.load_from_disk(save_path)) 208 | continue 209 | 210 | ds_shard = dataset.shard(num_shards=num_shards, index=i, contiguous=True) 211 | logging.info(f"Cache cleaned: {ds_shard.cleanup_cache_files()}") 212 | 213 | logging.info(f"Preprocessing...") 214 | ds_shard_post = ds_shard.map( 215 | lambda batch, indices: preprocess(batch, indices, shard=i, offset=offset, key=args.key, name=args.name), 216 | batched=True, 217 | with_indices=True, 218 | remove_columns=None if args.keep_key else args.key, 219 | num_proc=args.num_proc, 220 | ) 221 | 222 | if "starcoder" in args.name: 223 | logging.info(f"Starcoder detected: skipping filtering") 224 | else: 225 | logging.info(f"Filtering...") 226 | ds_shard_post = ds_shard_post.filter(lambda row: filter(row), num_proc=args.num_proc) 227 | 228 | offset += len(ds_shard_post) 229 | ds_shard_post.save_to_disk(save_path, max_shard_size="8GB") 230 | logging.info(f"Cache cleaned: {ds_shard.cleanup_cache_files()}") 231 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2024 Zyphra Technologies 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. --------------------------------------------------------------------------------