├── .gitignore ├── LICENSE ├── README.md ├── app ├── .dockerignore ├── Dockerfile ├── requirements.txt └── src │ ├── __init__.py │ ├── artifacts │ ├── __init__.py │ ├── downloaders │ │ ├── __init__.py │ │ ├── books_downloader.py │ │ ├── ccnet_downloader.py │ │ ├── openwebtext_downloader.py │ │ └── wikipedia_downloader.py │ ├── ft_trainer.py │ ├── hash_dist.py │ ├── update_resources.py │ └── utils │ │ ├── __init__.py │ │ ├── data_utils.py │ │ └── logging_utils.py │ ├── bloomfilter.py │ ├── core │ ├── __init__.py │ ├── constants.py │ ├── data_types.py │ ├── document.py │ ├── exceptions.py │ ├── quality_signals │ │ ├── __init__.py │ │ ├── base.py │ │ ├── classifiers.py │ │ ├── content.py │ │ ├── importance_weights.py │ │ ├── lines.py │ │ ├── natural_language.py │ │ ├── repetitions.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── classifiers.py │ │ │ ├── content.py │ │ │ ├── dsir.py │ │ │ └── stop_words.py │ ├── schema │ │ ├── __init__.py │ │ └── rp.py │ └── worker.py │ ├── dedupe │ ├── __init__.py │ ├── minhash.py │ └── utils.py │ ├── pipeline.py │ ├── prep_artifacts.py │ ├── run_lsh.py │ ├── token_count.py │ └── utilities │ ├── __init__.py │ ├── io │ ├── __init__.py │ ├── reader.py │ ├── s3.py │ └── writer.py │ ├── logging │ ├── __init__.py │ ├── configure.py │ ├── format.py │ ├── mp.py │ └── trackers.py │ ├── register │ ├── __init__.py │ └── registry_utils.py │ └── text │ ├── __init__.py │ ├── ngrams.py │ ├── normalization.py │ └── util.py ├── configs └── rp_v2.0.conf ├── docs └── rpv2.png └── scripts ├── apptainer_run_lsh.sh ├── apptainer_run_quality_signals.sh └── run_prep_artifacts.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.pyc 6 | *.DS_Store 7 | 8 | # data folders 9 | data/* 10 | !data/.gitkeep 11 | 12 | # notebooks 13 | notebooks/* 14 | .ipynb_checkpoints 15 | 16 | # Environments 17 | .env 18 | .venv 19 | env/ 20 | venv/ 21 | ENV/ 22 | env.bak/ 23 | venv.bak/ 24 | 25 | # ides 26 | .idea/ 27 | .vscode/ 28 | 29 | # distribution 30 | *.egg-info/ 31 | dist/ 32 | build/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /app/.dockerignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.pyc 6 | *.DS_Store 7 | 8 | # Environments 9 | .env 10 | .venv 11 | env/ 12 | venv/ 13 | ENV/ 14 | env.bak/ 15 | venv.bak/ 16 | 17 | # ides 18 | .idea/ 19 | .vscode/ 20 | 21 | # jupyter notebook 22 | notebooks/ 23 | .ipynb_checkpoints/ 24 | *.ipynb 25 | 26 | # debugging 27 | debugging/ -------------------------------------------------------------------------------- /app/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11.5-slim-bookworm 2 | 3 | WORKDIR /usr/app 4 | 5 | RUN apt-get update && apt-get install -y \ 6 | build-essential \ 7 | isal \ 8 | git 9 | 10 | RUN pip3 install --no-cache-dir --upgrade pip 11 | RUN pip3 install cmake cython 12 | 13 | # copy requirements.txt to the working directory 14 | COPY requirements.txt requirements.txt 15 | 16 | # install python dependencies 17 | RUN pip3 install --no-cache-dir -r requirements.txt 18 | 19 | # install mwparserfromhell (v >= 0.7.0 for spanish) 20 | RUN pip3 install apache_beam 21 | RUN pip3 install git+https://github.com/earwig/mwparserfromhell.git@0f89f44 22 | 23 | # set python hash seed 24 | ENV PYTHONHASHSEED 42 25 | 26 | # copy the local files to the working directory 27 | COPY . . 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /app/requirements.txt: -------------------------------------------------------------------------------- 1 | datasets 2 | s3fs 3 | dill 4 | tqdm 5 | nltk 6 | fasttext 7 | polars 8 | pyarrow 9 | numpy 10 | msgspec 11 | boto3 12 | networkit 13 | scrubadub 14 | textstat 15 | xopen 16 | rich 17 | progiter 18 | pybloomfiltermmap3 19 | transformers -------------------------------------------------------------------------------- /app/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/togethercomputer/RedPajama-Data/6d2cee9df2b0204dd2bcb00bf06b5a7b1d7432d7/app/src/__init__.py -------------------------------------------------------------------------------- /app/src/artifacts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/togethercomputer/RedPajama-Data/6d2cee9df2b0204dd2bcb00bf06b5a7b1d7432d7/app/src/artifacts/__init__.py -------------------------------------------------------------------------------- /app/src/artifacts/downloaders/__init__.py: -------------------------------------------------------------------------------- 1 | from .books_downloader import BooksDownloader 2 | from .openwebtext_downloader import OpenWebTextDownloader 3 | from .wikipedia_downloader import WikipediaDownloader 4 | from .ccnet_downloader import CCNetDownloader 5 | -------------------------------------------------------------------------------- /app/src/artifacts/downloaders/books_downloader.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from pathlib import Path 3 | import random 4 | from tqdm import tqdm 5 | 6 | from utilities.io import Writer 7 | from utilities.text.util import generate_paragraphs 8 | 9 | 10 | class BooksDownloader: 11 | r""" Loads the RedPajama Books dataset from HuggingFace Datasets and saves 12 | it to disk """ 13 | 14 | dataset_name = "books" 15 | output_fp = "books/en-books.jsonl.gz" 16 | 17 | def __init__( 18 | self, lang, out_dir, overwrite, cache_dir, max_samples, 19 | max_paragraphs_per_sample=200, max_samples_per_book=500 20 | ): 21 | self._lang = lang 22 | self._out_dir = out_dir 23 | self._overwrite = overwrite 24 | self._cache_dir = cache_dir 25 | self._max_samples = max_samples 26 | self._max_paragraphs_per_sample = max_paragraphs_per_sample 27 | self._max_samples_per_book = max_samples_per_book 28 | self._filepath = None 29 | 30 | def __str__(self): 31 | return f"{self.__class__.__name__}(lang={self._lang})" 32 | 33 | @property 34 | def filepath(self): 35 | return self._filepath 36 | 37 | def __generate_chunks(self, text: str): 38 | if self._max_paragraphs_per_sample is None: 39 | yield text 40 | return 41 | 42 | n_samples = 0 43 | buffer = [] 44 | buffer_size = random.randint(1, self._max_paragraphs_per_sample) 45 | for par in generate_paragraphs(text): 46 | buffer.append(par) 47 | if len(buffer) >= buffer_size: 48 | yield "\n".join(buffer) 49 | 50 | buffer_size = random.randint( 51 | 1, self._max_paragraphs_per_sample 52 | ) 53 | buffer = [] 54 | n_samples += 1 55 | 56 | if n_samples >= self._max_samples_per_book > 0: 57 | break 58 | 59 | def run(self, logger): 60 | if self._lang != "en": 61 | logger.info(f"{str(self)} Skipping {self._lang}") 62 | return 63 | 64 | self._filepath = Path(self._out_dir) / self.output_fp 65 | logger.info(f"{str(self)} Output file: {self._filepath}") 66 | logger.info(f"{str(self)} max_samples: {self._max_samples}") 67 | 68 | if self._filepath.exists(): 69 | if not self._overwrite: 70 | raise FileExistsError(f"File {self._filepath} already exists.") 71 | else: 72 | self._filepath.unlink() 73 | logger.info(f"{str(self)} Deleted {self._filepath}") 74 | 75 | out_uri = "file://" + str(self._filepath) 76 | writer = Writer(uri=out_uri, schema=[("text", str)]) 77 | 78 | logger.info(f"{str(self)} Download start.") 79 | pbar = tqdm(desc="writing progress", total=self._max_samples) 80 | flush_every = 5_000 81 | 82 | n_docs = 0 83 | for book in load_dataset( 84 | "togethercomputer/RedPajama-Data-1T", name="book", 85 | cache_dir=self._cache_dir, 86 | split="train", streaming=True 87 | ): 88 | for chunk in self.__generate_chunks(book["text"]): 89 | n_docs += 1 90 | if n_docs > self._max_samples > 0: 91 | break 92 | 93 | writer.write( 94 | data_obj={"text": chunk}, 95 | flush=n_docs % flush_every == 0 96 | ) 97 | pbar.update(1) 98 | 99 | else: 100 | continue 101 | break 102 | 103 | pbar.close() 104 | writer.close() 105 | logger.info(f"{str(self)} Download finished; num_samples={n_docs - 1}") 106 | -------------------------------------------------------------------------------- /app/src/artifacts/downloaders/ccnet_downloader.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import subprocess 3 | from pathlib import Path 4 | import re 5 | import functools 6 | from tqdm import tqdm 7 | import multiprocessing as mp 8 | from multiprocessing.pool import Pool 9 | from urllib.parse import urlparse 10 | import os 11 | 12 | from utilities.io import Reader, Writer 13 | from utilities.io.s3 import init_client 14 | 15 | 16 | class CCNetDownloader(object): 17 | r""" 18 | This class downloads / loads ccnet data and writes it to a jsonl file. 19 | """ 20 | 21 | dataset_name = "ccnet" 22 | 23 | # extension of the cc input files 24 | cc_ext = ".json.gz" 25 | 26 | def __init__( 27 | self, 28 | artifacts_dir, 29 | cc_input, 30 | cc_input_base_uri, 31 | lang, num_samples, max_workers, 32 | endpoint_url 33 | ): 34 | # write args to class variables 35 | self._lang = lang 36 | self._num_samples = num_samples 37 | self._cc_input = cc_input 38 | self._cc_input_base_uri = cc_input_base_uri 39 | self._endpoint_url = endpoint_url 40 | 41 | # parallel readers 42 | if max_workers is not None: 43 | self._parallel_readers = max_workers 44 | else: 45 | self._parallel_readers = mp.cpu_count() - 2 46 | 47 | # build output path 48 | self._output_fp = Path(artifacts_dir) / "datasets" / \ 49 | self._lang / "ccnet" / "ccnet.jsonl" 50 | self._output_fp.parent.mkdir(parents=True, exist_ok=True) 51 | 52 | def __str__(self): 53 | return f"{self.__class__.__name__}({self._lang})" 54 | 55 | @property 56 | def filepath(self): 57 | return self._output_fp 58 | 59 | def __ccnet_file_filter(self, fp: str) -> bool: 60 | r""" function to filter commoncrawl input files. """ 61 | # we only keep files in the target language 62 | if not Path(fp).name.startswith(f"{self._lang}_"): 63 | return False 64 | 65 | # check extension 66 | if not fp.endswith(self.cc_ext): 67 | return False 68 | 69 | return True 70 | 71 | def run(self, logger): 72 | if not Path(self._cc_input).exists(): 73 | raise ValueError( 74 | f"Listings file {self._cc_input} does not exist" 75 | ) 76 | 77 | # read the listings file and return the relative paths listed in the 78 | # file. 79 | logger.info(f"{str(self)} Start loading input listings...") 80 | with open(self._cc_input) as f: 81 | input_listings = list(map( 82 | lambda _fp: os.path.join(self._cc_input_base_uri, _fp), 83 | filter(self.__ccnet_file_filter, map(str.strip, f.readlines())) 84 | )) 85 | 86 | # partition cc input by snapshot id in order to ensure that we have a 87 | # balanced number of samples per snapshot. This is to avoid bias due 88 | # to distribution shifts over time. 89 | logger.info(f"{str(self)} Partitioning inputs by snapshot...") 90 | snapsh_re = re.compile(r'\b\d{4}-\d{2}\b') 91 | inputs_by_snapsh = defaultdict(list) 92 | for listing in input_listings: 93 | if (dump_id := snapsh_re.search(listing).group()) is None: 94 | continue 95 | inputs_by_snapsh[dump_id].append(listing) 96 | 97 | samples_per_snapshot = max( 98 | 1, self._num_samples // len(inputs_by_snapsh) 99 | ) 100 | 101 | # kick off processes 102 | manager = mp.Manager() 103 | data_queue = manager.Queue(maxsize=128 * self._parallel_readers) 104 | 105 | # writer 106 | writer_proc = mp.Process( 107 | target=self._writer_worker, args=(data_queue,) 108 | ) 109 | writer_proc.start() 110 | 111 | logger.info(f"{str(self)} Start loading {self._num_samples} samples " 112 | f"from {len(inputs_by_snapsh)} snapshots") 113 | 114 | with Pool(processes=self._parallel_readers) as pool: 115 | counts_per_snapsh = pool.starmap( 116 | functools.partial(self._load_snapshot, data_queue=data_queue), 117 | [ 118 | (snpsh_id, snpsh_files, samples_per_snapshot) 119 | for snpsh_id, snpsh_files in inputs_by_snapsh.items() 120 | ] 121 | ) 122 | 123 | total_samples = 0 124 | for counts, snapshot_id in counts_per_snapsh: 125 | logger.info(f"{str(self)} Snapshot {snapshot_id}: " 126 | f"loaded {counts} samples.") 127 | total_samples += counts 128 | 129 | logger.info(f"{str(self)} Total: loaded {total_samples} samples.") 130 | logger.info(f"{str(self)} Shuffling...") 131 | subprocess.run(["shuf", self._output_fp, "-o", self._output_fp]) 132 | logger.info(f"{str(self)} Done. Output: {self._output_fp}") 133 | 134 | # send kill signal to writer 135 | data_queue.put_nowait(None) 136 | writer_proc.join() 137 | manager.shutdown() 138 | 139 | def _load_snapshot( 140 | self, snapshot_id, input_uris, num_samples, data_queue: mp.Queue, 141 | ): 142 | # partition input files into head, middle and tail 143 | head_uris = list(filter(lambda _u: "_head" in _u, input_uris)) 144 | middle_uris = list(filter(lambda _u: "_middle" in _u, input_uris)) 145 | tail_uris = list(filter(lambda _u: "_tail" in _u, input_uris)) 146 | 147 | # compute number of samples to load from each bucket 148 | samples_per_bucket = { 149 | "head": int(num_samples * 0.1), 150 | "middle": int(num_samples * 0.2), 151 | "tail": int(num_samples * 0.7) 152 | } 153 | 154 | if urlparse(self._cc_input_base_uri).scheme == "s3": 155 | s3_client = init_client( 156 | endpoint_url=self._endpoint_url, 157 | signature_version="s3v4", 158 | aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"), 159 | aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY") 160 | ) 161 | else: 162 | s3_client = None 163 | 164 | reader = Reader( 165 | schema=[("raw_content", str), ("language", str)], 166 | s3_client=s3_client 167 | ) 168 | 169 | total_samples = 0 170 | 171 | for bucket, bucket_list in zip( 172 | ["head", "middle", "tail"], [head_uris, middle_uris, tail_uris] 173 | ): 174 | samples_retrieved = 0 175 | target_samples = samples_per_bucket[bucket] 176 | 177 | for uri in bucket_list: 178 | 179 | samples_to_retrieve = target_samples - samples_retrieved 180 | if samples_to_retrieve <= 0: 181 | break 182 | 183 | for idx, record in reader.read( 184 | uri=uri, max_samples=samples_to_retrieve 185 | ): 186 | data_queue.put({ 187 | "text": record.raw_content, 188 | "lang": record.language, 189 | "source": uri 190 | }) 191 | 192 | samples_retrieved += 1 193 | total_samples += 1 194 | 195 | return total_samples, snapshot_id 196 | 197 | def _writer_worker(self, data_queue: mp.Queue): 198 | 199 | writer = Writer( 200 | uri="file://" + str(self._output_fp), 201 | schema=[("text", str), ("lang", str), ("source", str)] 202 | ) 203 | 204 | flush_every = 10_000 205 | 206 | pbar = tqdm(desc="writing progress") 207 | 208 | num_recs = 0 209 | while True: 210 | data = data_queue.get() 211 | 212 | if data is None: 213 | break 214 | 215 | num_recs += 1 216 | writer.write(data, flush=num_recs % flush_every == 0) 217 | pbar.update(1) 218 | 219 | pbar.close() 220 | writer.close() 221 | -------------------------------------------------------------------------------- /app/src/artifacts/downloaders/openwebtext_downloader.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from pathlib import Path 3 | from tqdm import tqdm 4 | 5 | from utilities.io import Writer 6 | 7 | 8 | class OpenWebTextDownloader: 9 | r""" Loads the Openwebtext dataset from HuggingFace Datasets and saves it 10 | to disk """ 11 | 12 | dataset_name = "openwebtext" 13 | output_fp = "openwebtext/en-openwebtext.jsonl.gz" 14 | 15 | def __init__(self, lang, out_dir, overwrite, cache_dir, max_samples): 16 | self._lang = lang 17 | self._out_dir = out_dir 18 | self._overwrite = overwrite 19 | self._cache_dir = cache_dir 20 | self._max_samples = max_samples 21 | self._filepath = None 22 | 23 | def __str__(self): 24 | return f"{self.__class__.__name__}(lang={self._lang})" 25 | 26 | @property 27 | def filepath(self): 28 | return self._filepath 29 | 30 | def run(self, logger): 31 | if self._lang != "en": 32 | logger.info(f"{str(self)} Skipping {self._lang}") 33 | return 34 | 35 | self._filepath = Path(self._out_dir) / self.output_fp 36 | logger.info(f"{str(self)} Output file: {self._filepath}") 37 | logger.info(f"{str(self)} max_samples: {self._max_samples}") 38 | 39 | if self._filepath.exists(): 40 | if not self._overwrite: 41 | raise FileExistsError(f"File {self._filepath} already exists.") 42 | else: 43 | self._filepath.unlink() 44 | logger.info(f"{str(self)} Deleted {self._filepath}") 45 | 46 | out_uri = "file://" + str(self._filepath) 47 | writer = Writer(uri=out_uri, schema=[("text", str)]) 48 | 49 | logger.info(f"{str(self)} Download start.") 50 | pbar = tqdm(desc="writing progress") 51 | flush_every = 10_000 52 | 53 | n_docs = 0 54 | for record in load_dataset( 55 | "openwebtext", cache_dir=self._cache_dir, split="train", 56 | streaming=True 57 | ): 58 | n_docs += 1 59 | if n_docs > self._max_samples > 0: 60 | break 61 | writer.write( 62 | data_obj={"text": record["text"]}, 63 | flush=n_docs % flush_every == 0 64 | ) 65 | pbar.update(1) 66 | 67 | pbar.close() 68 | writer.close() 69 | logger.info(f"{str(self)} Download finished; num_samples={n_docs - 1}") 70 | -------------------------------------------------------------------------------- /app/src/artifacts/downloaders/wikipedia_downloader.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from pathlib import Path 3 | from tqdm import tqdm 4 | 5 | from utilities.io import Writer 6 | 7 | 8 | class WikipediaDownloader: 9 | r""" Loads the Wikipedia dataset from HuggingFace Datasets and saves it to 10 | disk """ 11 | 12 | dataset_name = "wikipedia" 13 | output_pattern = "wikipedia/{lang}-wikipedia.jsonl.gz" 14 | 15 | def __init__(self, lang, out_dir, overwrite, cache_dir, max_samples): 16 | self._lang = lang 17 | self._out_dir = out_dir 18 | self._overwrite = overwrite 19 | self._cache_dir = cache_dir 20 | self._max_samples = max_samples 21 | self._filepath = None 22 | 23 | def __str__(self): 24 | return f"{self.__class__.__name__}({self._lang})" 25 | 26 | @property 27 | def filepath(self): 28 | return self._filepath 29 | 30 | def run(self, logger): 31 | output_fn = self.output_pattern.format(lang=self._lang) 32 | self._filepath = Path(self._out_dir) / output_fn 33 | 34 | logger.info(f"{str(self)} Output file: {self._filepath}") 35 | logger.info(f"{str(self)} max_samples: {self._max_samples}") 36 | 37 | if self._filepath.exists(): 38 | if not self._overwrite: 39 | raise FileExistsError(f"File {self._filepath} already exists.") 40 | else: 41 | self._filepath.unlink() 42 | logger.info(f"{str(self)} Deleted {self._filepath}") 43 | 44 | out_uri = "file://" + str(self._filepath) 45 | writer = Writer(uri=out_uri, schema=[("text", str)]) 46 | 47 | logger.info(f"{str(self)} Download start.") 48 | pbar = tqdm(desc="writing progress") 49 | flush_every = 10_000 50 | 51 | try: 52 | # try to load wikipedia data from preprocessed huggingface dataset 53 | ds_iterator = load_dataset( 54 | "wikipedia", f"20220301.{self._lang}", streaming=True, 55 | split="train" 56 | ) 57 | logger.info(f"{str(self)} Load {self._lang}-wiki from 20220301") 58 | except Exception as _: 59 | # if that fails, load from original huggingface dataset and process 60 | ds_iterator = load_dataset( 61 | "wikipedia", language=self._lang, date="20230801", 62 | cache_dir=self._cache_dir, beam_runner="DirectRunner", 63 | split="train" 64 | ) 65 | logger.info(f"{str(self)} Load {self._lang}-wiki from 20230801") 66 | 67 | n_docs = 0 68 | for record in ds_iterator: 69 | n_docs += 1 70 | if n_docs > self._max_samples > 0: 71 | break 72 | writer.write( 73 | data_obj={"text": record["text"]}, 74 | flush=n_docs % flush_every == 0 75 | ) 76 | pbar.update(1) 77 | 78 | pbar.close() 79 | writer.close() 80 | logger.info(f"{str(self)} Download finished; num_samples={n_docs - 1}") 81 | -------------------------------------------------------------------------------- /app/src/artifacts/ft_trainer.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from tqdm import tqdm 3 | import fasttext 4 | import subprocess 5 | 6 | from core.document import Document 7 | from core.quality_signals.utils.classifiers import \ 8 | preprocess_quality_classifier 9 | from core.constants import CCNET_LABEL 10 | from utilities.io import Reader 11 | 12 | 13 | class FastTextTrainer: 14 | # cc label 15 | cc_label = CCNET_LABEL 16 | 17 | # output file naming convention 18 | output_fmt = "{dataset}.model.bin" 19 | 20 | def __init__( 21 | self, artifacts_dir, ccnet_data, target_data, target_name, 22 | samples_per_class, lang 23 | ): 24 | 25 | # write args to class variables 26 | self._ccnet_data = ccnet_data 27 | self._target_data = target_data 28 | self._samples_per_class = samples_per_class 29 | self._lang = lang 30 | self._target_label = f"__label__{target_name}" 31 | 32 | # build output directory 33 | out_dir = Path(artifacts_dir) / "classifiers" / self._lang 34 | out_dir.mkdir(parents=True, exist_ok=True) 35 | self._output = out_dir / self.output_fmt.format(dataset=target_name) 36 | self._train_data = out_dir / f"{target_name}.data.train" 37 | 38 | def run(self, logger): 39 | log_prefix = f"{self.__class__.__name__}(" \ 40 | f"lang={self._lang}, ccdata={self._ccnet_data}, " \ 41 | f"target_data={self._target_data}, " \ 42 | f"target_label={self._target_label})" 43 | 44 | train_data_fh = open(self._train_data, "w") 45 | 46 | logger.info(f"{log_prefix} Start building fasttext classifier") 47 | 48 | # write target data 49 | samples_per_slice = self._samples_per_class // len(self._target_data) 50 | total_target_samples = 0 51 | 52 | for target_data_fp in self._target_data: 53 | reader = Reader(schema=[("text", str)]) 54 | total_target_samples += self.__write_train_chunk( 55 | uri="file://" + str(target_data_fp), 56 | reader=reader, 57 | writer=train_data_fh, 58 | max_samples=samples_per_slice, 59 | target_label=self._target_label 60 | ) 61 | 62 | logger.info(f"{log_prefix} Number of target " 63 | f"samples found: {total_target_samples}") 64 | 65 | # write ccnet data 66 | reader = Reader(schema=[("text", str)]) 67 | ccnet_samples = self.__write_train_chunk( 68 | uri="file://" + str(self._ccnet_data), 69 | reader=reader, 70 | writer=train_data_fh, 71 | max_samples=total_target_samples, 72 | target_label=self.cc_label 73 | ) 74 | train_data_fh.close() 75 | logger.info(f"{log_prefix} Total ccnet samples: {ccnet_samples}") 76 | 77 | # shuffle train data 78 | logger.info(f"{log_prefix} Shuffling train data") 79 | subprocess.run( 80 | ["shuf", "-o", str(self._train_data), str(self._train_data)] 81 | ) 82 | 83 | # train fasttext classifier 84 | model = fasttext.train_supervised( 85 | input=str(self._train_data), verbose=2 86 | ) 87 | model.save_model(str(self._output)) 88 | logger.info(f"{log_prefix} Saved model to {self._output}") 89 | 90 | @staticmethod 91 | def __write_train_chunk( 92 | uri, reader: Reader, writer, max_samples, target_label 93 | ): 94 | num_samples = 0 95 | 96 | for record in tqdm( 97 | reader.read(uri, max_samples=max_samples, return_idx=False), 98 | total=max_samples 99 | ): 100 | doc = Document(record.text, domain=None) 101 | text = preprocess_quality_classifier(document=doc) 102 | writer.write(f"{target_label} {text}\n") 103 | num_samples += 1 104 | 105 | writer.flush() 106 | 107 | return num_samples 108 | -------------------------------------------------------------------------------- /app/src/artifacts/hash_dist.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import multiprocessing as mp 3 | from multiprocessing.pool import Pool 4 | import numpy as np 5 | from pathlib import Path 6 | from tqdm import tqdm 7 | 8 | from core.document import Document 9 | from utilities.io import Reader 10 | 11 | 12 | def _compute_hash_features(text: str, buckets: int): 13 | r""" Compute the hash features for a given text """ 14 | # compute hash features directly in the document class 15 | # for consistency 16 | document = Document( 17 | content=text, domain=None, precompute_ngrams=False, 18 | precompute_hash_features=True, dsir_buckets=buckets 19 | ) 20 | return document.hash_features, document.num_raw_words 21 | 22 | 23 | class HashDist: 24 | # output file naming convention 25 | output_file_fmt_counts = "{dataset}.{lang}.{buckets}.counts.npy" 26 | output_file_fmt_lambda = "{dataset}.{lang}.lambda.npy" 27 | 28 | def __init__( 29 | self, artifacts_dir, num_samples, buckets, max_workers, logger 30 | ): 31 | self._artifacts_dir = artifacts_dir 32 | self._num_samples = num_samples 33 | self._buckets = buckets 34 | self._max_workers = max_workers 35 | self._logger = logger 36 | 37 | def run(self, lang, datafile, dataset): 38 | log_prefix = f"{self.__class__.__name__}(" \ 39 | f"lang={lang}, datafile={datafile}, dataset={dataset})" 40 | datafile = str(Path(datafile).absolute()) 41 | 42 | out_dir = Path(self._artifacts_dir) / "dsir" / f"{lang}" 43 | out_dir.mkdir(parents=True, exist_ok=True) 44 | out_fp_dist = out_dir / self.output_file_fmt_counts.format( 45 | dataset=dataset, lang=lang, buckets=self._buckets 46 | ) 47 | out_fp_lambda = out_dir / self.output_file_fmt_lambda.format( 48 | dataset=dataset, lang=lang 49 | ) 50 | self._logger.info( 51 | f"{log_prefix} Start dsir computation for {lang}-{dataset}" 52 | ) 53 | self._logger.info(f"{log_prefix} Reading data from {datafile}") 54 | self._logger.info(f"{log_prefix} Write distribution to {out_fp_dist}") 55 | self._logger.info(f"{log_prefix} Write lambda to {out_fp_lambda}") 56 | 57 | if self._max_workers is not None: 58 | if self._max_workers < 0: 59 | raise ValueError("max_workers must be >= 0") 60 | max_proc = min(self._max_workers, mp.cpu_count() - 1) 61 | else: 62 | max_proc = mp.cpu_count() - 1 63 | 64 | self._logger.info(f"{log_prefix} Using {max_proc} processes") 65 | reader = Reader(schema=[("text", str)]) 66 | 67 | def _wrap_reader(): 68 | r""" wrap reader so that it can be used with multiprocessing. 69 | Otherwise, pickling of records fails. """ 70 | for record in reader.read( 71 | uri="file://" + datafile, 72 | max_samples=self._num_samples, 73 | return_idx=False 74 | ): 75 | yield record.text 76 | 77 | global_dist = np.zeros(self._buckets, dtype=np.int64) 78 | 79 | # MLE estimator for lambda of Poisson distribution 80 | lambda_mle = 0 81 | num_samples = 0 82 | 83 | with Pool(max_proc) as pool: 84 | for dist, dlen in tqdm( 85 | pool.imap_unordered( 86 | functools.partial( 87 | _compute_hash_features, 88 | buckets=self._buckets 89 | ), 90 | _wrap_reader() 91 | ), 92 | total=self._num_samples, 93 | desc=f"Reading {datafile}" 94 | ): 95 | global_dist += dist 96 | lambda_mle += dlen 97 | num_samples += 1 98 | 99 | # save lambda 100 | np.save(file=str(out_fp_lambda), arr=lambda_mle / num_samples) 101 | self._logger.info(f"{log_prefix} Saved lambda to {out_fp_lambda}") 102 | 103 | # save distribution 104 | np.save(file=str(out_fp_dist), arr=global_dist) 105 | self._logger.info(f"{log_prefix} Saved distribution to {out_fp_dist}") 106 | 107 | self._logger.info(f"{log_prefix} Finished dsir for {lang}-{dataset}.") 108 | -------------------------------------------------------------------------------- /app/src/artifacts/update_resources.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import defaultdict 3 | import itertools 4 | import json 5 | from pathlib import Path 6 | import time 7 | import urllib.request 8 | import requests 9 | import tarfile 10 | from typing import Dict, List, Tuple 11 | 12 | _UT1_BLACKLIST_URL = "http://dsi.ut-capitole.fr" \ 13 | "/blacklists/download/blacklists.tar.gz" 14 | _LDNOOBW_URL = "https://raw.githubusercontent.com/LDNOOBW/List-of-Dirty-" \ 15 | "Naughty-Obscene-and-Otherwise-Bad-Words/master/{lang}" 16 | 17 | 18 | def _build_category_index(raw_categories) -> Dict[Tuple[str], int]: 19 | r""" Build a mapping with a list of categories, corresponding to a unique 20 | combination of categories, to a category ID. 21 | """ 22 | categories = sorted(raw_categories) 23 | 24 | category_index = {} 25 | for i, category in enumerate(itertools.chain.from_iterable( 26 | itertools.combinations(categories, r) for r in 27 | range(1, len(categories) + 1) 28 | )): 29 | category_index[tuple(str(s) for s in sorted(category))] = i 30 | 31 | return category_index 32 | 33 | 34 | def _domain_to_category_list_mapping(bad_urls_dir, raw_categories): 35 | domain_to_category = defaultdict(list) 36 | 37 | for category_dir in (bad_urls_dir / "blacklists").iterdir(): 38 | if not category_dir.is_dir(): 39 | continue 40 | 41 | category = category_dir.name 42 | 43 | if category not in raw_categories: 44 | continue 45 | 46 | with open(category_dir / "domains", "r") as f: 47 | for dom in map(str.strip, f.readlines()): 48 | domain_to_category[dom].append(category) 49 | 50 | # postprocess 51 | domain_to_category = { 52 | dom: tuple(str(s) for s in sorted(set(categories))) 53 | for dom, categories in domain_to_category.items() 54 | } 55 | 56 | return domain_to_category 57 | 58 | 59 | def create_bad_urls_index(artifacts_dir: Path, raw_categories: List[str]): 60 | r""" update the URL blacklists from the University of Toulouse: 61 | 62 | @param artifacts_dir: (Path) The path to the resources directory 63 | @param raw_categories: (List[str]) The domain categories 64 | 65 | """ 66 | 67 | ut1_blacklist_dir = artifacts_dir / "bad_urls" 68 | 69 | if not ut1_blacklist_dir.exists(): 70 | ut1_blacklist_dir.mkdir(parents=True) 71 | 72 | print(f"fetching UT1 blacklist from {_UT1_BLACKLIST_URL}...") 73 | 74 | with urllib.request.urlopen(_UT1_BLACKLIST_URL) as response: 75 | with tarfile.open(fileobj=response, mode="r|gz") as tar: 76 | tar.extractall(path=ut1_blacklist_dir) 77 | 78 | with open(ut1_blacklist_dir / "_FETCH_TIMESTAMP", "w") as f: 79 | f.write(str(int(time.time()))) 80 | 81 | print(f"raw UT1 list fetched.") 82 | 83 | category_index = _build_category_index(raw_categories) 84 | 85 | # convert the raw UT1 blacklist to a domain -> category_id mapping where 86 | # a category corresponds to any combination of raw categories. 87 | domain_to_category_list = _domain_to_category_list_mapping( 88 | ut1_blacklist_dir, raw_categories 89 | ) 90 | 91 | domain_to_category_id = { 92 | dom: category_index[categories] 93 | for dom, categories in domain_to_category_list.items() 94 | } 95 | 96 | with open(ut1_blacklist_dir / "domain_to_category_id.json", "w") as f: 97 | json.dump(domain_to_category_id, f) 98 | 99 | # save the category index as int -> category mapping 100 | category_index = { 101 | i: categories for categories, i in category_index.items() 102 | } 103 | with open(ut1_blacklist_dir / "category_index.json", "w") as f: 104 | json.dump(category_index, f) 105 | 106 | 107 | def create_bad_words_list(artifacts_dir: Path, lang: str): 108 | r""" Fetch the LDNOOBW word list 109 | 110 | Args: 111 | artifacts_dir (Path): The path to the resources directory 112 | lang (str): The language to fetch the word list for 113 | """ 114 | 115 | ldnoobw_dir = artifacts_dir / "bad_words" 116 | 117 | if not ldnoobw_dir.exists(): 118 | ldnoobw_dir.mkdir(parents=True) 119 | 120 | word_list_fp = ldnoobw_dir / f"{lang}.txt" 121 | url = _LDNOOBW_URL.format(lang=lang) 122 | 123 | print(f"fetching bad words list from {url}...") 124 | 125 | response = requests.get(url) 126 | if response.status_code != 200: 127 | raise Exception(f"{response.status_code} -- {url}.") 128 | 129 | data = response.content.decode('utf-8') 130 | 131 | with open(ldnoobw_dir / f"_{lang}_FETCH_TIMESTAMP", "w") as f: 132 | f.write(str(int(time.time()))) 133 | 134 | data = set(w for w in data.splitlines() if w is not None) 135 | 136 | with open(word_list_fp, 'w') as f: 137 | f.write('\n'.join(data)) 138 | 139 | print(f"bad words list ({lang}) updated.") 140 | 141 | 142 | def main(): 143 | parser = argparse.ArgumentParser() 144 | parser.add_argument("--langs", type=str, nargs="+") 145 | parser.add_argument("--artifacts_dir", type=str) 146 | parser.add_argument("--block_categories", type=str, nargs="+") 147 | args = parser.parse_args() 148 | 149 | artifacts_dir = Path(args.artifacts_dir) 150 | artifacts_dir.mkdir(parents=True, exist_ok=True) 151 | 152 | # fetch ut1 blacklist 153 | create_bad_urls_index(artifacts_dir=artifacts_dir, 154 | raw_categories=args.block_categories) 155 | 156 | # fetch ldnoobw 157 | langs = set(args.langs) 158 | for lang in langs: 159 | try: 160 | create_bad_words_list(lang=lang, artifacts_dir=artifacts_dir) 161 | except Exception as e: 162 | print(f"Failed to fetch LDNOOBW {lang}: {e}") 163 | 164 | 165 | if __name__ == '__main__': 166 | main() 167 | -------------------------------------------------------------------------------- /app/src/artifacts/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/togethercomputer/RedPajama-Data/6d2cee9df2b0204dd2bcb00bf06b5a7b1d7432d7/app/src/artifacts/utils/__init__.py -------------------------------------------------------------------------------- /app/src/artifacts/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import string 3 | 4 | # text normalization: lowercasing and removing punctuation 5 | TRANSLATION_TABLE = str.maketrans( 6 | string.ascii_lowercase + string.ascii_uppercase + "\n", 7 | string.ascii_lowercase * 2 + " ", 8 | string.punctuation 9 | ) 10 | 11 | 12 | def normalize_text(text: str, max_words: int = -1): 13 | r""" Normalize text by lowercasing and removing punctuation; if max words 14 | is larger than 0, then a random but contiguous span of max_words is 15 | selected from the text. 16 | 17 | Args: 18 | text: text to normalize 19 | max_words: maximum number of words to keep in text 20 | 21 | Returns: 22 | normalized text 23 | """ 24 | text = text.translate(TRANSLATION_TABLE) 25 | text = text.split() 26 | num_words = len(text) 27 | 28 | # pick a random span inside the text if it is too long 29 | if num_words > max_words > 0: 30 | start = random.randint(0, num_words - max_words) 31 | text = text[start:start + max_words] 32 | 33 | return " ".join(text) 34 | -------------------------------------------------------------------------------- /app/src/artifacts/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from logging.handlers import QueueHandler 3 | import multiprocessing as mp 4 | 5 | __all__ = [ 6 | "worker_logger_configurer", 7 | "listener_logger_configurer", 8 | "LOG_FMT" 9 | ] 10 | 11 | LOG_FMT = '[%(asctime)s]::(PID %(process)d)::%(levelname)-2s::%(message)s' 12 | 13 | 14 | def worker_logger_configurer(queue: mp.Queue, level=logging.DEBUG): 15 | root = logging.getLogger() 16 | 17 | if not root.hasHandlers(): 18 | h = logging.handlers.QueueHandler(queue) 19 | root.addHandler(h) 20 | 21 | root.setLevel(level) 22 | 23 | 24 | def listener_logger_configurer(logfile, level=logging.DEBUG): 25 | root = logging.getLogger() 26 | formatter = logging.Formatter(LOG_FMT) 27 | 28 | # write to log file 29 | if logfile is not None: 30 | if not logfile.parent.exists(): 31 | logfile.parent.mkdir(parents=True) 32 | file_handler = logging.FileHandler(logfile) 33 | file_handler.setFormatter(formatter) 34 | root.addHandler(file_handler) 35 | 36 | # write to stdout 37 | stream_handler = logging.StreamHandler() 38 | stream_handler.setFormatter(formatter) 39 | root.addHandler(stream_handler) 40 | 41 | root.setLevel(level) 42 | -------------------------------------------------------------------------------- /app/src/bloomfilter.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | 4 | import boto3 5 | import concurrent.futures 6 | from dataclasses import dataclass 7 | from datetime import datetime as dt 8 | import gzip 9 | import io 10 | import logging 11 | import msgspec 12 | import os 13 | from pathlib import Path 14 | import polars as pl 15 | import progiter 16 | import pyarrow as pa 17 | import pybloomfilter 18 | import re 19 | from typing import Tuple 20 | from urllib.parse import urlparse 21 | from typing import Dict, List 22 | 23 | from utilities.logging import configure_logger 24 | from utilities.io import ParquetBatchWriter 25 | 26 | 27 | @dataclass 28 | class ReadStatus: 29 | is_success: bool 30 | msg: str 31 | uri: str 32 | 33 | 34 | class Deduper: 35 | r""" Bloom filter for exact deduplication of ccnet shards. Based on 36 | document contents. """ 37 | __slots__ = ( 38 | "_args", "_logger", "_job_id", "_input_base_uri", "_scheme", 39 | "_output_fp", "_out_schema", "_bloom_fp" 40 | ) 41 | 42 | # regex to extract filepaths from source file listings 43 | input_patterns = [ 44 | re.compile(r".*/[a-z]{2}_middle\.json\.gz"), 45 | re.compile(r".*/[a-z]{2}_head\.json\.gz") 46 | ] 47 | 48 | output_pattern = "duplicates-{timestamp}-{snapshot}.parquet" 49 | 50 | def __init__(self): 51 | self._job_id = dt.now().strftime("%Y%m%d_%H%M%S") 52 | self._args = self.__parse_arguments() 53 | 54 | # set random seed 55 | random.seed(self._args.seed) 56 | 57 | # parse args 58 | self._input_base_uri = self._args.input_base_uri 59 | self._scheme = urlparse(self._input_base_uri).scheme 60 | 61 | # init logging 62 | logfile = Path(self._args.output_dir) / "logs" / f"{self._job_id}.log" 63 | configure_logger(logfile=logfile, level=logging.INFO, stream=False) 64 | self._logger = logging.getLogger() 65 | 66 | # output writer 67 | self._output_fp = Path(self._args.output_dir) / "duplicates.parquet" 68 | self._out_schema = pa.schema([ 69 | ("shard_id", pa.string()), 70 | ("doc_id", pa.string()), 71 | ("digest", pa.string()) 72 | ]) 73 | 74 | # log setup 75 | for attr in [ 76 | "listings", "input_base_uri", "output_dir", "parallel_readers", 77 | "capacity", "error_rate", "seed", "max_inputs", "batch_size" 78 | ]: 79 | self._logger.info(f"{attr}: {getattr(self._args, attr)}") 80 | 81 | def __parse_arguments(self) -> argparse.Namespace: 82 | 83 | if self.__doc__ is not None: 84 | description = " - " + self.__doc__ 85 | else: 86 | description = self.__class__.__name__ 87 | 88 | parser = argparse.ArgumentParser( 89 | prog=self.__class__.__name__, description=description 90 | ) 91 | 92 | # io 93 | parser.add_argument( 94 | "--listings", type=str, default=None, 95 | help="Path to a file containing paths to ccnet shards; needs to " 96 | "match with the input_base_uri argument." 97 | ) 98 | parser.add_argument( 99 | "--input_base_uri", type=str, default=None, 100 | help="base uri of the input files." 101 | ) 102 | parser.add_argument( 103 | "--output_dir", type=str, default=None, 104 | help="directory where the output will be stored." 105 | ) 106 | parser.add_argument( 107 | "--s3_profile", type=str, default="default", 108 | help="profile name of the s3 client." 109 | ) 110 | parser.add_argument( 111 | "--endpoint_url", type=str, default=None, 112 | help="S3 bucket endpoint url." 113 | ) 114 | parser.add_argument( 115 | "--parallel_readers", type=int, default=1, 116 | help="number of parallel reader processes. Defaults to 1." 117 | ) 118 | parser.add_argument( 119 | "--max_inputs", type=int, default=None, 120 | help="maximum number of inputs to process. For debugging." 121 | ) 122 | parser.add_argument( 123 | "--batch_size", type=int, default=None, 124 | help="number of listings to be processed per process." 125 | ) 126 | 127 | parser.add_argument( 128 | "--seed", type=int, default=42, 129 | help="random seed." 130 | ) 131 | 132 | # dedup params 133 | parser.add_argument( 134 | "--capacity", type=int, default=None, 135 | help="Capacity of the bloom filter. This is the maximum number of " 136 | "unique documents that can be stored in the filter while " 137 | "keeping the error rate under `error_rate`." 138 | ) 139 | parser.add_argument( 140 | "--error_rate", type=float, default=0.01, 141 | help="false positive probability that will hold given that " 142 | "'capacity' is not exceeded. Defaults to 0.001" 143 | ) 144 | args = parser.parse_args() 145 | 146 | return args 147 | 148 | def __init_client(self): 149 | if self._scheme != "s3": 150 | return None 151 | 152 | session = boto3.Session(profile_name=self._args.s3_profile) 153 | return session.client( 154 | service_name='s3', 155 | endpoint_url=self._args.endpoint_url, 156 | config=boto3.session.Config( 157 | signature_version="s3v4", 158 | retries={'max_attempts': 5, 'mode': 'standard'} 159 | ) 160 | ) 161 | 162 | def __filter_listings(self, obj_key: str): 163 | for pat in self.input_patterns: 164 | if pat.search(obj_key) is not None: 165 | return True 166 | 167 | return False 168 | 169 | def __parse_listings(self): 170 | # build input uris 171 | with open(self._args.listings, "r") as f: 172 | uris = list( 173 | map(lambda ls: os.path.join(self._input_base_uri, ls.strip()), 174 | filter(self.__filter_listings, f.readlines())) 175 | ) 176 | 177 | return uris 178 | 179 | @staticmethod 180 | def __load_from_s3(uri, client): 181 | try: 182 | streaming_body = client.get_object( 183 | Bucket=uri.netloc, Key=uri.path.lstrip("/") 184 | )["Body"] 185 | buffer = io.BytesIO(streaming_body.read()) 186 | msg = f"__S3_URI_READ_SUCCESS__ success reading {uri.path}" 187 | is_success = True 188 | except Exception as e: 189 | msg = ( 190 | f"__S3_URI_READ_ERROR__ failed reading {uri.path}: " 191 | f"caught exception {e.__class__.__name__}: {e}" 192 | ) 193 | buffer = None 194 | is_success = False 195 | 196 | return is_success, msg, buffer 197 | 198 | @staticmethod 199 | def __load_from_disk(uri): 200 | try: 201 | with open(uri.path, "rb") as f: 202 | buffer = io.BytesIO(f.read()) 203 | msg = f"__DISK_URI_READ_SUCCESS__ success reading {uri.path}" 204 | is_success = True 205 | except Exception as e: 206 | msg = ( 207 | f"__DISK_URI_READ_ERROR__ failed reading {uri.path}: " 208 | f"caught exception {e.__class__.__name__}: {e}" 209 | ) 210 | buffer = None 211 | is_success = False 212 | 213 | return is_success, msg, buffer 214 | 215 | def _load_file(self, uri, client) -> Tuple[ReadStatus, io.BytesIO]: 216 | if uri.scheme == "s3": 217 | is_success, msg, buffer = self.__load_from_s3(uri, client) 218 | elif uri.scheme == "file": 219 | is_success, msg, buffer = self.__load_from_disk(uri) 220 | else: 221 | raise ValueError(f"Unknown scheme {uri.scheme}") 222 | 223 | read_status = ReadStatus( 224 | is_success=is_success, msg=msg, uri=uri.geturl() 225 | ) 226 | return read_status, buffer 227 | 228 | def _load_and_parse_inputs( 229 | self, input_chunk 230 | ) -> Dict[str, Tuple[ReadStatus, List[Dict]]]: 231 | # build msgspec decoder 232 | decoder = msgspec.json.Decoder( 233 | type=msgspec.defstruct(name="Record", fields=[("digest", str)]) 234 | ) 235 | 236 | client = self.__init_client() 237 | data = {} 238 | 239 | for uri in input_chunk: 240 | read_status, buffer = self._load_file( 241 | uri=urlparse(uri), client=client 242 | ) 243 | 244 | if not read_status.is_success: 245 | data[uri] = (read_status, []) 246 | continue 247 | 248 | shard_id = read_status.uri.replace( 249 | self._input_base_uri, "" 250 | ).lstrip("/") 251 | 252 | uri_data = [] 253 | 254 | try: 255 | with gzip.open(buffer, "rb") as f: 256 | for idx, obj in enumerate(f): 257 | rec = decoder.decode(obj) 258 | digest = str(getattr(rec, "digest")).replace( 259 | "sha1:", "" 260 | ) 261 | uri_data.append({ 262 | "shard_id": shard_id, 263 | "doc_id": f"{shard_id}/{idx}", 264 | "digest": digest 265 | }) 266 | except Exception as e: 267 | uri_data = [] 268 | read_status.msg = ( 269 | f"__S3_URI_DECODE_ERROR__ failed decoding {uri}: " 270 | f"caught exception {e.__class__.__name__}: {e}" 271 | ) 272 | read_status.is_success = False 273 | 274 | data[uri] = (read_status, uri_data) 275 | 276 | del buffer 277 | 278 | return data 279 | 280 | def __parallel_run(self, input_uris): 281 | # shuffle input uris 282 | random.shuffle(input_uris) 283 | 284 | if self._args.max_inputs is not None: 285 | self._logger.info(f"Limiting inputs to {self._args.max_inputs}") 286 | input_uris = input_uris[:self._args.max_inputs] 287 | 288 | # divide input uris into snapshots 289 | snapsh_re = re.compile(r'\b\d{4}-\d{2}\b') 290 | snapshots = {} 291 | for uri in input_uris: 292 | snapshot = snapsh_re.search(uri).group(0) 293 | if snapshot not in snapshots: 294 | snapshots[snapshot] = [uri] 295 | else: 296 | snapshots[snapshot].append(uri) 297 | 298 | snapshot_ids_sorted = sorted(snapshots.keys(), reverse=True) 299 | 300 | # init bloomfilter 301 | bloomfilter = pybloomfilter.BloomFilter( 302 | capacity=self._args.capacity, 303 | error_rate=self._args.error_rate 304 | ) 305 | 306 | self._logger.info(f"Filter capacity: {bloomfilter.capacity}") 307 | self._logger.info(f"Filter error rate: {bloomfilter.error_rate}") 308 | self._logger.info(f"Filter hash seeds: {bloomfilter.hash_seeds}") 309 | 310 | num_docs, num_dupes = 0, 0 311 | 312 | # progress bars 313 | pman = progiter.ProgressManager(backend="rich") 314 | total_progress = pman.progiter( 315 | total=self._args.capacity, postfix_str="Duplicates: --" 316 | ) 317 | download_progress = pman.progiter( 318 | total=len(input_uris), desc="Download" 319 | ) 320 | 321 | num_failed_uri = 0 322 | num_succ_uri = 0 323 | 324 | for snapsh_id in snapshot_ids_sorted: 325 | uri_list = snapshots[snapsh_id] 326 | random.shuffle(uri_list) 327 | 328 | uri_list_partitioned = [ 329 | uri_list[i:i + self._args.batch_size] 330 | for i in range(0, len(uri_list), self._args.batch_size) 331 | ] 332 | 333 | self._logger.info(f"__SNAPSHOT_START__ {snapsh_id}") 334 | 335 | # output writer 336 | timestamp = dt.now().strftime("%Y%m%d_%H%M%S") 337 | output_fp = ( 338 | Path(self._args.output_dir) / 339 | self.output_pattern.format( 340 | timestamp=timestamp, snapshot=snapsh_id 341 | ) 342 | ) 343 | out_writer = ParquetBatchWriter( 344 | output_fp=output_fp, schema=self._out_schema 345 | ) 346 | 347 | try: 348 | with concurrent.futures.ProcessPoolExecutor( 349 | max_workers=self._args.parallel_readers 350 | ) as executor: 351 | futures = { 352 | executor.submit( 353 | self._load_and_parse_inputs, input_chunk 354 | ): i 355 | for i, input_chunk in enumerate(uri_list_partitioned) 356 | } 357 | 358 | for future in concurrent.futures.as_completed(futures): 359 | data_chunks = future.result() 360 | del futures[future] 361 | download_progress.step(len(data_chunks)) 362 | 363 | for ( 364 | uri, (read_status, uri_data) 365 | ) in data_chunks.items(): 366 | 367 | if not read_status.is_success: 368 | self._logger.error(read_status.msg) 369 | num_failed_uri += 1 370 | continue 371 | 372 | num_succ_uri += 1 373 | download_progress.set_postfix_str( 374 | f"success: {num_succ_uri} " 375 | f"({num_failed_uri} failed)" 376 | ) 377 | 378 | self._logger.info(read_status.msg) 379 | 380 | for record in uri_data: 381 | digest = record["digest"] 382 | 383 | if bloomfilter.add(digest): 384 | out_writer.update_batch(obj=record) 385 | num_dupes += 1 386 | 387 | num_docs += 1 388 | total_progress.step(1) 389 | 390 | if num_docs % (1024 ** 2) == 0: 391 | out_writer.write_batch() 392 | 393 | dupe_prop = round(100 * num_dupes / num_docs, 2) 394 | total_progress.set_postfix_str( 395 | f"Duplicates: {num_dupes} / {num_docs}" 396 | f" ({dupe_prop:.2f}%)" 397 | ) 398 | 399 | except KeyboardInterrupt: 400 | self._logger.info("Keyboard interrupt. Stopping.") 401 | executor.shutdown(wait=False, cancel_futures=True) 402 | out_writer.close() 403 | break 404 | except Exception as e: 405 | self._logger.error( 406 | f"Caught exception {e.__class__.__name__}: {e}" 407 | ) 408 | executor.shutdown(wait=False, cancel_futures=True) 409 | out_writer.close() 410 | self._logger.info(f"__SNAPSHOT_FAIL__ {snapsh_id}") 411 | continue 412 | 413 | out_writer.close() 414 | self._logger.info(f"__SNAPSHOT_FINISH__ {snapsh_id}") 415 | 416 | pman.stop() 417 | bloomfilter.close() 418 | 419 | self._logger.info(f"Filtering complete.") 420 | 421 | def run(self): 422 | start_time = dt.now() 423 | print(f"start @ {start_time.strftime('%Y-%m-%d %H:%M:%S')}") 424 | self.__parallel_run(input_uris=self.__parse_listings()) 425 | end_time = dt.now() 426 | print(f"end @ {end_time.strftime('%Y-%m-%d %H:%M:%S')}") 427 | end_str = f"Total time: {end_time - start_time}" 428 | print(end_str) 429 | self._logger.info(end_str) 430 | 431 | def __result_summary(self): 432 | dump_reg = "(\d{4}-\d{2})\/" 433 | # read duplicates 434 | query = ( 435 | pl.scan_parquet(self._output_fp) 436 | .with_columns( 437 | pl.col("shard_id").str.extract(dump_reg, 1).alias("snapshot") 438 | ) 439 | .group_by("snapshot") 440 | .agg(pl.count()) 441 | ) 442 | 443 | stats = query.collect() 444 | 445 | with pl.Config(fmt_str_lengths=1000, tbl_rows=100): 446 | print(stats) 447 | 448 | 449 | if __name__ == '__main__': 450 | deduper = Deduper() 451 | deduper.run() 452 | -------------------------------------------------------------------------------- /app/src/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/togethercomputer/RedPajama-Data/6d2cee9df2b0204dd2bcb00bf06b5a7b1d7432d7/app/src/core/__init__.py -------------------------------------------------------------------------------- /app/src/core/constants.py: -------------------------------------------------------------------------------- 1 | PRECISION = 8 2 | CCNET_LABEL = "__label__cc" 3 | -------------------------------------------------------------------------------- /app/src/core/data_types.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from msgspec import Struct 3 | 4 | from typing import List, Tuple, Optional, Dict 5 | from typing_extensions import TypeAlias 6 | 7 | ScoreType: TypeAlias = Tuple[int, int, Optional[float]] 8 | SignalType: TypeAlias = List[ScoreType] 9 | 10 | 11 | @dataclass 12 | class TextSlice: 13 | text: str 14 | start: int 15 | end: int 16 | 17 | def __len__(self): 18 | return len(self.text) 19 | 20 | 21 | class InputSpec(Struct): 22 | raw_content: str 23 | url: str 24 | nlines: int 25 | original_nlines: int 26 | source_domain: str 27 | length: int 28 | original_length: int 29 | language: str 30 | language_score: float 31 | perplexity: float 32 | bucket: str 33 | digest: str 34 | cc_segment: str 35 | date_download: str 36 | 37 | 38 | class OutputSpec(Struct): 39 | id: str 40 | id_int: int 41 | metadata: Dict[str, str] 42 | quality_signals: Dict[str, List[Tuple[int, int, Optional[float]]]] 43 | 44 | 45 | 46 | -------------------------------------------------------------------------------- /app/src/core/document.py: -------------------------------------------------------------------------------- 1 | from nltk.tokenize import WordPunctTokenizer 2 | import re 3 | from typing import Optional, Tuple, Callable 4 | 5 | from utilities.text import normalize, form_ngrams 6 | from core.data_types import TextSlice 7 | from core.quality_signals.utils.dsir import hash_feature 8 | 9 | _word_tokenizer = WordPunctTokenizer() 10 | 11 | 12 | def _compute_ngrams(text_seq, n): 13 | return tuple(form_ngrams(iter(text_seq), n)) 14 | 15 | 16 | def split_paragraphs( 17 | text: str, normalizer: Callable[[str], str], remove_empty: bool = True 18 | ) -> Tuple[TextSlice]: 19 | """ 20 | This function is adapted from dolma: https://github.com/allenai/dolma 21 | 22 | Split a string into paragraphs. A paragraph is defined as a sequence of 23 | zero or more characters, followed by a newline character, or a sequence 24 | of one or more characters, followed by the end of the string. 25 | """ 26 | text_slices = tuple( 27 | TextSlice(normalizer(text[match.start():match.end()]), match.start(), 28 | match.end()) 29 | for match in re.finditer(r"([^\n]*\n|[^\n]+$)", text) 30 | ) 31 | 32 | if remove_empty is True: 33 | text_slices = tuple( 34 | text_slice for text_slice in text_slices if text_slice[0].strip() 35 | ) 36 | 37 | return text_slices 38 | 39 | 40 | class Document: 41 | __slots__ = ( 42 | "_raw_content", "_normalized_content", "_raw_lines", 43 | "_normalized_lines", "_raw_words", "_normalized_words", 44 | "_num_raw_words", "_num_normalized_words", "_domain", "_raw_2grams", 45 | "_raw_3grams", "_norm_2grams", "_norm_3grams", "_norm_4grams", 46 | "_hash_features" 47 | ) 48 | 49 | def __init__( 50 | self, content: str, domain: Optional[str], 51 | precompute_ngrams: bool = False, 52 | precompute_hash_features: bool = False, 53 | dsir_buckets: Optional[int] = None 54 | ): 55 | self._raw_content = content 56 | self._domain = domain 57 | 58 | # the normalized content: lowercased and punctuation removed 59 | self._normalized_content = normalize(content) 60 | 61 | # the lines of the document (split by newline) 62 | self._raw_lines: Tuple[TextSlice] = split_paragraphs( 63 | text=content, normalizer=lambda x: x, remove_empty=False 64 | ) 65 | 66 | # the lines of the document (split by newline), normalized 67 | self._normalized_lines: Tuple[TextSlice] = split_paragraphs( 68 | text=content, normalizer=normalize, remove_empty=False 69 | ) 70 | 71 | # the words of the document after normalization 72 | self._raw_words = tuple(_word_tokenizer.tokenize(self._raw_content)) 73 | 74 | # the normalized words of the document (split by whitespace) 75 | self._normalized_words = tuple(self._normalized_content.split()) 76 | 77 | # get number of words before and after normalization 78 | self._num_raw_words = len(self._raw_words) 79 | self._num_normalized_words = len(self._normalized_words) 80 | 81 | # precompute ngrams 82 | if precompute_ngrams: 83 | # raw grams 84 | self._raw_2grams = _compute_ngrams(self._raw_words, 2) 85 | self._raw_3grams = _compute_ngrams(self._raw_words, 3) 86 | 87 | # normalized grams 88 | self._norm_2grams = _compute_ngrams(self._normalized_words, 2) 89 | self._norm_3grams = _compute_ngrams(self._normalized_words, 3) 90 | self._norm_4grams = _compute_ngrams(self._normalized_words, 4) 91 | else: 92 | self._raw_2grams = None 93 | self._raw_3grams = None 94 | self._norm_2grams = None 95 | self._norm_3grams = None 96 | self._norm_4grams = None 97 | 98 | # precomupte hash features 99 | if precompute_hash_features: 100 | bigrams = self._raw_2grams or _compute_ngrams(self._raw_words, 2) 101 | self._hash_features = hash_feature( 102 | unigrams=self._raw_words, 103 | bigrams=bigrams, 104 | buckets=dsir_buckets 105 | ) 106 | else: 107 | self._hash_features = None 108 | 109 | def __len__(self): 110 | return len(self._raw_content) 111 | 112 | @property 113 | def raw_content(self): 114 | return self._raw_content 115 | 116 | @property 117 | def normalized_content(self): 118 | return self._normalized_content 119 | 120 | @property 121 | def raw_lines(self): 122 | return self._raw_lines 123 | 124 | @property 125 | def normalized_lines(self): 126 | return self._normalized_lines 127 | 128 | @property 129 | def raw_words(self): 130 | return self._raw_words 131 | 132 | @property 133 | def normalized_words(self): 134 | return self._normalized_words 135 | 136 | @property 137 | def num_raw_words(self): 138 | return self._num_raw_words 139 | 140 | @property 141 | def num_normalized_words(self): 142 | return self._num_normalized_words 143 | 144 | @property 145 | def domain(self): 146 | return self._domain 147 | 148 | @property 149 | def raw_1grams(self): 150 | return self._raw_words 151 | 152 | @property 153 | def raw_2grams(self): 154 | return self._raw_2grams 155 | 156 | @property 157 | def raw_3grams(self): 158 | return self._raw_3grams 159 | 160 | @property 161 | def norm_1grams(self): 162 | return self._normalized_words 163 | 164 | @property 165 | def norm_2grams(self): 166 | return self._norm_2grams 167 | 168 | @property 169 | def norm_3grams(self): 170 | return self._norm_3grams 171 | 172 | @property 173 | def norm_4grams(self): 174 | return self._norm_4grams 175 | 176 | @property 177 | def hash_features(self): 178 | return self._hash_features 179 | -------------------------------------------------------------------------------- /app/src/core/exceptions.py: -------------------------------------------------------------------------------- 1 | class S3ReadError(Exception): 2 | def __init__(self, message): 3 | super().__init__(message) 4 | 5 | 6 | class S3WriteError(Exception): 7 | def __init__(self, message): 8 | super().__init__(message) 9 | 10 | 11 | class LocalReadError(Exception): 12 | def __init__(self, message): 13 | super().__init__(message) 14 | 15 | 16 | class UnknownReadError(Exception): 17 | def __init__(self, message): 18 | super().__init__(message) 19 | -------------------------------------------------------------------------------- /app/src/core/quality_signals/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/togethercomputer/RedPajama-Data/6d2cee9df2b0204dd2bcb00bf06b5a7b1d7432d7/app/src/core/quality_signals/__init__.py -------------------------------------------------------------------------------- /app/src/core/quality_signals/base.py: -------------------------------------------------------------------------------- 1 | from core.document import Document 2 | from core.data_types import SignalType 3 | 4 | 5 | class RPSBase: 6 | r""" Base class for RP signal functions. Each child class must implement 7 | the __call__ method. The __call__ method takes a document as input and 8 | returns a score. """ 9 | DATA_TYPE = SignalType 10 | 11 | RPS_PREFIX: str = "RPS_" 12 | 13 | __slots__ = ["__field_name"] 14 | 15 | def __init__(self, *args, **kwargs): # noqa 16 | # make sure all classes start with RPS_; this is to ensure that 17 | # the get_rule_based_signals function works correctly when new signal 18 | # functions are added 19 | assert self.__class__.__name__.startswith(self.RPS_PREFIX), \ 20 | f"Name of signal function must" \ 21 | f" start with {self.RPS_PREFIX}; got {self.__class__.__name__}" 22 | 23 | self.__field_name = self.__class__.__name__.lower() 24 | 25 | def __call__(self, document: Document): 26 | raise NotImplementedError 27 | 28 | @property 29 | def field_name(self): 30 | return self.__field_name 31 | -------------------------------------------------------------------------------- /app/src/core/quality_signals/classifiers.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import List, Tuple, Type 3 | import fasttext 4 | 5 | from core.constants import PRECISION, CCNET_LABEL 6 | from core.quality_signals.base import RPSBase 7 | from core.document import Document 8 | from core.data_types import SignalType 9 | from core.quality_signals.utils.classifiers import \ 10 | preprocess_quality_classifier 11 | from utilities.register.registry_utils import * 12 | 13 | __all__ = [ 14 | "register_classifier_callables", "classifier_schema" 15 | ] 16 | 17 | 18 | def classifier_schema() -> List[Tuple[str, Type]]: 19 | r""" Returns a list of signal names and their data types """ 20 | return signal_schema(module=sys.modules[__name__]) 21 | 22 | 23 | def register_classifier_callables( 24 | wikiref_model: str, 25 | palm_model: str, 26 | wikipedia_model: str 27 | ) -> List[RPSBase]: 28 | r""" Returns a list of signal functions (i.e., RPSBase instances) that 29 | are used to extract content signals from a document. 30 | 31 | Args: 32 | wikiref_model: A fasttext model trained on Wikipedia references. 33 | palm_model: A fasttext model trained on ccnet vs 34 | {books, openwebtext, wikipedia}. 35 | wikipedia_model: A fasttext model trained on Wikipedia articles. 36 | 37 | Returns: 38 | A list of signal function class instances. 39 | """ 40 | return list(map( 41 | lambda cls: cls( 42 | wikiref_model=wikiref_model, 43 | palm_model=palm_model, 44 | wikipedia_model=wikipedia_model, 45 | ), 46 | get_callables_from_module(module=sys.modules[__name__]) 47 | )) 48 | 49 | 50 | class BaseMLSignal(RPSBase): 51 | __slots__ = "_ft_model" 52 | 53 | def __init__(self, ft_model_file: str): 54 | super(BaseMLSignal, self).__init__() 55 | if ft_model_file is None: 56 | self._ft_model = None 57 | else: 58 | self._ft_model = fasttext.load_model(str(ft_model_file)) 59 | 60 | def __call__(self, document: Document) -> SignalType: 61 | if self._ft_model is None: 62 | return [(0, len(document), None)] 63 | 64 | if len(document.raw_content) == 0: 65 | return [(0, len(document), None)] 66 | 67 | text = preprocess_quality_classifier(document=document) 68 | pred = self._ft_model.predict(text=text) 69 | 70 | (pred_label, pred_prob) = pred 71 | pred_label = pred_label[0] 72 | pred_prob = pred_prob[0] 73 | 74 | if pred_label == CCNET_LABEL: 75 | high_quality_score = 1 - pred_prob 76 | else: 77 | high_quality_score = pred_prob 78 | 79 | score = round(float(high_quality_score), PRECISION) 80 | return [(0, len(document), score)] 81 | 82 | 83 | class RPS_Doc_ML_Wikiref_Score(BaseMLSignal): # noqa 84 | r""" Fasttext classifier prediction for the document being a Wikipedia 85 | reference. This is the same fasttext model as in the RedPajama-1T 86 | dataset.""" 87 | __slots__ = () 88 | 89 | def __init__(self, wikiref_model: str, *args, **kwargs): # noqa 90 | super(RPS_Doc_ML_Wikiref_Score, self).__init__( 91 | ft_model_file=wikiref_model 92 | ) 93 | 94 | 95 | class RPS_Doc_ML_Palm_Score(BaseMLSignal): # noqa 96 | r""" Fasttext classifier prediction for the document being a Wikipedia 97 | article, OpenWebText sample or a RedPajama-V1 book.""" 98 | __slots__ = () 99 | 100 | def __init__(self, palm_model: str, *args, **kwargs): # noqa 101 | super(RPS_Doc_ML_Palm_Score, self).__init__( 102 | ft_model_file=palm_model 103 | ) 104 | 105 | 106 | class RPS_Doc_ML_Wikipedia_Score(BaseMLSignal): # noqa 107 | r""" Fasttext classifier prediction for the document being a Wikipedia 108 | article.""" 109 | __slots__ = () 110 | 111 | def __init__(self, wikipedia_model: str, *args, **kwargs): # noqa 112 | super(RPS_Doc_ML_Wikipedia_Score, self).__init__( 113 | ft_model_file=wikipedia_model 114 | ) 115 | -------------------------------------------------------------------------------- /app/src/core/quality_signals/content.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | import operator 4 | from pathlib import Path 5 | from typing import List, Tuple, Type 6 | 7 | from core.constants import PRECISION 8 | from core.quality_signals.base import RPSBase 9 | from core.quality_signals.utils.stop_words import get_stop_words 10 | from core.document import Document 11 | from core.data_types import SignalType 12 | from core.quality_signals.utils.content import \ 13 | load_bad_words, load_bad_urls_index 14 | from utilities.register.registry_utils import * 15 | from utilities.text import form_ngrams 16 | 17 | __all__ = ["register_content_callables", "content_schema"] 18 | 19 | 20 | def content_schema() -> List[Tuple[str, Type]]: 21 | r""" Returns a list of signal names and their data types """ 22 | return signal_schema(module=sys.modules[__name__]) 23 | 24 | 25 | def register_content_callables( 26 | language: str, bad_urls_dir: str, bad_words_dir: str 27 | ) -> List[RPSBase]: 28 | r""" Returns a list of signal functions (i.e., RPSBase instances) that 29 | are used to extract content signals from a document. 30 | 31 | Args: 32 | language: The language of the document. 33 | bad_urls_dir: directory containing the UT1 blacklist. 34 | bad_words_dir: directory containing the LDNOOBW blacklist. 35 | 36 | Returns: 37 | A list of signal function class instances. 38 | """ 39 | return list(map( 40 | lambda cls: cls( 41 | language=language, 42 | bad_urls_dir=bad_urls_dir, 43 | bad_words_dir=bad_words_dir 44 | ), 45 | get_callables_from_module(module=sys.modules[__name__]) 46 | )) 47 | 48 | 49 | class RPS_Doc_LDNOOBW_Words(RPSBase): # noqa 50 | r""" The number of sequences of words that are contained in the 51 | List-of-Dirty-Naughty-Obscene-and-Otherwise-Bad-Words blocklist. The 52 | blocklist is obtained from 53 | https://github.com/LDNOOBW/List-of-Dirty-Naughty-Obscene-and-Otherwise-Bad-Words 54 | """ 55 | __slots__ = ["_block_words", "_gram_vals"] 56 | 57 | def __init__( 58 | self, bad_words_dir: str, language: str, *args, **kwargs # noqa 59 | ): 60 | super(RPS_Doc_LDNOOBW_Words, self).__init__() 61 | self._block_words = load_bad_words( 62 | bad_words_dir=Path(bad_words_dir), lang=language 63 | ) 64 | 65 | # cache the number of words in each block list entry 66 | self._gram_vals = set(map( 67 | lambda w: 1 + operator.countOf(w, " "), self._block_words 68 | )) 69 | 70 | def __call__(self, document: Document) -> SignalType: 71 | if len(document.normalized_content) == 0: 72 | return [(0, len(document), .0)] 73 | 74 | num_dirty = 0 75 | 76 | # for each ngram value, count the number of ngrams in the document 77 | # which are also in the block words list 78 | for n in self._gram_vals: 79 | if n == 1: 80 | num_dirty += sum( 81 | 1 for _ in filter( 82 | lambda w: w in self._block_words, 83 | document.normalized_words 84 | ) 85 | ) 86 | continue 87 | 88 | num_dirty += sum( 89 | 1 for _ in filter( 90 | lambda t: " ".join(t) in self._block_words, 91 | # try to fetch the cached ngrams, otherwise compute them 92 | # on the fly 93 | getattr(document, f"norm_{n}grams", None) 94 | or 95 | form_ngrams(iter(document.normalized_words), n) 96 | ) 97 | ) 98 | 99 | score = float(num_dirty) 100 | return [(0, len(document), score)] 101 | 102 | 103 | class RPS_Doc_Lorem_Ipsum(RPSBase): # noqa 104 | r""" The ratio between the number of occurences of 'lorem ipsum' 105 | and the number of characters in the text after normalization. Text is 106 | normalized by lowercasing and removing punctuation. """ 107 | SEARCH_TEXT = "lorem ipsum" 108 | SEARCH_REGEX = re.compile(r"lorem ipsum", re.IGNORECASE) 109 | 110 | __slots__ = () 111 | 112 | def __call__(self, document: Document) -> SignalType: 113 | if len(document.normalized_content) == 0: 114 | return [(0, len(document), 0.0)] 115 | 116 | if self.SEARCH_TEXT not in document.normalized_content: 117 | return [(0, len(document), .0)] 118 | 119 | num_occurences = len(self.SEARCH_REGEX.findall( 120 | document.normalized_content 121 | )) 122 | 123 | score = float(num_occurences) / len(document.normalized_content) 124 | score = round(score, PRECISION) 125 | 126 | return [(0, len(document), score)] 127 | 128 | 129 | class RPS_Doc_Curly_Bracket(RPSBase): # noqa 130 | r""" The ratio between the number of occurences of '{' or '}' and the 131 | number of characters in the raw text. """ 132 | SEARCH_TEXT = ("{", "}") 133 | __slots__ = () 134 | 135 | def __call__(self, document: Document) -> SignalType: 136 | if len(document.raw_content) == 0: 137 | return [(0, len(document), .0)] 138 | 139 | if all(map(lambda x: x not in document.raw_content, self.SEARCH_TEXT)): 140 | return [(0, len(document), .0)] 141 | 142 | num_occurences = sum( 143 | map(lambda x: operator.countOf(document.raw_content, x), 144 | self.SEARCH_TEXT) 145 | ) 146 | 147 | score = float(num_occurences) / len(document.raw_content) 148 | score = round(score, PRECISION) 149 | 150 | return [(0, len(document), score)] 151 | 152 | 153 | class RPS_Doc_UT1_Blacklist(RPSBase): # noqa 154 | r""" An categorical id of the list of categories of the domain of the 155 | document. Categories are obtained from the UT1 blacklist. 156 | """ 157 | __slots__ = ["_ut1_mapping"] 158 | 159 | def __init__(self, bad_urls_dir: str, *args, **kwargs): # noqa 160 | super(RPS_Doc_UT1_Blacklist, self).__init__() 161 | self._ut1_mapping = load_bad_urls_index(Path(bad_urls_dir)) 162 | 163 | def __call__(self, document: Document) -> SignalType: 164 | score: int = self._ut1_mapping.get(document.domain, None) 165 | return [(0, len(document), score)] 166 | 167 | 168 | class RPS_Doc_Stop_Word_Fraction(RPSBase): # noqa 169 | r""" The ratio between the number of stop words and the number of words in 170 | the document. """ 171 | __slots__ = ["_stop_words"] 172 | 173 | def __init__(self, language: str, *args, **kwargs): # noqa 174 | super(RPS_Doc_Stop_Word_Fraction, self).__init__() 175 | self._stop_words = get_stop_words(language) 176 | 177 | def __call__(self, document: Document) -> SignalType: 178 | if len(document.normalized_words) == 0: 179 | return [(0, len(document), .0)] 180 | 181 | num_stop_words = sum( 182 | map(lambda w: w in self._stop_words, document.raw_words) 183 | ) 184 | 185 | score = float(num_stop_words) / document.num_raw_words 186 | score = round(score, PRECISION) 187 | 188 | return [(0, len(document), score)] 189 | -------------------------------------------------------------------------------- /app/src/core/quality_signals/importance_weights.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.stats as stats 3 | import sys 4 | from typing import List, Tuple, Type, Optional 5 | from pathlib import Path 6 | 7 | from core.constants import PRECISION 8 | from core.quality_signals.base import RPSBase 9 | from core.quality_signals.utils.dsir import hash_feature 10 | from core.document import Document 11 | from core.data_types import SignalType 12 | 13 | from utilities.register.registry_utils import * 14 | from utilities.text import form_ngrams 15 | 16 | __all__ = [ 17 | "register_importance_weights_callables", 18 | "importance_weights_schema" 19 | ] 20 | 21 | 22 | def importance_weights_schema() -> List[Tuple[str, Type]]: 23 | r""" Returns a list of signal names and their data types """ 24 | return signal_schema(module=sys.modules[__name__]) 25 | 26 | 27 | def register_importance_weights_callables( 28 | source_fps: Optional[Tuple[str]], 29 | wiki_fps: Optional[Tuple[str]], 30 | openwebtext_fps: Optional[Tuple[str]], 31 | books_fps: Optional[Tuple[str]], 32 | language: str 33 | ) -> List[RPSBase]: 34 | r""" Returns a list of signal functions (i.e., RPSBase instances) that 35 | are used to extract content signals from a document. 36 | 37 | Returns: 38 | A list of signal function class instances. 39 | """ 40 | return list(map( 41 | lambda cls: cls( 42 | language=language, 43 | source_fps=source_fps, 44 | wiki_fps=wiki_fps, 45 | openwebtext_fps=openwebtext_fps, 46 | books_fps=books_fps 47 | ), 48 | get_callables_from_module(module=sys.modules[__name__]) 49 | )) 50 | 51 | 52 | class Base_Importance(RPSBase): # noqa 53 | r""" Base class for functions which return the log ratio of the likelihood 54 | of the document's features with respect to the target domain 55 | versus the source domain. """ 56 | 57 | __slots__ = ( 58 | "_log_diff_dist", "_feature_dim", "_target_lambda", 59 | "_source_lambda", "_length_correction" 60 | ) 61 | 62 | def __init__( 63 | self, 64 | target_fps: Tuple[str, str], 65 | source_fps: Tuple[str, str], 66 | language: str, 67 | length_correction: bool = False 68 | ): 69 | super(Base_Importance, self).__init__() 70 | self._length_correction = length_correction 71 | 72 | if target_fps is None or source_fps is None: 73 | self._log_diff_dist = None 74 | self._feature_dim = None 75 | return 76 | 77 | target_count_fp, target_lambbda_fp = target_fps 78 | source_count_fp, source_lambda_fp = source_fps 79 | 80 | assert language == Path(target_count_fp).stem.split(".")[1], \ 81 | f"Language mismatch between {target_count_fp} and {language}" 82 | 83 | assert language == Path(source_count_fp).stem.split(".")[1], \ 84 | f"Language mismatch between {target_count_fp} and {language}" 85 | 86 | # load hash counts 87 | target_counts = np.load(target_count_fp) 88 | target_dist = target_counts / target_counts.sum() 89 | source_counts = np.load(source_count_fp) 90 | source_dist = source_counts / source_counts.sum() 91 | 92 | if length_correction: 93 | self._target_lambda = np.load(target_lambbda_fp) 94 | self._source_lambda = np.load(source_lambda_fp) 95 | else: 96 | self._target_lambda = None 97 | self._source_lambda = None 98 | 99 | # compute log diff dist 100 | self._feature_dim = target_counts.shape[0] 101 | self._log_diff_dist = np.array( 102 | np.log(target_dist + 1e-8) - np.log(source_dist + 1e-8) 103 | ) 104 | 105 | def __call__(self, document: Document) -> SignalType: 106 | if self._log_diff_dist is None: 107 | return [(0, len(document), None)] 108 | 109 | doc_len = len(document) 110 | 111 | if doc_len == 0: 112 | return [(0, doc_len, None)] 113 | 114 | # try to fetch cached features, if not compute them 115 | features = ( 116 | document.hash_features 117 | if document.hash_features is not None 118 | else 119 | hash_feature( 120 | unigrams=document.raw_words, 121 | # fetch cached bigrams, otherwise comptue them 122 | bigrams=( 123 | document.raw_2grams 124 | or 125 | tuple(form_ngrams(iter(document.raw_words), 2)) 126 | ), 127 | buckets=self._feature_dim 128 | ) 129 | ) 130 | 131 | logratio = np.inner(features, self._log_diff_dist) 132 | score = float(logratio) 133 | 134 | if not self._length_correction: 135 | score = round(score, PRECISION) 136 | return [(0, doc_len, score)] 137 | 138 | # correct for the length assuming a Poisson distribution 139 | return self.__add_length_penalty(score, doc_len) 140 | 141 | def __add_length_penalty(self, score, doc_len): 142 | # correct for the length assuming a Poisson distribution 143 | len_prob_source = stats.poisson.pmf(doc_len, self._source_lambda) 144 | len_prob_target = stats.poisson.pmf(doc_len, self._target_lambda) 145 | 146 | len_correction = np.log(len_prob_target + 1e-8) - \ 147 | np.log(len_prob_source + 1e-8) 148 | 149 | score += float(len_correction) 150 | score = round(score, PRECISION) 151 | return [(0, doc_len, score)] 152 | 153 | 154 | class RPS_Doc_Wikipedia_Importance(Base_Importance): # noqa 155 | r""" Given a bag of {1,2}-wordgram model trained on Wikipedia articles p, 156 | and a model trained on the source domain q. This is the logarithm of the 157 | ratio p(doc)/q(doc). If length_correction is enabled, then the length of 158 | score is adjusted by adding the term log(p_poisson(len) / q_poisson(len)) 159 | to the final score. 160 | """ 161 | __slots__ = () 162 | 163 | def __init__( 164 | self, 165 | wiki_fps: Tuple[str, str], 166 | source_fps: Tuple[str, str], 167 | language: str, 168 | *args, **kwargs # noqa 169 | ): 170 | super(RPS_Doc_Wikipedia_Importance, self).__init__( 171 | target_fps=wiki_fps, 172 | source_fps=source_fps, 173 | language=language, 174 | length_correction=False 175 | ) 176 | 177 | 178 | class RPS_Doc_Wikipedia_Importance_Length_Correction( # noqa 179 | Base_Importance 180 | ): 181 | r""" Given a bag of {1,2}-wordgram model trained on Wikipedia articles p, 182 | and a model trained on the source domain q. This is the logarithm of the 183 | ratio p(doc)/q(doc). If length_correction is enabled, then the length of 184 | score is adjusted by adding the term log(p_poisson(len) / q_poisson(len)) 185 | to the final score. Corrects for length by adding a length penalty term. 186 | """ 187 | __slots__ = () 188 | 189 | def __init__( 190 | self, 191 | wiki_fps: Tuple[str, str], 192 | source_fps: Tuple[str, str], 193 | language: str, 194 | *args, **kwargs # noqa 195 | ): 196 | super(RPS_Doc_Wikipedia_Importance_Length_Correction, 197 | self).__init__( 198 | target_fps=wiki_fps, 199 | source_fps=source_fps, 200 | language=language, 201 | length_correction=True 202 | ) 203 | 204 | 205 | class RPS_Doc_Books_Importance(Base_Importance): # noqa 206 | r""" Given a bag of {1,2}-wordgram model trained on Books p, 207 | and a model trained on the source domain q. This is the logarithm of the 208 | ratio p(doc)/q(doc). If length_correction is enabled, then the length of 209 | score is adjusted by adding the term log(p_poisson(len) / q_poisson(len)) 210 | to the final score. 211 | """ 212 | __slots__ = () 213 | 214 | def __init__( 215 | self, 216 | books_fps: Tuple[str, str], 217 | source_fps: Tuple[str, str], 218 | language: str, 219 | *args, **kwargs # noqa 220 | ): 221 | super(RPS_Doc_Books_Importance, self).__init__( 222 | target_fps=books_fps, 223 | source_fps=source_fps, 224 | language=language, 225 | length_correction=False 226 | ) 227 | 228 | 229 | class RPS_Doc_Books_Importance_Length_Correction( # noqa 230 | Base_Importance 231 | ): # noqa 232 | r""" Given a bag of {1,2}-wordgram model trained on Books p, 233 | and a model trained on the source domain q. This is the logarithm of the 234 | ratio p(doc)/q(doc). If length_correction is enabled, then the length of 235 | score is adjusted by adding the term log(p_poisson(len) / q_poisson(len)) 236 | to the final score. Corrects for length by adding a length penalty term. 237 | """ 238 | __slots__ = () 239 | 240 | def __init__( 241 | self, 242 | books_fps: Tuple[str, str], 243 | source_fps: Tuple[str, str], 244 | language: str, 245 | *args, **kwargs # noqa 246 | ): 247 | super(RPS_Doc_Books_Importance_Length_Correction, self).__init__( 248 | target_fps=books_fps, 249 | source_fps=source_fps, 250 | language=language, 251 | length_correction=True 252 | ) 253 | 254 | 255 | class RPS_Doc_OpenWebText_Importance(Base_Importance): # noqa 256 | r""" Given a bag of {1,2}-wordgram model trained on OpenWebText p, 257 | and a model trained on the source domain q. This is the logarithm of the 258 | ratio p(doc)/q(doc). If length_correction is enabled, then the length of 259 | score is adjusted by adding the term log(p_poisson(len) / q_poisson(len)) 260 | to the final score. 261 | """ 262 | __slots__ = () 263 | 264 | def __init__( 265 | self, 266 | openwebtext_fps: Tuple[str, str], 267 | source_fps: Tuple[str, str], 268 | language: str, 269 | *args, **kwargs # noqa 270 | ): 271 | super(RPS_Doc_OpenWebText_Importance, self).__init__( 272 | target_fps=openwebtext_fps, 273 | source_fps=source_fps, 274 | language=language, 275 | length_correction=False 276 | ) 277 | 278 | 279 | class RPS_Doc_OpenWebText_Importance_Length_Correction( # noqa 280 | Base_Importance): # noqa 281 | r""" Given a bag of {1,2}-wordgram model trained on OpenWebText p, 282 | and a model trained on the source domain q. This is the logarithm of the 283 | ratio p(doc)/q(doc). If length_correction is enabled, then the length of 284 | score is adjusted by adding the term log(p_poisson(len) / q_poisson(len)) 285 | to the final score. Corrects for length by adding a length penalty term. 286 | """ 287 | __slots__ = () 288 | 289 | def __init__( 290 | self, 291 | openwebtext_fps: Tuple[str, str], 292 | source_fps: Tuple[str, str], 293 | language: str, 294 | *args, **kwargs # noqa 295 | ): 296 | super( 297 | RPS_Doc_OpenWebText_Importance_Length_Correction, self 298 | ).__init__( 299 | target_fps=openwebtext_fps, 300 | source_fps=source_fps, 301 | language=language, 302 | length_correction=True 303 | ) 304 | -------------------------------------------------------------------------------- /app/src/core/quality_signals/lines.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import List, Tuple, Type 3 | 4 | from core.constants import PRECISION 5 | from core.quality_signals.base import RPSBase 6 | from core.data_types import SignalType, ScoreType, TextSlice 7 | from core.document import Document 8 | from utilities.register.registry_utils import * 9 | 10 | __all__ = [ 11 | "register_lines_callables", "lines_schema" 12 | ] 13 | 14 | 15 | def lines_schema() -> List[Tuple[str, Type]]: 16 | r""" Returns a list of signal names and their data types """ 17 | return signal_schema(module=sys.modules[__name__]) 18 | 19 | 20 | def register_lines_callables() -> List[RPSBase]: 21 | r""" Returns a list of signal functions (i.e., RPSBase instances) that 22 | are used to extract line signals from a document. 23 | 24 | Returns: 25 | A list of signal function class instances. 26 | """ 27 | return list(map( 28 | lambda cls: cls(), 29 | get_callables_from_module(module=sys.modules[__name__]) 30 | )) 31 | 32 | 33 | class RPS_Lines_Javascript_Counts(RPSBase): # noqa 34 | r""" The number of occurences of the word "javascript" in each line. """ 35 | SEARCH_TEXT = "javascript" 36 | __slots__ = () 37 | 38 | def _process_line(self, text_slice: TextSlice) -> ScoreType: 39 | if len(text_slice.text) == 0: 40 | return tuple((text_slice.start, text_slice.end, 0.0)) 41 | 42 | score = float(sum( 43 | 1 for w in text_slice.text.split() if w == self.SEARCH_TEXT 44 | )) 45 | 46 | return tuple((text_slice.start, text_slice.end, score)) 47 | 48 | def __call__(self, document: Document) -> SignalType: 49 | return list(map(self._process_line, document.normalized_lines)) 50 | 51 | 52 | class RPS_Lines_Ending_With_Terminal_Punctution_Mark(RPSBase): # noqa 53 | r""" A list of integers indicating whether (1) or not (0) a line ends with 54 | a terminal punctuation mark. A terminal punctation mark is defined as 55 | one of the following: ".", "!", "?", "”" """ 56 | TERMINAL_PUNCTUATION_MARKS = (".", "!", "?", "”") 57 | __slots__ = () 58 | 59 | def _process_line(self, text_slice: TextSlice) -> ScoreType: 60 | score = text_slice.text.rstrip().endswith( 61 | self.TERMINAL_PUNCTUATION_MARKS 62 | ) 63 | score = float(score) 64 | return tuple((text_slice.start, text_slice.end, score)) 65 | 66 | def __call__(self, document: Document) -> SignalType: 67 | return list(map(self._process_line, document.raw_lines)) 68 | 69 | 70 | class RPS_Lines_Num_Words(RPSBase): # noqa 71 | r""" The number of words in each line. This is computed based on the 72 | normalied text. Normalization is done by lowercasing the text and 73 | removing punctuation.""" 74 | __slots__ = () 75 | 76 | def _process_line(self, text_slice: TextSlice) -> ScoreType: # noqa 77 | score = len(text_slice.text.split()) 78 | return tuple((text_slice.start, text_slice.end, score)) 79 | 80 | def __call__(self, document: Document) -> SignalType: 81 | return list(map(self._process_line, document.normalized_lines)) 82 | 83 | 84 | class RPS_Lines_Uppercase_Letter_Fraction(RPSBase): # noqa 85 | r""" The ratio between number of uppercase letters and total number of 86 | characters in each line. This is based on the raw text. """ 87 | __slots__ = () 88 | 89 | def _process_line(self, text_slice: TextSlice) -> ScoreType: # noqa 90 | if len(text_slice) == 0: 91 | return tuple((text_slice.start, text_slice.end, 0.0)) 92 | 93 | score = sum(map(str.isupper, text_slice.text)) / len(text_slice) 94 | score = round(score, PRECISION) 95 | return tuple((text_slice.start, text_slice.end, score)) 96 | 97 | def __call__(self, document: Document) -> SignalType: 98 | return list(map(self._process_line, document.raw_lines)) 99 | 100 | 101 | class RPS_Lines_Numerical_Chars_Fraction(RPSBase): # noqa 102 | r""" The ratio between number of numerical characters and total number of 103 | characters in each line. This is based on text after lowercasing and 104 | removing punctuation.""" 105 | __slots__ = () 106 | 107 | def _process_line(self, text_slice: TextSlice) -> ScoreType: # noqa 108 | if len(text_slice) == 0: 109 | return tuple((text_slice.start, text_slice.end, 0.0)) 110 | 111 | score = sum(map(str.isnumeric, text_slice.text)) / len(text_slice) 112 | score = round(score, PRECISION) 113 | return tuple((text_slice.start, text_slice.end, score)) 114 | 115 | def __call__(self, document: Document) -> SignalType: 116 | return list(map(self._process_line, document.normalized_lines)) 117 | 118 | 119 | class RPS_Lines_Start_With_Bulletpoint(RPSBase): # noqa 120 | r""" Whether the lines that start with a bullet point symbol. The 121 | following set of unicodes are considered a bullet point: 122 | \u2022 (bullet point), \u2023 (triangular bullet point), \u25B6 (black 123 | right pointing triangle), \u25C0 (black left pointing triangle), 124 | \u25E6 (white bullet point), \u25A0 (black square), \u25A1 (white 125 | square), \u25AA (black small square), \u25AB (white small square), 126 | \u2013 (en dash).""" 127 | BULLET_POINT_SYMBOLS = ( 128 | "\u2022", # bullet point 129 | "\u2023", # triangular bullet point 130 | "\u25B6", # black right pointing triangle 131 | "\u25C0", # black left pointing triangle 132 | "\u25E6", # white bullet point 133 | "\u25A0", # black square 134 | "\u25A1", # white square 135 | "\u25AA", # black small square 136 | "\u25AB", # white small square 137 | "\u2013", # en dash 138 | ) 139 | 140 | __slots__ = () 141 | 142 | def _process_line(self, text_slice: TextSlice) -> ScoreType: # noqa 143 | score = text_slice.text.lstrip().startswith(self.BULLET_POINT_SYMBOLS) 144 | score = float(score) 145 | return tuple((text_slice.start, text_slice.end, score)) 146 | 147 | def __call__(self, document: Document) -> SignalType: 148 | num_lines = len(document.raw_lines) 149 | 150 | if num_lines == 0: 151 | return [(0, len(document), None)] 152 | 153 | return list(map(self._process_line, document.raw_lines)) 154 | -------------------------------------------------------------------------------- /app/src/core/quality_signals/natural_language.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | import math 3 | import re 4 | import sys 5 | from typing import List, Tuple, Type 6 | 7 | from core.constants import PRECISION 8 | from core.data_types import SignalType 9 | from core.quality_signals.base import RPSBase 10 | from core.document import Document 11 | from utilities.register.registry_utils import * 12 | 13 | __all__ = [ 14 | "register_natural_language_callables", 15 | "natural_language_schema" 16 | ] 17 | 18 | 19 | def natural_language_schema() -> List[Tuple[str, Type]]: 20 | r""" Returns a list of signal names and their data types """ 21 | return signal_schema(module=sys.modules[__name__]) 22 | 23 | 24 | def register_natural_language_callables() -> List[RPSBase]: 25 | r""" Returns a list of signal functions (i.e., RPSBase instances) that 26 | are used to extract natural language signals from a document. 27 | 28 | Returns: 29 | A list of signal function class instances. 30 | """ 31 | return list(map( 32 | lambda cls: cls(), 33 | get_callables_from_module(module=sys.modules[__name__]) 34 | )) 35 | 36 | 37 | class RPS_Doc_Num_Sentences(RPSBase): # noqa 38 | r""" The number of sentences in the content. This is calculated using 39 | the regex r'\b[^.!?]+[.!?]*' """ 40 | SENT_PATTERN = re.compile(r'\b[^.!?]+[.!?]*', flags=re.UNICODE) 41 | 42 | __slots__ = () 43 | 44 | def __call__(self, document: Document) -> SignalType: 45 | r""" count the number of sentences in the content using regex""" 46 | score = float(len(self.SENT_PATTERN.findall(document.raw_content))) 47 | return [(0, len(document), score)] 48 | 49 | 50 | class RPS_Doc_Word_Count(RPSBase): # noqa 51 | r""" The number of words in the content after normalization. """ 52 | __slots__ = () 53 | 54 | def __call__(self, document: Document) -> SignalType: 55 | return [(0, len(document), document.num_normalized_words)] 56 | 57 | 58 | class RPS_Doc_Mean_Word_Length(RPSBase): # noqa 59 | r""" The mean length of words in the content normalization. """ 60 | __slots__ = () 61 | 62 | def __call__(self, document: Document) -> SignalType: 63 | if document.num_normalized_words == 0: 64 | return [(0, len(document), None)] 65 | 66 | num_chars = float(sum(map(len, document.normalized_words))) 67 | score = num_chars / document.num_normalized_words 68 | score = round(score, PRECISION) 69 | return [(0, len(document), score)] 70 | 71 | 72 | class RPS_Doc_Symbol_To_Word_Ratio(RPSBase): # noqa 73 | r""" The ratio of symbols to words in the content. This is analogous to 74 | the signal used in Gopher. Symbols are defined "#", "...", and "…". """ 75 | SYMBOLS = ("#", "...", "…") 76 | 77 | __slots__ = () 78 | 79 | def __call__(self, document: Document) -> SignalType: 80 | num_words = document.num_raw_words 81 | 82 | if num_words == 0: 83 | return [(0, len(document), None)] 84 | 85 | # count the number of symbols in the content 86 | num_symbols = float(sum( 87 | document.raw_content.count(x) for x in self.SYMBOLS 88 | )) 89 | 90 | score = num_symbols / num_words 91 | score = round(score, PRECISION) 92 | return [(0, len(document), score)] 93 | 94 | 95 | class RPS_Doc_Frac_Lines_End_With_Ellipsis(RPSBase): # noqa 96 | r""" The fraction of lines that end with an ellipsis, where an ellipsis 97 | is defined as either "..." or "…". """ 98 | ELLIPSIS_SYMBOLS = ("...", "…") 99 | 100 | __slots__ = () 101 | 102 | def __call__(self, document: Document) -> SignalType: 103 | num_lines = len(document.raw_lines) 104 | 105 | if num_lines == 0: 106 | return [(0, len(document), None)] 107 | 108 | total_ellipsis_lines = float(sum( 109 | text_slice.text.rstrip().endswith(self.ELLIPSIS_SYMBOLS) 110 | for text_slice in document.raw_lines 111 | )) 112 | 113 | score = total_ellipsis_lines / num_lines 114 | score = round(score, PRECISION) 115 | return [(0, len(document), score)] 116 | 117 | 118 | class RPS_Doc_Frac_No_Alph_Words(RPSBase): # noqa 119 | r""" The fraction of words that contain no alphabetical character. 120 | This is based on the raw content. """ 121 | ALPH_REGEX = re.compile(r"[a-zA-Z]") 122 | 123 | __slots__ = () 124 | 125 | def __call__(self, document: Document) -> SignalType: 126 | num_words = document.num_raw_words 127 | 128 | if num_words == 0: 129 | return [(0, len(document), None)] 130 | 131 | num_words_with_alpha = float(sum( 132 | int(self.ALPH_REGEX.search(word) is not None) 133 | for word in document.raw_words 134 | )) 135 | 136 | score = 1.0 - num_words_with_alpha / num_words 137 | score = round(score, PRECISION) 138 | return [(0, len(document), score)] 139 | 140 | 141 | class RPS_Doc_Frac_Unique_Words(RPSBase): # noqa 142 | r""" The fraction of unique words in the content. This is also known as 143 | the degeneracy of a text sample. Calculated based on the normalized 144 | content. """ 145 | __slots__ = () 146 | 147 | def __call__(self, document: Document) -> SignalType: 148 | num_words = document.num_normalized_words 149 | 150 | if num_words == 0: 151 | return [(0, len(document), None)] 152 | 153 | score = float(len(set(document.normalized_words))) / num_words 154 | score = round(score, PRECISION) 155 | return [(0, len(document), score)] 156 | 157 | 158 | class RPS_Doc_Unigram_Entropy(RPSBase): # noqa 159 | r""" The entropy of the unigram distribution of the 160 | content. This measures the diversity of the content and is computed 161 | using sum(-x / total * log(x / total)) where the sum is taken over 162 | over counts of unique words in the noramlized (punctuation removed, 163 | lowercased) content.""" 164 | __slots__ = () 165 | 166 | def __call__(self, document: Document) -> SignalType: 167 | if len(document.normalized_words) == 0: 168 | return [(0, len(document), None)] 169 | 170 | # count the number of times each word appears in the content 171 | counter = Counter(document.normalized_words) 172 | 173 | # calculate the entropy of the unigram distribution 174 | total = sum(counter.values()) 175 | entropy = sum(map( 176 | lambda x: -x / total * math.log(x / total) if x > 0 else 0.0, 177 | counter.values() 178 | )) 179 | 180 | score = round(entropy, PRECISION) 181 | return [(0, len(document), score)] 182 | 183 | 184 | class RPS_Doc_Frac_All_Caps_Words(RPSBase): # noqa 185 | r""" The fraction of words in the content that only conist of uppercase 186 | letters. This is based on the raw content.""" 187 | __slots__ = () 188 | 189 | def __call__(self, document: Document) -> SignalType: 190 | num_words = document.num_raw_words 191 | 192 | if num_words == 0: 193 | return [(0, len(document), None)] 194 | 195 | score = float(sum(map(str.isupper, document.raw_words))) / num_words 196 | score = round(score, PRECISION) 197 | return [(0, len(document), score)] 198 | -------------------------------------------------------------------------------- /app/src/core/quality_signals/repetitions.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | import numpy as np 3 | import sys 4 | from typing import List, Tuple, Type 5 | 6 | from core.constants import PRECISION 7 | from core.quality_signals.base import RPSBase 8 | from core.document import Document 9 | from core.data_types import SignalType 10 | from utilities.register.registry_utils import * 11 | from utilities.text import form_ngrams 12 | 13 | __all__ = [ 14 | "register_repetitions_callables", 15 | "repetitions_schema" 16 | ] 17 | 18 | 19 | def repetitions_schema() -> List[Tuple[str, Type]]: 20 | r""" Returns a list of signal names and their data types """ 21 | return signal_schema(module=sys.modules[__name__]) 22 | 23 | 24 | def register_repetitions_callables() -> List[RPSBase]: 25 | r""" Returns a list of signal functions (i.e., RPSBase instances) that 26 | are used to extract repetition related signals from a document. 27 | 28 | Returns: 29 | A list of signal function class instances. 30 | """ 31 | return list(map( 32 | lambda cls: cls(), 33 | get_callables_from_module(module=sys.modules[__name__]) 34 | )) 35 | 36 | 37 | class Base_RPS_Frac_Chars_In_Top_NGram(RPSBase): # noqa 38 | r""" Base class for calculating the fraction of characters in the 39 | top N-gram. This operates on the lower-cased, punctation removed 40 | content.""" 41 | NGRAM_SIZE: int = None 42 | 43 | __slots__ = [] 44 | 45 | def __call__(self, document: Document) -> SignalType: 46 | if self.NGRAM_SIZE is None: 47 | raise NotImplementedError( 48 | "NGRAM_SIZE must be set in the subclass" 49 | ) 50 | 51 | # get the most common ngram 52 | most_common_ngram = Counter( 53 | # fetch the ngrams from the document if they exist, otherwise 54 | # compute them 55 | getattr(document, f"norm_{self.NGRAM_SIZE}grams", None) 56 | or 57 | form_ngrams(iter(document.normalized_words), self.NGRAM_SIZE) 58 | ).most_common(1) 59 | 60 | if len(most_common_ngram) == 0: 61 | return [(0, len(document), 0.0)] 62 | 63 | ngram, count = most_common_ngram[0] 64 | 65 | if count <= 1: 66 | return [(0, len(document), 0.0)] 67 | 68 | total_chars = sum(len(w) for w in document.normalized_words) 69 | score = sum(len(w) for w in ngram) * count / total_chars 70 | score = round(score, PRECISION) 71 | return [(0, len(document), score)] 72 | 73 | 74 | class RPS_Doc_Frac_Chars_Top_2gram(Base_RPS_Frac_Chars_In_Top_NGram): # noqa 75 | r""" The fraction of characters in the top word Bigram. Operates on the 76 | lower-cased, punctation removed content.""" 77 | NGRAM_SIZE = 2 78 | __slots__ = [] 79 | 80 | 81 | class RPS_Doc_Frac_Chars_Top_3gram(Base_RPS_Frac_Chars_In_Top_NGram): # noqa 82 | r""" The fraction of characters in the top word Trigram. Operates on the 83 | lower-cased, punctation removed content.""" 84 | NGRAM_SIZE = 3 85 | __slots__ = [] 86 | 87 | 88 | class RPS_Doc_Frac_Chars_Top_4gram(Base_RPS_Frac_Chars_In_Top_NGram): # noqa 89 | r""" The fraction of characters in the top word 4gram. Operates on the 90 | lower-cased, punctation removed content.""" 91 | NGRAM_SIZE = 4 92 | __slots__ = [] 93 | 94 | 95 | class Base_RPS_Frac_Chars_In_Dupe_NGrams(RPSBase): # noqa 96 | r""" Base class for calculating the fraction of characters in 97 | duplicate word N-grams. This operates on the lower-cased, punctation 98 | removed content. The function also ensures that characters in overlapping 99 | ngrams are only counted once.""" 100 | NGRAM_SIZE: int = None 101 | __slots__ = [] 102 | 103 | def __call__(self, document: Document) -> SignalType: 104 | if self.NGRAM_SIZE is None: 105 | raise NotImplementedError( 106 | "NGRAM_SIZE must be set in the subclass" 107 | ) 108 | 109 | if len(document.normalized_words) < self.NGRAM_SIZE: 110 | return [(0, len(document), 0.0)] 111 | 112 | # fetch the ngrams from the document if they exist, otherwise 113 | # compute them 114 | doc_n_grams = ( 115 | getattr(document, f"norm_{self.NGRAM_SIZE}grams", None) 116 | or 117 | tuple(form_ngrams( 118 | iter(document.normalized_words), self.NGRAM_SIZE 119 | )) 120 | ) 121 | 122 | # keep only ngrams which occur at least twice 123 | ngram_dupes = { 124 | ngram for ngram, count in Counter(doc_n_grams).items() if count > 1 125 | } 126 | 127 | duplicated_grams = np.zeros(len(document.normalized_words), dtype=int) 128 | 129 | i = 0 130 | for ngram in doc_n_grams: 131 | if ngram in ngram_dupes: 132 | duplicated_grams[i: i + self.NGRAM_SIZE] = 1 133 | 134 | i += 1 135 | 136 | word_lengths = np.array(list(map(len, document.normalized_words))) 137 | chars_duped = np.sum(word_lengths * duplicated_grams) 138 | total_chars = np.sum(word_lengths) 139 | 140 | if total_chars == 0: 141 | return [(0, len(document), 0.0)] 142 | 143 | score = float(chars_duped / total_chars) 144 | score = round(score, PRECISION) 145 | return [(0, len(document), score)] 146 | 147 | 148 | class RPS_Doc_Frac_Chars_Dupe_5Grams( # noqa 149 | Base_RPS_Frac_Chars_In_Dupe_NGrams 150 | ): 151 | r""" The fraction of characters in duplicate word 5grams. This operates on 152 | the lower-cased, punctation removed content. It is also ensured that 153 | characters in overlapping ngrams are only counted once. """ 154 | NGRAM_SIZE = 5 155 | __slots__ = [] 156 | 157 | 158 | class RPS_Doc_Frac_Chars_Dupe_6Grams( # noqa 159 | Base_RPS_Frac_Chars_In_Dupe_NGrams 160 | ): 161 | r""" The fraction of characters in duplicate word 6grams. This operates on 162 | the lower-cased, punctation removed content. It is also ensured that 163 | characters in overlapping ngrams are only counted once. """ 164 | NGRAM_SIZE = 6 165 | __slots__ = [] 166 | 167 | 168 | class RPS_Doc_Frac_Chars_Dupe_7Grams( # noqa 169 | Base_RPS_Frac_Chars_In_Dupe_NGrams 170 | ): 171 | r""" The fraction of characters in duplicate word 7grams. This operates on 172 | the lower-cased, punctation removed content. It is also ensured that 173 | characters in overlapping ngrams are only counted once. """ 174 | NGRAM_SIZE = 7 175 | __slots__ = [] 176 | 177 | 178 | class RPS_Doc_Frac_Chars_Dupe_8Grams( # noqa 179 | Base_RPS_Frac_Chars_In_Dupe_NGrams 180 | ): 181 | r""" The fraction of characters in duplicate word 8grams. This operates on 182 | the lower-cased, punctation removed content. It is also ensured that 183 | characters in overlapping ngrams are only counted once. """ 184 | NGRAM_SIZE = 8 185 | __slots__ = [] 186 | 187 | 188 | class RPS_Doc_Frac_Chars_Dupe_9Grams( # noqa 189 | Base_RPS_Frac_Chars_In_Dupe_NGrams 190 | ): 191 | r""" The fraction of characters in duplicate word 9grams. This operates on 192 | the lower-cased, punctation removed content. It is also ensured that 193 | characters in overlapping ngrams are only counted once. """ 194 | NGRAM_SIZE = 9 195 | __slots__ = [] 196 | 197 | 198 | class RPS_Doc_Frac_Chars_Dupe_10Grams( # noqa 199 | Base_RPS_Frac_Chars_In_Dupe_NGrams 200 | ): 201 | r""" The fraction of characters in duplicate word 10grams. This operates on 202 | the lower-cased, punctation removed content. It is also ensured that 203 | characters in overlapping ngrams are only counted once. """ 204 | NGRAM_SIZE = 10 205 | __slots__ = [] 206 | -------------------------------------------------------------------------------- /app/src/core/quality_signals/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/togethercomputer/RedPajama-Data/6d2cee9df2b0204dd2bcb00bf06b5a7b1d7432d7/app/src/core/quality_signals/utils/__init__.py -------------------------------------------------------------------------------- /app/src/core/quality_signals/utils/classifiers.py: -------------------------------------------------------------------------------- 1 | from core.document import Document 2 | 3 | 4 | def preprocess_quality_classifier(document: Document): 5 | r""" Preprocesses a document for quality classification. This function 6 | removes all newlines and trailing whitespaces from the document. 7 | 8 | Args: 9 | document: A document. 10 | 11 | Returns: 12 | A string. 13 | """ 14 | # remove newlines and trailing and leading whitespaces 15 | return " ".join(document.raw_content.splitlines()).strip() 16 | -------------------------------------------------------------------------------- /app/src/core/quality_signals/utils/content.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import Dict, Set 4 | 5 | _DEFAULT_LANGS = ("en", "fr", "it", "es", "de") 6 | 7 | 8 | def load_bad_urls_index(bad_urls_dir: Path) -> Dict[str, int]: 9 | with open(bad_urls_dir / "domain_to_category_id.json", "r") as f: 10 | domain_to_category_id = json.load(f) 11 | return domain_to_category_id 12 | 13 | 14 | def load_bad_words(bad_words_dir: Path, lang: str) -> Set[str]: 15 | r""" load the LDNOOBW word list for a given language 16 | 17 | Source: 18 | https://github.com/LDNOOBW/List-of-Dirty-Naughty-Obscene-and-Otherwise-Bad-Words 19 | 20 | Args: 21 | bad_words_dir (Path): The path to the resources directory where the 22 | list is stored 23 | lang (str): The language for which to fetch the word list 24 | 25 | Returns: 26 | A set of words 27 | """ 28 | if lang not in _DEFAULT_LANGS: 29 | return set() 30 | 31 | ldnoobw_fp = bad_words_dir / f"{lang}.txt" 32 | 33 | if not ldnoobw_fp.exists(): 34 | raise FileNotFoundError(f"LDNOOBW word list {ldnoobw_fp} not found!") 35 | 36 | with open(ldnoobw_fp, 'r') as f: 37 | data = set(ln.strip() for ln in f.readlines()) 38 | 39 | return data 40 | -------------------------------------------------------------------------------- /app/src/core/quality_signals/utils/dsir.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Tuple 3 | 4 | 5 | def compute_hash(ngram: str, buckets: int): 6 | return int(abs(hash(ngram)) % buckets) 7 | 8 | 9 | def hash_feature( 10 | unigrams: Tuple[str], bigrams: Tuple[str], buckets: int 11 | ) -> np.ndarray: 12 | counts = np.zeros(buckets, dtype=np.int64) 13 | 14 | for unigram in unigrams: 15 | counts[compute_hash(unigram, buckets=buckets)] += 1 16 | 17 | for bigram in bigrams: 18 | counts[compute_hash(bigram, buckets=buckets)] += 1 19 | 20 | return counts 21 | -------------------------------------------------------------------------------- /app/src/core/schema/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/togethercomputer/RedPajama-Data/6d2cee9df2b0204dd2bcb00bf06b5a7b1d7432d7/app/src/core/schema/__init__.py -------------------------------------------------------------------------------- /app/src/core/schema/rp.py: -------------------------------------------------------------------------------- 1 | r""" This module contains definitions of the schemas used in the project. These 2 | are used to build msgspec writers and readers. 3 | 4 | The schemas are defined as lists of tuples, where each tuple contains the name 5 | and type of the field. 6 | 7 | """ 8 | 9 | from core.data_types import SignalType 10 | from core.quality_signals.content import content_schema 11 | from core.quality_signals.repetitions import repetitions_schema 12 | from core.quality_signals.natural_language import natural_language_schema 13 | from core.quality_signals.lines import lines_schema 14 | from core.quality_signals.classifiers import classifier_schema 15 | from core.quality_signals.importance_weights import importance_weights_schema 16 | 17 | METADATA_SCHEMA = [ 18 | ("cc_net_source", str), 19 | ("cc_segment", str), 20 | ("shard_id", str), 21 | ("url", str), 22 | ("source_domain", str), 23 | ("language", str), 24 | ("snapshot_id", str) 25 | ] 26 | 27 | QUALITY_SIGNALS_SCHEMA = [ 28 | ("ccnet_length", SignalType), 29 | ("ccnet_original_length", SignalType), 30 | ("ccnet_nlines", SignalType), 31 | ("ccnet_original_nlines", SignalType), 32 | ("ccnet_language_score", SignalType), 33 | ("ccnet_perplexity", SignalType), 34 | ("ccnet_bucket", SignalType), 35 | *content_schema(), 36 | *natural_language_schema(), 37 | *repetitions_schema(), 38 | *lines_schema(), 39 | *classifier_schema(), 40 | *importance_weights_schema(), 41 | ] 42 | 43 | RP_SIGNAL_SCHEMA = [ 44 | ("id", str), 45 | ("id_int", int), 46 | ("metadata", METADATA_SCHEMA), 47 | ("quality_signals", QUALITY_SIGNALS_SCHEMA) 48 | ] 49 | -------------------------------------------------------------------------------- /app/src/core/worker.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import fasttext 3 | import gc 4 | import hashlib 5 | import logging 6 | import logging.handlers 7 | import multiprocessing as mp 8 | import os 9 | from pathlib import Path 10 | import re 11 | from typing import List, Dict, Callable, Optional 12 | from urllib.parse import urlparse 13 | import urllib3 14 | import pyarrow as pa 15 | import uuid 16 | 17 | from core.document import Document 18 | from core.quality_signals.content import register_content_callables 19 | from core.quality_signals.lines import register_lines_callables 20 | from core.quality_signals.natural_language import \ 21 | register_natural_language_callables 22 | from core.quality_signals.repetitions import register_repetitions_callables 23 | from core.quality_signals.classifiers import register_classifier_callables 24 | from core.quality_signals.importance_weights import \ 25 | register_importance_weights_callables 26 | from core.data_types import InputSpec 27 | from core.schema.rp import RP_SIGNAL_SCHEMA 28 | from dedupe.minhash import MinHash 29 | from utilities.io import Reader, Writer, ParquetBatchWriter 30 | from utilities.io.s3 import init_client 31 | from utilities.logging.mp import configure_worker_logger 32 | 33 | # disable warnings 34 | fasttext.FastText.eprint = lambda x: None 35 | urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) # noqa 36 | 37 | _BYTE_ORDER = sys.byteorder 38 | 39 | 40 | def _ccnet_bucket_to_int(bucket: str) -> Optional[float]: 41 | r""" ccnet bucket name to float mapping """ 42 | if bucket == "head": 43 | return 0.0 44 | elif bucket == "middle": 45 | return 1.0 46 | elif bucket == "tail": 47 | return 2.0 48 | else: 49 | return None 50 | 51 | 52 | class Worker: 53 | # output file pattern 54 | shard_pattern_signals = "{shard_id}.signals.json.gz" 55 | shard_pattern_minhash = "{shard_id}.minhash.parquet" 56 | 57 | # regex to extract snapshot id from uri 58 | snapsh_re = re.compile(r'\b\d{4}-\d{2}\b') 59 | uri_id_re = re.compile(r'\b\d{4}-\d{2}\b/.*') 60 | 61 | def __init__( 62 | self, language: str, 63 | snapshot_id: str, 64 | input_listings: List[str], 65 | input_base_uri: str, 66 | output_base_uri: str, 67 | log_dir: str, 68 | classifier_files: Dict[str, str], 69 | dsir_files: Dict[str, str], 70 | dsir_bucket: int, 71 | ldnoobw_dir: Path, 72 | ut1_dir: Path, 73 | minhash_similarities: List[float], 74 | minhash_ngram_size: int, 75 | minhash_num_permutations: int, 76 | monitor_queue: mp.Queue, 77 | logging_queue: mp.Queue, 78 | seed: int, 79 | endpoint_url: str = None, 80 | max_docs: int = -1, 81 | flush_interval=1000 82 | ): 83 | self._lang = language 84 | self._snapshot_id = snapshot_id 85 | self._input_base_uri = input_base_uri 86 | self._output_base_uri = output_base_uri 87 | self._dsir_files = dsir_files 88 | self._dsir_buckets = dsir_bucket 89 | self._flush_interval = flush_interval 90 | 91 | # init logger 92 | configure_worker_logger(logging_queue, level=logging.INFO) 93 | self._logger = logging.getLogger() 94 | 95 | # minhash setup 96 | self._minhash = MinHash( 97 | similarity_thresholds=minhash_similarities, 98 | ngram_size=minhash_ngram_size, 99 | num_permutations=minhash_num_permutations, 100 | seed=seed 101 | ) 102 | 103 | self._logger.info(f"__MINHASH_PERM_CHECKSUM__ " 104 | f"{self._minhash.checksum}") 105 | 106 | self._max_docs = max_docs 107 | self._monitor_queue = monitor_queue 108 | self._endpoint_url = endpoint_url 109 | 110 | self._job_id = str(uuid.uuid4()) 111 | 112 | # build input paths 113 | self._input_uri_list = list(map( 114 | lambda x: os.path.join(self._input_base_uri, x), 115 | input_listings 116 | )) 117 | 118 | # init file to keep track of failed input files 119 | self._failed_input_file = os.path.join( 120 | log_dir, f"{language}-inputs.{self._job_id}.FAIL" 121 | ) 122 | 123 | # init file to keep track of successful input files 124 | self._success_input_file = os.path.join( 125 | log_dir, f"{language}-inputs.{self._job_id}.SUCCESS" 126 | ) 127 | 128 | # setup input file reader 129 | read_scheme = urlparse(self._input_base_uri).scheme 130 | if read_scheme == "s3": 131 | client = init_client( 132 | endpoint_url=self._endpoint_url, 133 | aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"), 134 | aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"), 135 | signature_version="s3v4" 136 | ) 137 | else: 138 | client = None 139 | 140 | self._reader = Reader( 141 | input_spec=InputSpec, threads=1, s3_client=client, 142 | logger=self._logger 143 | ) 144 | 145 | # classifier model filepaths 146 | self._palm_model_file = classifier_files.get("palm") 147 | self._wikiref_model_file = classifier_files.get("wikiref") 148 | self._wikipedia_model_file = classifier_files.get("wikipedia") 149 | 150 | # initialize signal functions 151 | self._quality_signals = self.__init_quality_signals( 152 | ldnoobw_dir=ldnoobw_dir, ut1_dir=ut1_dir 153 | ) 154 | 155 | # minhash_schema 156 | self._minhash_schema = pa.schema([ 157 | ("shard_id", pa.string()), 158 | ("id", pa.string()), 159 | ("id_int", pa.uint64()), 160 | *[ 161 | ( 162 | "signature_sim{s}".format(s=s), pa.list_(pa.binary()) 163 | ) 164 | for s in minhash_similarities 165 | ] 166 | ]) 167 | 168 | @property 169 | def job_id(self): 170 | return self._job_id 171 | 172 | def __init_quality_signals(self, ldnoobw_dir, ut1_dir) -> List[Callable]: 173 | callables = [] 174 | 175 | # initialize content signal functions 176 | self._logger.info(f"Registering content signals for {self._lang}..") 177 | callables += register_content_callables( 178 | language=self._lang, 179 | bad_urls_dir=ut1_dir, 180 | bad_words_dir=ldnoobw_dir 181 | ) 182 | 183 | # initialize repetition removal signal functions 184 | self._logger.info(f"Registering repetition signals for {self._lang}..") 185 | callables += register_repetitions_callables() 186 | 187 | # initialize natural language signal functions 188 | self._logger.info(f"Registering natlang signals for {self._lang}..") 189 | callables += register_natural_language_callables() 190 | 191 | # initialize line signal functions 192 | self._logger.info(f"Registering line level signals for {self._lang}..") 193 | callables += register_lines_callables() 194 | 195 | # initialize ml heuristics signal functions 196 | self._logger.info(f"Registering classifier signals for {self._lang}..") 197 | callables += register_classifier_callables( 198 | wikiref_model=self._wikiref_model_file, 199 | palm_model=self._palm_model_file, 200 | wikipedia_model=self._wikipedia_model_file 201 | ) 202 | 203 | # initialize importance weights signal functions 204 | # hacky -- first index is the counts file, second is the lambda file 205 | # this is set in pipeline.py 206 | self._logger.info(f"Registering dsir signals for {self._lang}..") 207 | callables += register_importance_weights_callables( 208 | source_fps=self._dsir_files.get("ccnet"), 209 | wiki_fps=self._dsir_files.get("wikipedia"), 210 | openwebtext_fps=self._dsir_files.get("openwebtext"), 211 | books_fps=self._dsir_files.get("books"), 212 | language=self._lang 213 | ) 214 | 215 | return callables 216 | 217 | def __process_record( 218 | self, idx: int, record, uri_id: str, snapshot_id: str 219 | ): 220 | # Setup document; this precomputes ngrams and hash features 221 | document = Document( 222 | record.raw_content, 223 | domain=record.source_domain, 224 | precompute_ngrams=True, 225 | precompute_hash_features=True, 226 | dsir_buckets=self._dsir_buckets 227 | ) 228 | 229 | # compute signals 230 | rp_v2_signals = {} 231 | for func in self._quality_signals: 232 | rp_v2_signals[func.field_name] = func(document) # noqa 233 | 234 | # compute minhash signatures 235 | minhash_signatures = self._minhash.compute_banded_signatures( 236 | tokens=document.normalized_words 237 | ) 238 | 239 | # compute document ids 240 | doc_id = f"{uri_id}/{idx}" 241 | doc_id_int = int.from_bytes( 242 | hashlib.sha1(doc_id.encode("utf-8")).digest()[:8], # take 8 bytes 243 | byteorder=_BYTE_ORDER, signed=False 244 | ) 245 | 246 | record_data = { 247 | "id": f"{uri_id}/{idx}", 248 | "id_int": doc_id_int, 249 | } 250 | 251 | metadata = { 252 | "cc_segment": record.cc_segment, 253 | "cc_net_source": uri_id, 254 | "url": record.url, 255 | "source_domain": record.source_domain, 256 | "language": record.language, 257 | "snapshot_id": snapshot_id 258 | } 259 | 260 | ccnet_quality_signals = { 261 | "ccnet_length": ( 262 | (0, len(document), float(record.length)), 263 | ), 264 | "ccnet_original_length": ( 265 | (0, len(document), float(record.original_length)), 266 | ), 267 | "ccnet_nlines": ( 268 | (0, len(document), float(record.nlines)), 269 | ), 270 | "ccnet_original_nlines": ( 271 | (0, len(document), float(record.original_nlines)), 272 | ), 273 | "ccnet_language_score": ( 274 | (0, len(document), float(record.language_score)), 275 | ), 276 | "ccnet_perplexity": ( 277 | (0, len(document), float(record.perplexity)), 278 | ), 279 | "ccnet_bucket": ( 280 | (0, len(document), _ccnet_bucket_to_int(record.bucket)), 281 | ), 282 | } 283 | 284 | record_data["metadata"] = metadata 285 | record_data["quality_signals"] = { 286 | **ccnet_quality_signals, **rp_v2_signals 287 | } 288 | 289 | return record_data, minhash_signatures, doc_id, doc_id_int 290 | 291 | def __process_uri(self, docs_to_fetch: int, uri: str): 292 | num_docs = 0 293 | docs_added = 0 294 | snapshot_id = self.snapsh_re.search(uri).group(0) 295 | uri_id = self.uri_id_re.search(uri).group(0) 296 | 297 | # signal writer 298 | signal_uri = os.path.join( 299 | self._output_base_uri, 300 | self.shard_pattern_signals.format(shard_id=uri_id.split(".")[0]), 301 | ) 302 | signal_writer = Writer(uri=signal_uri, schema=RP_SIGNAL_SCHEMA) 303 | self._logger.info(f"Initialized jsonl writer to {signal_uri}") 304 | 305 | # init minhash writer 306 | minhash_uri = os.path.join( 307 | self._output_base_uri, 308 | self.shard_pattern_minhash.format(shard_id=uri_id.split(".")[0]), 309 | ) 310 | minhash_writer = ParquetBatchWriter( 311 | output_fp=minhash_uri, schema=self._minhash_schema 312 | ) 313 | self._logger.info(f"Initialized parquet writer to {minhash_uri}") 314 | 315 | for idx, record in self._reader.read( 316 | uri=uri, max_samples=docs_to_fetch, return_idx=True 317 | ): 318 | # compute signals 319 | ( 320 | record_data, minhash_signatures, doc_id, doc_id_int 321 | ) = self.__process_record( 322 | idx=idx, record=record, uri_id=uri_id, snapshot_id=snapshot_id 323 | ) 324 | num_docs += 1 325 | docs_added += 1 326 | 327 | # write quality signals 328 | signal_writer.write(record_data) 329 | 330 | # record minhash signatures 331 | minhash_writer.update_batch( 332 | obj={"shard_id": uri_id, "id_int": doc_id_int, "id": doc_id, 333 | **minhash_signatures} 334 | ) 335 | 336 | # send to monitor 337 | if num_docs % self._flush_interval == 0: 338 | minhash_writer.write_batch() 339 | signal_writer.flush() 340 | self._monitor_queue.put({ 341 | "lang": self._lang, "num_docs": docs_added 342 | }) 343 | docs_added = 0 344 | 345 | if docs_added > 0: 346 | self._monitor_queue.put({ 347 | "lang": self._lang, "num_docs": docs_added 348 | }) 349 | 350 | # close writers 351 | signal_writer.close() 352 | minhash_writer.close() 353 | 354 | gc.collect() 355 | 356 | return num_docs 357 | 358 | def run(self): 359 | total_docs = 0 360 | 361 | for i, uri in enumerate(self._input_uri_list, start=1): 362 | docs_to_fetch = self._max_docs - total_docs 363 | if docs_to_fetch <= 0 < self._max_docs: 364 | self._logger.info( 365 | f"Reached max docs {self._max_docs} at {uri}") 366 | break 367 | 368 | # process file 369 | self._logger.info( 370 | f"Start processing {uri} ({i}/{len(self._input_uri_list)})" 371 | ) 372 | try: 373 | docs_in_uri = self.__process_uri(docs_to_fetch, uri) 374 | except Exception as e: 375 | with open(self._failed_input_file, "a+") as f: 376 | f.write(f"{uri}\n") 377 | self._logger.error(f"__URI_FAIL__ {uri} with exception: " 378 | f"{e.__class__.__name__}: {e} in " 379 | f"{self.__class__.__name__}.__process_uri") 380 | continue 381 | 382 | total_docs += docs_in_uri 383 | self._logger.info( 384 | f"__URI_SUCCESS__ {uri} ({i}/{len(self._input_uri_list)})" 385 | ) 386 | 387 | # send signal that a uri has been completed 388 | self._monitor_queue.put({ 389 | "lang": self._lang, "num_docs": None, "uri_complete": True 390 | }) 391 | 392 | # keep track of completed uris 393 | with open(self._success_input_file, "a+") as f: 394 | f.write(f"{uri}\n") 395 | 396 | self._logger.info(f"Worker {self._job_id} Completed. " 397 | f"Processed {total_docs} documents.") 398 | 399 | gc.collect() 400 | 401 | return total_docs, self._lang 402 | -------------------------------------------------------------------------------- /app/src/dedupe/__init__.py: -------------------------------------------------------------------------------- 1 | from .minhash import MinHash 2 | -------------------------------------------------------------------------------- /app/src/dedupe/minhash.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import numpy as np 3 | from typing import List, Dict, Optional, Tuple 4 | 5 | from dedupe.utils import optimal_param, generate_signature 6 | 7 | 8 | class MinHash: 9 | _sig_key_pat = "signature_sim{s}" 10 | 11 | def __init__( 12 | self, 13 | similarity_thresholds: List[float], 14 | ngram_size: int, 15 | num_permutations: int, 16 | seed: int, 17 | ): 18 | self._similarity_thresholds = similarity_thresholds 19 | self._rng = np.random.RandomState(seed) 20 | self._ngram_size = ngram_size 21 | 22 | self._bands_rows = { 23 | str(s): optimal_param(threshold=s, num_perm=num_permutations) 24 | for s in similarity_thresholds 25 | } 26 | 27 | self._hashranges = { 28 | self._sig_key_pat.format(s=s): self._get_hashrange(b, r) 29 | for s, (b, r) in self._bands_rows.items() 30 | } 31 | 32 | # init minhash artifacts 33 | self.__init_minhash(num_permutations) 34 | 35 | def __init_minhash(self, num_permutations): 36 | # minhash constants 37 | self._max_hash = np.uint64((1 << 32) - 1) 38 | self._mersenne_prime = np.uint64((1 << 61) - 1) 39 | self._permutations = np.array( 40 | [ 41 | ( 42 | self._rng.randint( 43 | 1, self._mersenne_prime, dtype=np.uint64 44 | ), 45 | self._rng.randint( 46 | 0, self._mersenne_prime, dtype=np.uint64 47 | ), 48 | ) 49 | for _ in range(num_permutations) 50 | ], 51 | dtype=np.uint64, 52 | ).T 53 | 54 | # compute checksum for permutations 55 | self._checksum = hashlib.sha256( 56 | self._permutations.tobytes() 57 | ).hexdigest() 58 | 59 | @staticmethod 60 | def _get_hashrange(b, r): 61 | return [(i * r, (i + 1) * r) for i in range(b)] 62 | 63 | @property 64 | def similarity_thresholds(self): 65 | return self._similarity_thresholds 66 | 67 | @property 68 | def checksum(self): 69 | return self._checksum 70 | 71 | def compute_banded_signatures( 72 | self, tokens: Tuple[str] 73 | ) -> Dict[str, Optional[List[bytes]]]: 74 | if len(tokens) < self._ngram_size: 75 | return {k: None for k in self._hashranges.keys()} 76 | 77 | # compute signature 78 | minhashes: np.ndarray = generate_signature( 79 | words_sequence=iter(tokens), 80 | ngram_size=self._ngram_size, 81 | permutations=self._permutations, 82 | max_hash=self._max_hash, 83 | mersenne_prime=self._mersenne_prime 84 | ) 85 | 86 | # partition signatures into bands 87 | signatures = { 88 | sig_key: [ 89 | bytes(minhashes[start:end].byteswap().data) 90 | for start, end in hashrange 91 | ] 92 | for sig_key, hashrange in self._hashranges.items() 93 | } 94 | 95 | return signatures 96 | -------------------------------------------------------------------------------- /app/src/dedupe/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | To reduce dependencies, the functions in this module are adapted from 3 | the `datasketch` library with minor modifications. 4 | """ 5 | 6 | from scipy.integrate import quad as integrate 7 | import hashlib 8 | import numpy as np 9 | import struct 10 | from typing import Iterable, List 11 | 12 | from utilities.text import form_ngrams 13 | 14 | 15 | def _false_positive_probability(threshold, b, r): 16 | def proba(s): 17 | return 1 - (1 - s ** float(r)) ** float(b) 18 | 19 | a, *_ = integrate(proba, 0.0, threshold) 20 | 21 | return a 22 | 23 | 24 | def _false_negative_probability(threshold, b, r): 25 | def proba(s): 26 | return 1 - (1 - (1 - s ** float(r)) ** float(b)) 27 | 28 | a, *_ = integrate(proba, threshold, 1.0) 29 | 30 | return a 31 | 32 | 33 | def optimal_param( 34 | threshold: float, 35 | num_perm: int, 36 | false_positive_weight: float = 0.5, 37 | false_negative_weight: float = 0.5 38 | ): 39 | r""" 40 | Compute the optimal `MinHashLSH` parameter that minimizes the weighted sum 41 | of probabilities of false positive and false negative. 42 | """ 43 | min_error = float("inf") 44 | opt = (0, 0) 45 | for b in range(1, num_perm + 1): 46 | max_r = int(num_perm / b) 47 | for r in range(1, max_r + 1): 48 | fp = _false_positive_probability(threshold, b, r) 49 | fn = _false_negative_probability(threshold, b, r) 50 | error = fp * false_positive_weight + fn * false_negative_weight 51 | if error < min_error: 52 | min_error = error 53 | opt = (b, r) 54 | return opt 55 | 56 | 57 | def sha1_hash32(data: bytes) -> int: 58 | """ 59 | A 32-bit hash function based on SHA1. 60 | 61 | Note: 62 | This implementation is copied from datasketch to avoid dependency. 63 | 64 | Args: 65 | data (bytes): the data to generate 32-bit integer hash from. 66 | 67 | Returns: 68 | int: an integer hash value that can be encoded using 32 bits. 69 | """ 70 | return struct.unpack(" np.ndarray: 80 | r""" 81 | Combined with some datasketch code to better parallelize computation. 82 | 83 | Note: 84 | This implementation is adapted from the near-dedupe implementation by 85 | the bigcode project. 86 | 87 | Parameters 88 | ---------- 89 | words_sequence : str 90 | A sequence of (normalized) words for which to generate a signature. 91 | ngram_size : int 92 | The size of n-grams. 93 | permutations : np.ndarray 94 | The permutations for the minhash. 95 | max_hash: int 96 | The maximum value for hashes. 97 | mersenne_prime: int 98 | The mersenne prime. 99 | 100 | Returns 101 | ------- 102 | List[np.uint32] 103 | The minhash signature. 104 | """ 105 | num_perm = permutations.shape[-1] 106 | hashvalues = np.ones(num_perm, dtype=np.uint64) * max_hash 107 | tokens = {" ".join(t) for t in form_ngrams(words_sequence, ngram_size)} 108 | h_vals = np.array( 109 | [sha1_hash32(token.encode("utf-8")) for token in tokens], 110 | dtype=np.uint64 111 | ) 112 | a, b = permutations 113 | phv = np.bitwise_and( 114 | ((h_vals * np.tile(a, (len(h_vals), 1)).T).T + b) % mersenne_prime, 115 | max_hash 116 | ) 117 | 118 | # compute the minhash 119 | signature = np.vstack([phv, hashvalues]).min(axis=0).astype(np.uint32) 120 | 121 | return signature 122 | -------------------------------------------------------------------------------- /app/src/prep_artifacts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from datetime import datetime as dt 4 | import os 5 | from pathlib import Path 6 | 7 | from artifacts.downloaders import ( 8 | WikipediaDownloader, 9 | OpenWebTextDownloader, 10 | BooksDownloader, 11 | CCNetDownloader 12 | ) 13 | from artifacts.hash_dist import HashDist 14 | from artifacts.ft_trainer import FastTextTrainer 15 | from utilities.logging.configure import configure_logger 16 | 17 | 18 | def parse_arguments(): 19 | def nullable_string(val): 20 | # converts empty string to None 21 | return None if not val else val 22 | 23 | parser = argparse.ArgumentParser() 24 | # input and outputs 25 | parser.add_argument( 26 | "--artifacts_dir", type=str, default=None, 27 | help="Directory where artifacts of the pipeline are stored" 28 | ) 29 | parser.add_argument( 30 | "--cc_input", type=str, default=None, 31 | help="cc_net output listings" 32 | ) 33 | parser.add_argument( 34 | "--cc_input_base_uri", type=str, default=None, 35 | help="Base URL (prefix) used for files list in input. Used to " 36 | "select the access method: s3:/// or file:///" 37 | ) 38 | parser.add_argument( 39 | "--cache_dir", type=str, default=None, 40 | help="huggingface cache directory" 41 | ) 42 | parser.add_argument( 43 | "--overwrite", action="store_true", 44 | help="Overwrite existing files" 45 | ) 46 | parser.add_argument( 47 | "--lang", type=str, default=None 48 | ) 49 | parser.add_argument( 50 | "--max_workers", type=int, default=None, 51 | help="Maximum number of workers to use" 52 | ) 53 | parser.add_argument( 54 | "--dsir_num_samples", type=int, default=None, 55 | help="Number of samples to use for dsir" 56 | ) 57 | parser.add_argument( 58 | "--dsir_feature_dim", type=int, default=None, 59 | help="Number of buckets to use for dsir" 60 | ) 61 | parser.add_argument( 62 | "--classifiers_num_samples", type=int, default=None, 63 | help="Number of samples to use for classifiers" 64 | ) 65 | parser.add_argument( 66 | "--endpoint_url", type=nullable_string, default=None, 67 | help="endpoint url where the s3 bucket is exposed." 68 | ) 69 | 70 | # sampling 71 | parser.add_argument( 72 | "--max_samples_per_book", type=int, default=None, 73 | help="Maximum number of samples to use per book" 74 | ) 75 | parser.add_argument( 76 | "--max_paragraphs_per_book_sample", type=int, default=None, 77 | help="Maximum number of paragraphs to use per book sample" 78 | ) 79 | 80 | return parser.parse_args() 81 | 82 | 83 | def main(artifacts_dir: str, cc_input: str, cc_input_base_uri: str, 84 | cache_dir: str, overwrite: bool, lang: str, 85 | max_workers: int, endpoint_url: str, 86 | dsir_num_samples: int, dsir_feature_dim: int, 87 | classifiers_num_samples: int, max_samples_per_book: int, 88 | max_paragraphs_per_book_sample: int 89 | ): 90 | if max_workers is None: 91 | max_workers = os.cpu_count() - 2 92 | else: 93 | max_workers = min(max_workers, os.cpu_count() - 2) 94 | 95 | # parse config 96 | num_samples = max(dsir_num_samples, classifiers_num_samples) 97 | 98 | # build output directory 99 | datasets_dir = Path(artifacts_dir) / "datasets" / f"{lang}" 100 | datasets_dir.mkdir(exist_ok=True, parents=True) 101 | timestamp = dt.now().strftime("%Y%m%d-%H%M%S") 102 | logfile = Path(artifacts_dir) / f"logs/{lang}_artifacts@{timestamp}.log" 103 | logfile.parent.mkdir(exist_ok=True, parents=True) 104 | configure_logger(logfile=logfile, level=logging.INFO) 105 | logger = logging.getLogger() 106 | 107 | logger.info(f"Start preparing artifacts for {lang}") 108 | logger.info(f"num_samples: {num_samples}") 109 | logger.info(f"PYTHONHASHSEED: {os.environ.get('PYTHONHASHSEED')}") 110 | 111 | # download ccnet dataset 112 | ccnet = CCNetDownloader( 113 | lang=lang, artifacts_dir=artifacts_dir, cc_input=cc_input, 114 | cc_input_base_uri=cc_input_base_uri, num_samples=num_samples, 115 | max_workers=max_workers, endpoint_url=endpoint_url 116 | ) 117 | ccnet.run(logger=logger) 118 | 119 | # download wikipedia dataset 120 | wikipedia = WikipediaDownloader( 121 | lang=lang, out_dir=datasets_dir, 122 | overwrite=overwrite, cache_dir=cache_dir, 123 | max_samples=num_samples 124 | ) 125 | wikipedia.run(logger=logger) 126 | 127 | # download openwebtext dataset 128 | openwebtext = OpenWebTextDownloader( 129 | lang=lang, out_dir=datasets_dir, 130 | overwrite=overwrite, cache_dir=cache_dir, 131 | max_samples=num_samples 132 | ) 133 | openwebtext.run(logger=logger) 134 | 135 | # download books dataset 136 | books = BooksDownloader( 137 | lang=lang, out_dir=datasets_dir, 138 | overwrite=overwrite, cache_dir=cache_dir, 139 | max_samples=num_samples, 140 | max_paragraphs_per_sample=max_paragraphs_per_book_sample, 141 | max_samples_per_book=max_samples_per_book, 142 | ) 143 | books.run(logger=logger) 144 | 145 | # compute hash distributions 146 | hash_dist = HashDist( 147 | artifacts_dir=artifacts_dir, 148 | num_samples=num_samples, 149 | buckets=dsir_feature_dim, 150 | max_workers=max_workers, 151 | logger=logger 152 | ) 153 | 154 | # compute hash distribution for each dataset 155 | for obj in [wikipedia, openwebtext, books, ccnet]: 156 | fp = obj.filepath 157 | 158 | if fp is None: 159 | continue 160 | 161 | hash_dist.run(lang=lang, datafile=fp, dataset=obj.dataset_name) 162 | 163 | if lang == "en": 164 | # compute fasttext palm classifier 165 | target_name = "palm" 166 | target_data = [ 167 | wikipedia.filepath, books.filepath, openwebtext.filepath 168 | ] 169 | else: 170 | # for non english languages, we use wikipedia as target 171 | target_name = f"wikipedia" 172 | target_data = [wikipedia.filepath] 173 | 174 | trainer = FastTextTrainer( 175 | artifacts_dir=artifacts_dir, 176 | ccnet_data=ccnet.filepath, 177 | target_data=target_data, 178 | target_name=target_name, 179 | samples_per_class=classifiers_num_samples, 180 | lang=lang 181 | ) 182 | trainer.run(logger=logger) 183 | 184 | logger.info(f"Finished preparing artifacts for {lang}") 185 | 186 | 187 | if __name__ == '__main__': 188 | args = parse_arguments() 189 | main(artifacts_dir=args.artifacts_dir, 190 | cc_input=args.cc_input, 191 | cc_input_base_uri=args.cc_input_base_uri, 192 | cache_dir=args.cache_dir, 193 | overwrite=args.overwrite, 194 | lang=args.lang, 195 | max_workers=args.max_workers, 196 | endpoint_url=args.endpoint_url, 197 | dsir_num_samples=args.dsir_num_samples, 198 | dsir_feature_dim=args.dsir_feature_dim, 199 | classifiers_num_samples=args.classifiers_num_samples, 200 | max_samples_per_book=args.max_samples_per_book, 201 | max_paragraphs_per_book_sample=args.max_paragraphs_per_book_sample 202 | ) 203 | -------------------------------------------------------------------------------- /app/src/run_lsh.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datetime import datetime as dt 3 | import gc 4 | import logging 5 | import networkit.components as nk_components 6 | import networkit.graph as nk_graph 7 | import numpy as np 8 | import os 9 | from pathlib import Path 10 | import polars as pl 11 | import pyarrow as pa 12 | import pyarrow.dataset as ds 13 | import s3fs 14 | import time 15 | from typing import Dict, Tuple, List 16 | from urllib.parse import urlparse 17 | 18 | from dedupe.utils import optimal_param 19 | 20 | LOG_FMT = '[%(asctime)s]::(PID %(process)d)::%(levelname)-2s::%(message)s' 21 | 22 | 23 | class LSH: 24 | r""" Locality Sensitive Hashing (LSH) algorithm for near deduplication. """ 25 | __slots__ = ( 26 | "_args", "_num_bands", "_job_id", "_logger", "_sig_key", "_schema" 27 | ) 28 | 29 | # regex to extract filepaths from source file listings 30 | fp_pattern = r'\/(\d{4}-\d{2}\/\d{4}\/.*\.json\.gz)$' 31 | 32 | # signature key 33 | sig_key_pat = "signature_sim{s}" 34 | 35 | def __init__(self): 36 | self._job_id = dt.now().strftime("%Y%m%d_%H%M%S") 37 | self._args = self.__parse_arguments() 38 | 39 | self._sig_key = self.sig_key_pat.format( 40 | s=str(self._args.similarity) 41 | ) 42 | 43 | # get number of bands and rows 44 | self._num_bands, _ = optimal_param( 45 | threshold=self._args.similarity, num_perm=self._args.num_perm 46 | ) 47 | 48 | # init schema 49 | self._schema = self.__init_schema() 50 | 51 | # init logging 52 | self.__init_logger() 53 | 54 | # log setup 55 | self._logger.info("=" * 80) 56 | self._logger.info("LSH config:") 57 | for k, v in vars(self._args).items(): 58 | self._logger.info(f"{k}: {v}") 59 | self._logger.info("=" * 80) 60 | 61 | def __init_schema(self) -> pa.Schema: 62 | return pa.schema([ 63 | ("id", pa.string()), 64 | ("shard_id", pa.string()), 65 | ("id_int", pa.uint64()), 66 | (self._sig_key, pa.list_(pa.binary())) 67 | ]) 68 | 69 | def __init_logger(self): 70 | self._logger = logging.getLogger(self._job_id) 71 | self._logger.setLevel(logging.DEBUG) 72 | 73 | # log to file 74 | logfile = Path(self._args.output_dir) / "logs" / f"{self._job_id}.log" 75 | if not logfile.parent.exists(): 76 | logfile.parent.mkdir(parents=True) 77 | filehandler = logging.FileHandler(logfile) 78 | filehandler.setFormatter(logging.Formatter(LOG_FMT)) 79 | self._logger.addHandler(filehandler) 80 | 81 | # log to stdout 82 | stream_handler = logging.StreamHandler() 83 | stream_handler.setFormatter(logging.Formatter(LOG_FMT)) 84 | self._logger.addHandler(stream_handler) 85 | 86 | def __parse_arguments(self) -> argparse.Namespace: 87 | 88 | if self.__doc__ is not None: 89 | description = " - " + self.__doc__ 90 | else: 91 | description = self.__class__.__name__ 92 | 93 | parser = argparse.ArgumentParser( 94 | prog=self.__class__.__name__, description=description 95 | ) 96 | parser.add_argument( 97 | "--listings", type=str, default=None, 98 | help="file containing paths to minhash parquet files. LSH will be" 99 | "run on the minhashes stored in these files." 100 | ) 101 | parser.add_argument( 102 | "--input_base_uri", type=str, default=None, 103 | help="base uri of the input files." 104 | ) 105 | parser.add_argument( 106 | "--output_dir", type=str, default=None, 107 | help="root directory where the output will be stored." 108 | ) 109 | parser.add_argument( 110 | "--similarity", type=float, default=None, 111 | help="similarity threshold for two documents to be considered near" 112 | " duplicates." 113 | ) 114 | parser.add_argument( 115 | "--num_perm", type=int, default=None, 116 | help="number of permutations used during minhashing." 117 | ) 118 | parser.add_argument( 119 | "--max_docs", type=int, default=-1, 120 | help="maximum number of documents to process. If set to -1, all " 121 | "documents will be processed." 122 | ) 123 | 124 | # s3 125 | parser.add_argument( 126 | "--s3_profile", type=str, default="default", 127 | help="aws profile to use when connecting to s3." 128 | ) 129 | parser.add_argument( 130 | "--endpoint_url", type=str, default=None, 131 | help="endpoint url of the s3 server." 132 | ) 133 | 134 | args = parser.parse_args() 135 | 136 | return args 137 | 138 | def __build_dataset(self) -> pa.dataset.Dataset: 139 | base_uri = urlparse(self._args.input_base_uri) 140 | 141 | if base_uri.scheme == "file": 142 | return self.__buil_dataset_local(base_uri) 143 | elif base_uri.scheme == "s3": 144 | return self.__build_dataset_s3() 145 | else: 146 | raise ValueError(f"Invalid base uri: {base_uri}") 147 | 148 | def __buil_dataset_local(self, base_uri) -> pa.dataset.Dataset: 149 | root_path = Path(base_uri.path) 150 | 151 | # 1) get paths and build pyarrow dataset 152 | with open(self._args.listings, "r") as f: 153 | input_paths = [ 154 | root_path / Path(line.strip()) for line in f.readlines() 155 | ] 156 | 157 | return ds.dataset( 158 | source=input_paths, schema=self._schema, format="parquet" 159 | ) 160 | 161 | def __build_dataset_s3(self) -> pa.dataset.Dataset: 162 | fs = s3fs.S3FileSystem( 163 | profile=self._args.s3_profile, 164 | endpoint_url=self._args.endpoint_url 165 | ) 166 | 167 | # 1) get paths and build pyarrow dataset 168 | with open(self._args.listings, "r") as f: 169 | input_paths = list(map( 170 | lambda ln: os.path.join(self._args.input_base_uri, ln.strip()), 171 | f.readlines() 172 | )) 173 | 174 | return ds.dataset( 175 | source=input_paths, filesystem=fs, schema=self._schema, 176 | format="parquet" 177 | ) 178 | 179 | def run(self): 180 | global_start_time = time.time() 181 | 182 | # 1) build pyarrow dataset; this is a lazy operation pointing to a 183 | # collection of parquet files on disk or in an S3 bucket 184 | pa_dset = self.__build_dataset() 185 | 186 | # 2) build edges 187 | step_time = time.time() 188 | self._logger.info("Start building edges") 189 | edges = self.__build_edges(pa_dset=pa_dset) 190 | step_time = time.time() - step_time 191 | self._logger.info( 192 | f"Building edges complete. Shape={edges.shape}; Time={step_time}s" 193 | ) 194 | 195 | # 3) detect components 196 | step_time = time.time() 197 | self._logger.info("Start detecting components") 198 | ( 199 | components, num_nodes, reversed_mapper 200 | ) = self.__run_connected_components(edges=edges) 201 | step_time = time.time() - step_time 202 | self._logger.info( 203 | f"Connected compontents complete. Time={step_time}s" 204 | ) 205 | 206 | del edges 207 | gc.collect() 208 | 209 | # 4) collect cluster ids 210 | step_time = time.time() 211 | self._logger.info("Start collecting cluster ids") 212 | cluster_ids = self.__get_doc_to_cluster_array( 213 | components=components, reversed_mapper=reversed_mapper 214 | ) 215 | step_time = time.time() - step_time 216 | self._logger.info(f"Building doc->cluster index complete. " 217 | f"Time={step_time}s") 218 | 219 | # 5) build cluster dataframes 220 | step_time = time.time() 221 | self._logger.info("Start building final cluster dataframes") 222 | cluster_dataframes = self.__build_cluster_dataframes( 223 | pa_dset=pa_dset, doc_to_cluster=cluster_ids 224 | ) 225 | step_time = time.time() - step_time 226 | self._logger.info(f"Building final cluster dataframes complete. " 227 | f"Time={step_time}s") 228 | 229 | # 6) write cluster dataframes to disk 230 | out_root = Path(self._args.output_dir) 231 | for k, v in cluster_dataframes.items(): 232 | 233 | tag = Path(k.split(".")[0]).with_suffix(".clusters.parquet") 234 | if not (out_root / tag).parent.exists(): 235 | (out_root / tag).parent.mkdir(parents=True) 236 | 237 | # write to disk 238 | v.write_parquet(out_root / tag) 239 | self._logger.info(f"Wrote cluster data to {out_root / tag}") 240 | 241 | elapsed_time = time.time() - global_start_time 242 | self._logger.info(f"LSH complete. Total time: {elapsed_time}s") 243 | 244 | def __build_edges(self, pa_dset: pa.dataset.Dataset) -> np.ndarray: 245 | 246 | # build polars query plan 247 | query = pl.scan_pyarrow_dataset(pa_dset) 248 | 249 | if self._args.max_docs > 0: 250 | query = query.head(self._args.max_docs) 251 | 252 | query = ( 253 | query 254 | .select( 255 | pl.col(["id_int", self._sig_key]) 256 | ) 257 | .filter( 258 | ~pl.col(self._sig_key).is_null() 259 | ) 260 | .with_columns( 261 | pl.Series( 262 | name="band", 263 | values=[list(range(self._num_bands))], 264 | dtype=pl.List(pl.UInt8) 265 | ) 266 | ) 267 | .explode(self._sig_key, "band") 268 | .group_by(self._sig_key, "band") 269 | .agg(pl.col("id_int")) 270 | .filter( 271 | pl.col("id_int").list.lengths() > 1 272 | ) 273 | .select( 274 | pl.col("id_int"), 275 | pl.col("id_int").list.min().alias("min_node") 276 | ) 277 | .explode("id_int") 278 | .filter( 279 | pl.col("id_int") != pl.col("min_node") 280 | ) 281 | .select( 282 | pl.concat_list(["id_int", "min_node"]).alias("edges") 283 | ) 284 | .unique("edges") 285 | ) 286 | 287 | self._logger.debug(f"Query Plan:\n{query.explain()}") 288 | self._logger.debug(f"Start running query...") 289 | edges = query.collect(streaming=True).to_numpy().flatten() 290 | self._logger.debug(f"Completed running query.") 291 | gc.collect() 292 | 293 | return edges 294 | 295 | @staticmethod 296 | def __run_connected_components( 297 | edges: np.ndarray 298 | ) -> Tuple[List[List[int]], int, Dict[int, int]]: 299 | # build graph from edges 300 | graph = nk_graph.Graph() 301 | node_mapper = {} 302 | 303 | for row in edges: 304 | node_id1, node_id2 = row 305 | 306 | if node_id1 not in node_mapper: 307 | node_mapper[node_id1] = graph.addNode() 308 | 309 | if node_id2 not in node_mapper: 310 | node_mapper[node_id2] = graph.addNode() 311 | 312 | graph.addEdge(node_mapper[node_id1], node_mapper[node_id2]) 313 | 314 | reversed_mapper = {value: key for key, value in node_mapper.items()} 315 | 316 | # compute connected components 317 | cc = nk_components.ConnectedComponents(G=graph) 318 | cc.run() 319 | components = cc.getComponents() 320 | num_nodes = sum(cc.getComponentSizes().values()) 321 | 322 | return components, num_nodes, reversed_mapper 323 | 324 | @staticmethod 325 | def __get_doc_to_cluster_array( 326 | components: List[List[int]], reversed_mapper: Dict[int, int] 327 | ) -> np.ndarray: 328 | def __process_comp(comp) -> np.ndarray: 329 | nodes = np.array( 330 | list(map(reversed_mapper.get, comp)) 331 | ).reshape(-1, 1) 332 | cluster_id = min(map(reversed_mapper.get, comp)) 333 | cluster_id = np.repeat(cluster_id, len(nodes)).reshape(-1, 1) 334 | return np.hstack((nodes, cluster_id)) 335 | 336 | data = np.vstack(tuple(map(__process_comp, components))) 337 | 338 | return data 339 | 340 | def __build_cluster_dataframes( 341 | self, pa_dset: pa.dataset.Dataset, doc_to_cluster: np.ndarray 342 | ) -> Dict[str, pl.DataFrame]: 343 | cluster_df = pl.LazyFrame( 344 | data=doc_to_cluster, 345 | schema=[("id_int", pl.UInt64), ("cluster_id", pl.UInt64)] 346 | ) 347 | 348 | # build polars query plan 349 | query = pl.scan_pyarrow_dataset(pa_dset) 350 | 351 | if self._args.max_docs > 0: 352 | query = query.head(self._args.max_docs) 353 | 354 | partitioned_dfs = ( 355 | query 356 | .select(pl.col(["id", "id_int", "shard_id"])) 357 | .join(other=cluster_df, on="id_int", how="inner") 358 | .select(pl.col(["id", "id_int", "cluster_id", "shard_id"])) 359 | .collect() 360 | ) 361 | 362 | with pl.Config(set_fmt_str_lengths=5000, tbl_rows=20): 363 | self._logger.info( 364 | f"First 20 rows of minhash clusters:\n\n" 365 | f"{partitioned_dfs.sort(by='cluster_id').head(20)}" 366 | ) 367 | time.sleep(2) 368 | 369 | partitioned_dfs = partitioned_dfs.partition_by(by="shard_id", 370 | as_dict=True) 371 | 372 | return partitioned_dfs 373 | 374 | 375 | if __name__ == '__main__': 376 | job = LSH() 377 | job.run() 378 | -------------------------------------------------------------------------------- /app/src/token_count.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import boto3 3 | import botocore.client 4 | import concurrent.futures 5 | from dataclasses import dataclass 6 | from datetime import datetime as dt 7 | import gzip 8 | import io 9 | import logging 10 | import msgspec 11 | import os 12 | from pathlib import Path 13 | import progiter 14 | import pyarrow as pa 15 | import random 16 | import re 17 | from typing import Tuple, List 18 | from tokenizers import Tokenizer 19 | from urllib.parse import urlparse, ParseResult 20 | 21 | from utilities.logging import configure_logger 22 | from utilities.io.writer import ParquetBatchWriter 23 | 24 | 25 | class InputSpec(msgspec.Struct): 26 | raw_content: str 27 | 28 | 29 | @dataclass 30 | class DlStatus: 31 | is_success: bool 32 | msg: str 33 | uri: str 34 | 35 | 36 | @dataclass 37 | class InputResult: 38 | is_success: bool 39 | msg: str 40 | input_id: str 41 | num_docs: int = 0 42 | num_tokens: int = 0 43 | token_counts: List[Tuple[int, int]] = None 44 | 45 | 46 | TOKENIZER = Tokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") 47 | 48 | 49 | class PostProcessor: 50 | listings_re = re.compile( 51 | r".*(\d{4}-\d{2}/\d{4}/(?:en|es|de|fr|it)_(?:tail|middle|head)).json.gz" 52 | ) 53 | 54 | def __parse_arguments(self) -> argparse.Namespace: 55 | 56 | if self.__doc__ is not None: 57 | description = " - " + self.__doc__ 58 | else: 59 | description = self.__class__.__name__ 60 | 61 | parser = argparse.ArgumentParser( 62 | prog=self.__class__.__name__, description=description 63 | ) 64 | 65 | # io 66 | parser.add_argument( 67 | "--snapshots", nargs="+", type=str, default=None, 68 | ) 69 | parser.add_argument( 70 | "--input_base_uri", type=str, default=None, 71 | help="base uri of the input files." 72 | ) 73 | parser.add_argument( 74 | "--logs_dir", type=str, default=None, 75 | help="directory to store logs." 76 | ) 77 | 78 | parser.add_argument( 79 | "--s3_profile", type=str, default="default", 80 | help="profile name of the s3 client." 81 | ) 82 | parser.add_argument( 83 | "--endpoint_url", type=str, default=None, 84 | help="S3 bucket endpoint url." 85 | ) 86 | parser.add_argument( 87 | "--parallelism", type=int, default=1, 88 | help="number of parallel processes. Defaults to 1." 89 | ) 90 | 91 | parser.add_argument( 92 | "--batch_size", type=int, default=1, 93 | help="batch size. Defaults to 1." 94 | ) 95 | parser.add_argument( 96 | "--max_inputs", type=int, default=4, 97 | help="maximum number of inputs to process. For debugging." 98 | ) 99 | 100 | parser.add_argument( 101 | "--debug", default=0, choices=[0, 1], type=int, 102 | help="runs in debug mode if set to 1." 103 | ) 104 | parser.add_argument( 105 | "--input_listings", type=str, default="listings.txt", 106 | help="path to file containing input ids." 107 | ) 108 | parser.add_argument( 109 | "--seed", type=int, default=42 110 | ) 111 | 112 | args = parser.parse_args() 113 | 114 | return args 115 | 116 | def __init__(self): 117 | self._job_id = dt.now().strftime("%Y%m%d_%H%M%S") 118 | self._args = self.__parse_arguments() 119 | 120 | random.seed(self._args.seed) 121 | 122 | # i/o 123 | self._input_base_uri = self._args.input_base_uri 124 | self._logs_dir = self._args.logs_dir 125 | 126 | def __init_client(self): 127 | session = boto3.Session(profile_name=self._args.s3_profile) 128 | client = session.client( 129 | service_name='s3', 130 | endpoint_url=self._args.endpoint_url, 131 | config=boto3.session.Config( 132 | signature_version='s3v4', 133 | retries={'max_attempts': 10, 'mode': 'standard'} 134 | ) 135 | ) 136 | return session, client 137 | 138 | @staticmethod 139 | def _dload_file(uri: ParseResult, client) -> Tuple[DlStatus, io.BytesIO]: 140 | try: 141 | streaming_body = client.get_object( 142 | Bucket=uri.netloc, Key=uri.path.lstrip("/") 143 | )["Body"] 144 | buffer = io.BytesIO(streaming_body.read()) 145 | msg = f"__S3_URI_READ_SUCCESS__ success reading {uri.path}" 146 | is_success = True 147 | except Exception as e: 148 | msg = ( 149 | f"__S3_URI_READ_ERROR__ failed reading {uri.path}: " 150 | f"caught exception {e.__class__.__name__}: {e}" 151 | ) 152 | buffer = None 153 | is_success = False 154 | 155 | read_status = DlStatus(is_success=is_success, msg=msg, uri=str(uri)) 156 | return read_status, buffer 157 | 158 | def __load_input_ids( 159 | self, snapshot: str 160 | ) -> List[str]: 161 | 162 | assert self._args.input_listings is not None 163 | 164 | input_ids = [] 165 | with open(self._args.input_listings, "r") as fin: 166 | for ln in fin.readlines(): 167 | try: 168 | ln = self.listings_re.findall(ln.strip())[0] 169 | except IndexError: 170 | continue 171 | if f"{snapshot}/" not in ln: 172 | continue 173 | input_ids.append(ln) 174 | 175 | return input_ids 176 | 177 | def _process_listings(self, input_ids: List[str]) -> List[InputResult]: 178 | sess, client = self.__init_client() 179 | 180 | # decoding and encoding 181 | decoder = msgspec.json.Decoder(type=InputSpec) 182 | 183 | results = [] 184 | for input_id in input_ids: 185 | proc_res: InputResult = self._process_single_listing( 186 | client, input_id, decoder 187 | ) 188 | results.append(proc_res) 189 | 190 | return results 191 | 192 | def _process_single_listing( 193 | self, client, input_id, decoder 194 | ) -> InputResult: 195 | # handle signals 196 | result: InputResult = self._handle_documents( 197 | client, input_id, decoder 198 | ) 199 | if not result.is_success: 200 | result.msg = f"__FAIL__ {input_id} ({result.msg})" 201 | return result 202 | 203 | result.msg = f"__SUCCESS__ {input_id}" 204 | 205 | return result 206 | 207 | def _handle_documents( 208 | self, 209 | client: botocore.client.BaseClient, 210 | input_id: str, 211 | decoder 212 | ) -> InputResult: 213 | # download doc 214 | input_uri = urlparse( 215 | os.path.join( 216 | self._input_base_uri, f"{input_id}.json.gz" 217 | ) 218 | ) 219 | dl_status, input_buffer = self._dload_file(input_uri, client=client) 220 | 221 | # check if download was successful 222 | if not dl_status.is_success: 223 | return InputResult( 224 | is_success=False, msg=dl_status.msg, input_id=input_id 225 | ) 226 | 227 | num_docs = 0 228 | total_tokens = 0 229 | token_counts = [] 230 | 231 | try: 232 | with gzip.open(input_buffer, mode="rb") as in_fh: 233 | for idx, obj in enumerate(in_fh): 234 | record = decoder.decode(obj) 235 | 236 | # tokenize 237 | num_tokens = len( 238 | TOKENIZER.encode(record.raw_content).tokens 239 | ) 240 | token_counts.append((idx, num_tokens)) 241 | 242 | total_tokens += num_tokens 243 | num_docs += 1 244 | 245 | except Exception as e: 246 | msg = ( 247 | f"__DECODE_ENCODE_FAIL__ {input_id}: " 248 | f"caught exception {e.__class__.__name__}: {e}" 249 | ) 250 | return InputResult(is_success=False, msg=msg, input_id=input_id) 251 | 252 | return InputResult( 253 | is_success=True, 254 | msg="", 255 | input_id=input_id, 256 | num_docs=num_docs, 257 | num_tokens=total_tokens, 258 | token_counts=token_counts 259 | ) 260 | 261 | def run(self): 262 | # init logging 263 | logfile = Path(self._logs_dir) / f"{self._job_id}.log" 264 | configure_logger(logfile=logfile, level=logging.INFO, stream=False) 265 | logger = logging.getLogger() 266 | 267 | # log configs 268 | for attr in ( 269 | "snapshots", "input_base_uri", "batch_size", 270 | "parallelism", "max_inputs", "debug", "input_listings", "seed" 271 | ): 272 | logger.info(f"__CONFIG__ {attr}: {getattr(self._args, attr)}") 273 | 274 | for snapshot in self._args.snapshots: 275 | logger.info(f"__START_SNAPSHOT__ {snapshot}") 276 | try: 277 | self.run_snapshot(snapshot, logger=logger) 278 | except KeyboardInterrupt: 279 | break 280 | logger.info(f"__END_SNAPSHOT__ {snapshot}") 281 | 282 | def run_snapshot(self, snapshot_id, logger): 283 | # load input file ids 284 | input_ids = self.__load_input_ids(snapshot_id) 285 | msg = ( 286 | f"__INPUT_LISTINGS_LOADED__ " 287 | f"found {len(input_ids)} input files in {snapshot_id}" 288 | ) 289 | logger.info(msg) 290 | random.shuffle(input_ids) 291 | 292 | if self._args.max_inputs is not None: 293 | input_ids = input_ids[:self._args.max_inputs] 294 | 295 | input_ids_batches = [ 296 | input_ids[i:i + self._args.batch_size] 297 | for i in range(0, len(input_ids), self._args.batch_size) 298 | ] 299 | 300 | # init output writer 301 | out_fp = Path(self._logs_dir) / f"{snapshot_id}_counts.parquet" 302 | out_schema = pa.schema([ 303 | ("input_id", pa.string()), 304 | ("doc_id", pa.string()), 305 | ("snapshot_id", pa.string()), 306 | ("num_tokens", pa.int64()) 307 | ]) 308 | 309 | pq_writer = ParquetBatchWriter(output_fp=out_fp, schema=out_schema) 310 | 311 | if self._args.debug: 312 | self.__debug_run( 313 | input_ids_batches, logger=logger, snapshot_id=snapshot_id, 314 | pq_writer=pq_writer 315 | ) 316 | else: 317 | self.__parallel_run( 318 | input_ids_batches, logger=logger, snapshot_id=snapshot_id, 319 | pq_writer=pq_writer 320 | ) 321 | 322 | pq_writer.close() 323 | 324 | def __debug_run( 325 | self, 326 | input_ids_batches: List[List[str]], 327 | logger: logging.Logger, 328 | snapshot_id: str, 329 | pq_writer: ParquetBatchWriter 330 | ): 331 | num_docs = 0 332 | num_succ = 0 333 | num_fail = 0 334 | total_tokens = 0 335 | 336 | # progress bar 337 | total_inputs = sum(map(len, input_ids_batches)) 338 | pman = progiter.ProgressManager(backend="rich") 339 | pbar = pman.progiter( 340 | total=total_inputs, 341 | desc=f"Processing {snapshot_id}", 342 | backend="rich" 343 | ) 344 | 345 | for batch in input_ids_batches: 346 | inputs_results: List[InputResult] = self._process_listings(batch) 347 | 348 | for proc_res in inputs_results: 349 | if proc_res.is_success: 350 | num_succ += 1 351 | num_docs += proc_res.num_docs 352 | total_tokens += proc_res.num_tokens 353 | else: 354 | num_fail += 1 355 | 356 | logger.info(proc_res.msg) 357 | 358 | pbar.step(1) 359 | pbar.set_postfix_str( 360 | f"total_inputs: {num_succ:,} ({num_fail:,} fail); " 361 | f"num_docs: {num_docs:,} -- " 362 | f"num_tokens: {total_tokens:,}" 363 | ) 364 | 365 | if not proc_res.is_success: 366 | continue 367 | 368 | for idx, num_tokens in proc_res.token_counts: 369 | pq_writer.update_batch({ 370 | "input_id": proc_res.input_id, 371 | "doc_id": f"{proc_res.input_id}.json.gz/{idx}", 372 | "snapshot_id": snapshot_id, 373 | "num_tokens": num_tokens, 374 | }) 375 | 376 | pq_writer.write_batch() 377 | 378 | pman.stop() 379 | 380 | # log summary 381 | logger.info( 382 | f"__PROCESSING_COMPLETE__\n*******************\n" 383 | f"num_inputs_success: {num_succ:,}\n" 384 | f"num_inputs_failed: {num_fail:,}\n" 385 | f"num_docs: {num_docs:,}\n" 386 | f"num_tokens: {total_tokens:,}" 387 | ) 388 | 389 | def __parallel_run( 390 | self, 391 | input_ids_batches: List[List[str]], 392 | logger: logging.Logger, 393 | snapshot_id: str, 394 | pq_writer: ParquetBatchWriter 395 | ): 396 | num_docs = 0 397 | num_succ = 0 398 | num_fail = 0 399 | total_tokens = 0 400 | 401 | # progress bar 402 | total_inputs = sum(map(len, input_ids_batches)) 403 | pman = progiter.ProgressManager(backend="rich") 404 | pbar = pman.progiter( 405 | total=total_inputs, 406 | desc=f"Processing {snapshot_id}", 407 | backend="rich" 408 | ) 409 | 410 | # process listings 411 | try: 412 | with concurrent.futures.ProcessPoolExecutor( 413 | max_workers=self._args.parallelism 414 | ) as executor: 415 | futures = { 416 | executor.submit( 417 | self._process_listings, 418 | input_ids=batch, 419 | ): batch 420 | for batch in input_ids_batches 421 | } 422 | 423 | for future in concurrent.futures.as_completed(futures): 424 | proc_results: List[InputResult] = future.result() 425 | del futures[future] 426 | 427 | for proc_res in proc_results: 428 | if proc_res.is_success: 429 | num_succ += 1 430 | num_docs += proc_res.num_docs 431 | total_tokens += proc_res.num_tokens 432 | else: 433 | num_fail += 1 434 | 435 | logger.info(proc_res.msg) 436 | 437 | pbar.step(1) 438 | pbar.set_postfix_str( 439 | f"total_inputs: {num_succ:,} ({num_fail:,} fail); " 440 | f"num_docs: {num_docs:,} -- " 441 | f"num_tokens: {total_tokens:,}" 442 | ) 443 | 444 | if not proc_res.is_success: 445 | continue 446 | 447 | for idx, num_tokens in proc_res.token_counts: 448 | pq_writer.update_batch({ 449 | "input_id": proc_res.input_id, 450 | "doc_id": f"{proc_res.input_id}.json.gz/{idx}", 451 | "snapshot_id": snapshot_id, 452 | "num_tokens": num_tokens, 453 | }) 454 | 455 | pq_writer.write_batch() 456 | 457 | except KeyboardInterrupt: 458 | logger.info("KeyboardInterrupt caught. Terminating...") 459 | pman.stop() 460 | executor.shutdown(wait=False, cancel_futures=True) 461 | pq_writer.close() 462 | raise KeyboardInterrupt 463 | 464 | pman.stop() 465 | 466 | # log summary 467 | logger.info( 468 | f"__PROCESSING_COMPLETE__\n*******************\n" 469 | f"num_inputs_success: {num_succ:,}\n" 470 | f"num_inputs_failed: {num_fail:,}\n" 471 | f"num_docs: {num_docs:,}\n" 472 | f"num_tokens: {total_tokens:,}" 473 | ) 474 | 475 | 476 | if __name__ == '__main__': 477 | pp = PostProcessor() 478 | pp.run() 479 | -------------------------------------------------------------------------------- /app/src/utilities/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/togethercomputer/RedPajama-Data/6d2cee9df2b0204dd2bcb00bf06b5a7b1d7432d7/app/src/utilities/__init__.py -------------------------------------------------------------------------------- /app/src/utilities/io/__init__.py: -------------------------------------------------------------------------------- 1 | from .writer import Writer, ParquetBatchWriter 2 | from .reader import Reader 3 | -------------------------------------------------------------------------------- /app/src/utilities/io/reader.py: -------------------------------------------------------------------------------- 1 | import io 2 | import msgspec 3 | import boto3 4 | import gzip 5 | from urllib.parse import urlparse, ParseResult 6 | import pathlib 7 | from typing import Optional, Type, List, Tuple 8 | import xopen 9 | 10 | from core.data_types import InputSpec 11 | from core.exceptions import * 12 | 13 | 14 | class Reader: 15 | r""" Read plain jsonl, jsonl.zst and jsonl.gz files using msgspec """ 16 | 17 | def __init__( 18 | self, 19 | schema: List[Tuple[str, type]] = None, 20 | input_spec: Type[msgspec.Struct] = InputSpec, 21 | s3_client: Optional[boto3.client] = None, 22 | threads: int = 1, 23 | logger=None 24 | ): 25 | self._client = s3_client 26 | self._threads = threads 27 | 28 | if logger is None: 29 | self._print = print 30 | else: 31 | self._print = logger.error 32 | 33 | # msgspec decoder 34 | if schema is not None: 35 | input_type = msgspec.defstruct(name="Record", fields=schema) 36 | else: 37 | input_type = input_spec 38 | 39 | self._obj_decoder = msgspec.json.Decoder(type=input_type) 40 | 41 | self._total_consumed = 0 42 | 43 | def read(self, uri: str, max_samples: Optional[int] = -1, 44 | return_idx: bool = True): 45 | n_samples = 0 46 | 47 | try: 48 | with self.__get_filehandle(uri) as fh: 49 | for idx, obj in enumerate(fh): 50 | try: 51 | record = self._obj_decoder.decode(obj) 52 | if return_idx: 53 | yield idx, record 54 | else: 55 | yield record 56 | except Exception as e: 57 | self._print(f"__SAMPLE_READ_ERROR__ {uri}/{idx}: " 58 | f"{e.__class__.__name__}: {e}") 59 | continue 60 | 61 | n_samples += 1 62 | 63 | if n_samples >= max_samples > 0: 64 | break 65 | except S3ReadError as e: 66 | raise e 67 | except LocalReadError: 68 | raise e 69 | except Exception as e: 70 | raise UnknownReadError(f"unknown __URI_READ_ERROR__ {uri}: " 71 | f"{e.__class__.__name__}: {e}") 72 | 73 | def __get_filehandle(self, uri: str): 74 | uri = urlparse(uri) 75 | 76 | if uri.scheme == "s3": 77 | return self.__get_s3_filehandle(uri) 78 | 79 | if uri.scheme == "file": 80 | return self.__get_local_filehandle(uri) 81 | 82 | raise ValueError(f"Invalid uri: {uri}; must be of the form " 83 | f"s3:/// or file://") 84 | 85 | def __get_s3_filehandle(self, uri: ParseResult): 86 | assert self._client is not None, "S3 client not initialized" 87 | 88 | try: 89 | streaming_body = self._client.get_object( 90 | Bucket=uri.netloc, Key=uri.path.lstrip("/") 91 | )["Body"] 92 | buffer = io.BytesIO(streaming_body.read()) 93 | except Exception as e: 94 | raise S3ReadError( 95 | f"__S3_URI_READ_ERROR__ failed reading {uri.path}: " 96 | f"caught exception {e.__class__.__name__}: {e}" 97 | ) 98 | 99 | return gzip.open(buffer, mode="rb") 100 | 101 | def __get_local_filehandle(self, uri: ParseResult): 102 | fp = pathlib.Path(uri.path) 103 | 104 | try: 105 | if fp.suffix == ".gz": 106 | return xopen.xopen(fp, mode="rb", threads=self._threads) 107 | 108 | if fp.suffix == ".jsonl": 109 | return open(fp, mode="rb") 110 | except Exception as e: 111 | raise LocalReadError( 112 | f"__LOCAL_URI_READ_ERROR__ failed reading {uri.path}: " 113 | f"caught exception {e.__class__.__name__}: {e}" 114 | ) 115 | 116 | raise ValueError(f"File type of {fp} not supported.") 117 | -------------------------------------------------------------------------------- /app/src/utilities/io/s3.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | from botocore import UNSIGNED 3 | 4 | 5 | def init_client( 6 | endpoint_url: str, 7 | aws_access_key_id: str = None, 8 | aws_secret_access_key: str = None, 9 | signature_version: str = UNSIGNED 10 | ): 11 | return boto3.client( 12 | service_name='s3', 13 | aws_access_key_id=aws_access_key_id, 14 | aws_secret_access_key=aws_secret_access_key, 15 | endpoint_url=endpoint_url, 16 | config=boto3.session.Config( 17 | signature_version=signature_version, 18 | retries={ 19 | 'max_attempts': 5, # this is the default in standard mode 20 | 'mode': 'standard' 21 | } 22 | ) 23 | ) 24 | -------------------------------------------------------------------------------- /app/src/utilities/io/writer.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import msgspec 3 | import gzip 4 | import pyarrow as pa 5 | import pyarrow.parquet as pq 6 | from urllib.parse import urlparse 7 | import boto3 8 | 9 | from typing import Type, Any, Dict, List, Tuple, Optional 10 | 11 | 12 | class Writer: 13 | def __init__( 14 | self, 15 | uri: str, 16 | schema: List[Tuple[str, type]], 17 | s3_client: Optional[boto3.client] = None 18 | ): 19 | self._client = s3_client 20 | uri = urlparse(uri) 21 | 22 | if uri.scheme == "s3": 23 | raise NotImplementedError("streaming to S3 not supported yet") 24 | 25 | elif uri.scheme == "file": 26 | fp = pathlib.Path(uri.path) 27 | 28 | if not fp.parent.exists(): 29 | fp.parent.mkdir(parents=True, exist_ok=True) 30 | 31 | if fp.suffix == ".gz": 32 | self._filehandle = gzip.open(fp, mode="wb") 33 | elif fp.suffix == ".jsonl": 34 | self._filehandle = open(fp, mode="wb") 35 | else: 36 | raise ValueError(f"File type of {fp} not supported.") 37 | else: 38 | raise ValueError(f"Invalid uri: {uri}; must be of the form " 39 | f"s3:/// or file://") 40 | 41 | # encode records using msgspec 42 | self._encoder = msgspec.json.Encoder() 43 | self._buffer = bytearray(64) 44 | 45 | # define record struct 46 | self._record: Type[msgspec.Struct] = msgspec.defstruct( 47 | name="Record", fields=schema 48 | ) 49 | 50 | def write(self, data_obj: Dict[str, Any], flush: bool = False): 51 | self._encoder.encode_into(self._record(**data_obj), self._buffer) 52 | self._buffer.extend(b"\n") 53 | self._filehandle.write(self._buffer) 54 | 55 | if flush: 56 | self.flush() 57 | 58 | def close(self): 59 | self.flush() 60 | self._filehandle.close() 61 | 62 | def flush(self): 63 | self._filehandle.flush() 64 | self._buffer.clear() 65 | 66 | 67 | class ParquetBatchWriter: 68 | 69 | def __init__(self, output_fp, schema: pa.Schema): 70 | self._schema = schema 71 | self._writer = pq.ParquetWriter(output_fp, self._schema) 72 | self.__init_batch() 73 | 74 | def close(self): 75 | if len(self._batch[self._schema.names[0]]) > 0: 76 | self.write_batch() 77 | self._writer.close() 78 | 79 | def update_batch(self, obj: Dict[str, Any]): 80 | for col in self._schema.names: 81 | self._batch[col].append(obj[col]) 82 | 83 | def write_batch(self): 84 | self._writer.write_batch(batch=pa.record_batch( 85 | data=[ 86 | pa.array(self._batch[field.name], type=field.type) 87 | for field in self._schema 88 | ], 89 | schema=self._schema 90 | )) 91 | self.__init_batch() 92 | 93 | def __init_batch(self): 94 | self._batch = {col: [] for col in self._schema.names} 95 | -------------------------------------------------------------------------------- /app/src/utilities/logging/__init__.py: -------------------------------------------------------------------------------- 1 | from .format import LOG_FMT 2 | from .configure import configure_logger 3 | -------------------------------------------------------------------------------- /app/src/utilities/logging/configure.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | from typing import Optional 4 | 5 | from .format import LOG_FMT 6 | 7 | __all__ = [ 8 | "configure_logger", 9 | ] 10 | 11 | 12 | def configure_logger( 13 | logfile: Optional[Path] = None, level: int = logging.DEBUG, 14 | stream: bool = True 15 | ): 16 | root = logging.getLogger() 17 | formatter = logging.Formatter(LOG_FMT) 18 | 19 | # write to log file 20 | if logfile is not None: 21 | if not logfile.parent.exists(): 22 | logfile.parent.mkdir(parents=True) 23 | file_handler = logging.FileHandler(logfile) 24 | file_handler.setFormatter(formatter) 25 | root.addHandler(file_handler) 26 | 27 | # write to stdout 28 | if stream: 29 | stream_handler = logging.StreamHandler() 30 | stream_handler.setFormatter(formatter) 31 | root.addHandler(stream_handler) 32 | 33 | root.setLevel(level) 34 | -------------------------------------------------------------------------------- /app/src/utilities/logging/format.py: -------------------------------------------------------------------------------- 1 | LOG_FMT = '[%(asctime)s]::(PID %(process)d)::%(levelname)-2s::%(message)s' 2 | -------------------------------------------------------------------------------- /app/src/utilities/logging/mp.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from logging.handlers import QueueHandler 3 | import multiprocessing as mp 4 | from pathlib import Path 5 | from typing import Optional 6 | 7 | from .format import LOG_FMT 8 | 9 | __all__ = [ 10 | "configure_worker_logger", 11 | "configure_listener_logger", 12 | ] 13 | 14 | 15 | def configure_worker_logger( 16 | queue: Optional[mp.Queue] = None, level: int = logging.DEBUG 17 | ): 18 | root = logging.getLogger() 19 | 20 | if not root.hasHandlers() and queue is not None: 21 | h = logging.handlers.QueueHandler(queue) 22 | root.addHandler(h) 23 | 24 | root.setLevel(level) 25 | 26 | 27 | def configure_listener_logger( 28 | logfile: Optional[Path] = None, level: int = logging.DEBUG 29 | ): 30 | root = logging.getLogger() 31 | formatter = logging.Formatter(LOG_FMT) 32 | 33 | # write to log file 34 | if logfile is not None: 35 | if not logfile.parent.exists(): 36 | logfile.parent.mkdir(parents=True) 37 | file_handler = logging.FileHandler(logfile) 38 | file_handler.setFormatter(formatter) 39 | root.addHandler(file_handler) 40 | 41 | # write to stdout 42 | stream_handler = logging.StreamHandler() 43 | stream_handler.setFormatter(formatter) 44 | root.addHandler(stream_handler) 45 | 46 | root.setLevel(level) 47 | -------------------------------------------------------------------------------- /app/src/utilities/logging/trackers.py: -------------------------------------------------------------------------------- 1 | __all__ = ["RateTracker"] 2 | 3 | 4 | class RateTracker: 5 | def __init__(self, n=200): 6 | self._start_time_tracker = [] 7 | self._counts_tracker = [] 8 | self._n = n 9 | 10 | def update(self, count, start_time): 11 | if len(self._start_time_tracker) >= self._n: 12 | self._start_time_tracker.pop(0) 13 | self._counts_tracker.pop(0) 14 | 15 | self._start_time_tracker.append(start_time) 16 | self._counts_tracker.append(count) 17 | 18 | def get_rate(self, current_time: float): 19 | if len(self._start_time_tracker) == 0: 20 | return 0 21 | 22 | if current_time - self._start_time_tracker[0] < 1e-6: 23 | return 0 24 | 25 | start_time = self._start_time_tracker[0] 26 | pages = sum(self._counts_tracker) 27 | return pages / (current_time - start_time) 28 | 29 | def reset(self): 30 | self._start_time_tracker = [] 31 | self._counts_tracker = [] 32 | -------------------------------------------------------------------------------- /app/src/utilities/register/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/togethercomputer/RedPajama-Data/6d2cee9df2b0204dd2bcb00bf06b5a7b1d7432d7/app/src/utilities/register/__init__.py -------------------------------------------------------------------------------- /app/src/utilities/register/registry_utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import Tuple, List, Type 3 | 4 | from core.quality_signals.base import RPSBase 5 | 6 | __all__ = [ 7 | "get_callables_from_module", "signal_schema" 8 | ] 9 | 10 | _SIG_PREF = RPSBase.RPS_PREFIX 11 | 12 | 13 | def get_callables_from_module(module: object) -> List[Type[RPSBase]]: 14 | r""" Returns a list of signal class references that are defined in the 15 | module. 16 | 17 | Args: 18 | module: The module to search for signal classes, obtained via 19 | `sys.modules[__name__]`. 20 | 21 | Returns: 22 | A list of signal class references. 23 | """ 24 | 25 | def _sig_func_predicate(mem: object): 26 | return inspect.isclass(mem) and mem.__name__.startswith(_SIG_PREF) 27 | 28 | return [cls for _, cls in inspect.getmembers(module, _sig_func_predicate)] 29 | 30 | 31 | def signal_schema(module: object) -> List[Tuple[str, Type]]: 32 | r""" Returns a list of signal names and their data types, defining the 33 | schema for signals. """ 34 | return list(map( 35 | lambda cls: (cls.__name__.lower(), cls.DATA_TYPE), 36 | get_callables_from_module(module=module) 37 | )) 38 | -------------------------------------------------------------------------------- /app/src/utilities/text/__init__.py: -------------------------------------------------------------------------------- 1 | from .ngrams import form_ngrams 2 | from .normalization import normalize 3 | from .util import generate_paragraphs 4 | -------------------------------------------------------------------------------- /app/src/utilities/text/ngrams.py: -------------------------------------------------------------------------------- 1 | def form_ngrams(sequence, n): 2 | history = [] 3 | # build the first ngram, yielding only when we have a full ngram 4 | while n > 1: 5 | try: 6 | next_item = next(sequence) 7 | except StopIteration: 8 | # no more data, terminate the generator 9 | return 10 | history.append(next_item) 11 | n -= 1 12 | 13 | # yield each ngram we have, then add the next item and repeat 14 | for item in sequence: 15 | history.append(item) 16 | yield tuple(history) 17 | del history[0] 18 | -------------------------------------------------------------------------------- /app/src/utilities/text/normalization.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | import unicodedata 4 | 5 | TRANSLATION_TABLE_PUNCTUATION = str.maketrans("", "", string.punctuation) 6 | 7 | 8 | def normalize( 9 | text: str, 10 | remove_punct: bool = True, 11 | lowercase: bool = True, 12 | nfd_unicode: bool = True, 13 | white_space: bool = True 14 | ) -> str: 15 | """ Normalize the text by lowercasing and removing punctuation. """ 16 | # remove punctuation 17 | if remove_punct: 18 | text = text.translate(TRANSLATION_TABLE_PUNCTUATION) 19 | 20 | # lowercase 21 | if lowercase: 22 | text = text.lower() 23 | 24 | if white_space: 25 | text = text.strip() 26 | text = re.sub(r"\s+", " ", text) 27 | 28 | # NFD unicode normalization 29 | if nfd_unicode: 30 | text = unicodedata.normalize("NFD", text) 31 | 32 | return text 33 | -------------------------------------------------------------------------------- /app/src/utilities/text/util.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def generate_paragraphs(text: str, remove_empty: bool = True): 5 | for match in re.finditer(r"([^\n]*\n|[^\n]+$)", text): 6 | text_slice = text[match.start():match.end()] 7 | 8 | if remove_empty and not text_slice.strip(): 9 | continue 10 | 11 | yield text_slice 12 | -------------------------------------------------------------------------------- /configs/rp_v2.0.conf: -------------------------------------------------------------------------------- 1 | # run parameters 2 | DATA_ROOT="" 3 | ARTIFACTS_ID="rpv2" 4 | INPUT_BASE_URI="file://${DATA_ROOT}/" 5 | OUTPUT_BASE_URI="file://${DATA_ROOT}/rpv2" 6 | MAX_DOCS=-1 7 | LANGUAGES=("en" "fr" "es" "it" "de") 8 | 9 | # filename keep filters 10 | FILENAME_KEEP_PATTERNS=( 11 | ".*/[a-z]{2}_middle\.json\.gz" 12 | ".*/[a-z]{2}_head\.json\.gz" 13 | ) 14 | 15 | # General parameters used across steps 16 | S3_ENDPOINT_URL="" 17 | S3_BUCKET="" 18 | S3_CCNET_PREFIX="/rs_cc_net" 19 | S3_PROFILE="" 20 | 21 | # Docker 22 | DOCKER_S3_ENDPOINT_URL="" 23 | DOCKER_MNT_DIR="/mnt/data" 24 | DOCKER_REPO="" 25 | 26 | # Dedupe 27 | MINHASH_NGRAM_SIZE="13" 28 | MINHASH_NUM_PERMUTATIONS="128" 29 | MINHASH_SIMILARITIES=(1.0 0.9 0.8 0.7) 30 | 31 | # DSIR 32 | DSIR_NUM_SAMPLES=500000 33 | DSIR_FEATURE_DIM=10000 34 | 35 | # Classifiers 36 | CLASSIFIERS_NUM_SAMPLES=75000 37 | 38 | # sampling for books artifacts 39 | MAX_SAMPLES_PER_BOOK=1000 40 | MAX_PARAGRAPHS_PER_BOOK_SAMPLE=250 41 | 42 | # Others 43 | INPUTS_PER_PROCESS=20 # the number of files processed by one process at a time 44 | 45 | # domain blacklist categories 46 | DOMAIN_BLACKLIST_CATEGORIES=( 47 | "adult" 48 | "agressive" 49 | "agressif" 50 | "arjel" 51 | "chat" 52 | "dating" 53 | "ddos" 54 | "filehosting" 55 | "gambling" 56 | "porn" 57 | "mixed_adult" 58 | "phishing" 59 | "violence" 60 | ) 61 | 62 | # CC snapshot ids to process 63 | CC_SNAPSHOT_IDS=( 64 | "2014-15" 65 | "2014-23" 66 | "2014-35" 67 | "2014-41" 68 | "2014-42" 69 | "2014-49" 70 | "2014-52" 71 | "2015-14" 72 | "2015-22" 73 | "2015-27" 74 | "2015-32" 75 | "2015-35" 76 | "2015-40" 77 | "2015-48" 78 | "2016-07" 79 | "2016-18" 80 | "2016-22" 81 | "2016-26" 82 | "2016-30" 83 | "2016-36" 84 | "2016-40" 85 | "2016-44" 86 | "2016-50" 87 | "2017-04" 88 | "2017-09" 89 | "2017-17" 90 | "2017-22" 91 | "2017-26" 92 | "2017-30" 93 | "2017-34" 94 | "2017-39" 95 | "2017-43" 96 | "2017-47" 97 | "2017-51" 98 | "2018-05" 99 | "2018-09" 100 | "2018-13" 101 | "2018-17" 102 | "2018-22" 103 | "2018-26" 104 | "2018-30" 105 | "2018-34" 106 | "2018-39" 107 | "2018-43" 108 | "2018-47" 109 | "2018-51" 110 | "2019-04" 111 | "2019-09" 112 | "2019-13" 113 | "2019-18" 114 | "2019-22" 115 | "2019-26" 116 | "2019-30" 117 | "2019-35" 118 | "2019-39" 119 | "2019-43" 120 | "2019-47" 121 | "2019-51" 122 | "2020-05" 123 | "2020-10" 124 | "2020-16" 125 | "2020-24" 126 | "2020-29" 127 | "2020-34" 128 | "2020-40" 129 | "2020-45" 130 | "2020-50" 131 | "2021-04" 132 | "2021-10" 133 | "2021-17" 134 | "2021-21" 135 | "2021-25" 136 | "2021-31" 137 | "2021-39" 138 | "2021-43" 139 | "2021-49" 140 | "2022-05" 141 | "2022-21" 142 | "2022-27" 143 | "2022-33" 144 | "2022-40" 145 | "2022-49" 146 | "2023-06" 147 | "2023-14" 148 | ) -------------------------------------------------------------------------------- /docs/rpv2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/togethercomputer/RedPajama-Data/6d2cee9df2b0204dd2bcb00bf06b5a7b1d7432d7/docs/rpv2.png -------------------------------------------------------------------------------- /scripts/apptainer_run_lsh.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | trap cleanup_on_error ERR SIGINT SIGTERM 5 | 6 | cleanup_on_error() { 7 | echo "Error: $0:$LINENO: command \`$BASH_COMMAND\` failed with exit code $?" 8 | exit 1 9 | } 10 | 11 | while [[ $# -gt 0 ]]; do 12 | key="$1" 13 | case $key in 14 | --config) 15 | CONFIG_FILE="$2" 16 | shift 17 | shift 18 | ;; 19 | --input_base_uri) 20 | INPUT_BASE_URI="$2" 21 | shift 22 | shift 23 | ;; 24 | --output_dir) 25 | OUTPUT_DIR="$2" 26 | shift 27 | shift 28 | ;; 29 | --similarity) 30 | SIMILARITY="$2" 31 | shift 32 | shift 33 | ;; 34 | --listings) 35 | LISTINGS="$2" 36 | shift 37 | shift 38 | ;; 39 | --max_docs) 40 | MAX_DOCS="$2" 41 | shift 42 | shift 43 | ;; 44 | *) 45 | echo "Invalid option: -$OPTARG" >&2 46 | ;; 47 | esac 48 | done 49 | 50 | # make environment variables available to downstream scripts 51 | set -a 52 | # shellcheck source=configs/base.conf 53 | . "$CONFIG_FILE" 54 | set +a 55 | 56 | # run pipeline 57 | apptainer run --memory 480g "${DOCKER_REPO}" \ 58 | python3 src/run_lsh.py \ 59 | --listings "${LISTINGS}" \ 60 | --input_base_uri "${INPUT_BASE_URI}" \ 61 | --output_dir "${OUTPUT_DIR}" \ 62 | --similarity "${SIMILARITY}" \ 63 | --num_perm "${MINHASH_NUM_PERMUTATIONS}" \ 64 | --max_docs ${MAX_DOCS} 65 | -------------------------------------------------------------------------------- /scripts/apptainer_run_quality_signals.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | trap cleanup_on_error ERR SIGINT SIGTERM 5 | 6 | cleanup_on_error() { 7 | echo "Error: $0:$LINENO: command \`$BASH_COMMAND\` failed with exit code $?" 8 | exit 1 9 | } 10 | 11 | help() { 12 | echo "Usage: apptainer_run_quality_signals.sh [ -c | --config ] [ -d | --dump_id ]" 13 | exit 2 14 | } 15 | 16 | while [[ $# -gt 0 ]]; do 17 | key="$1" 18 | case $key in 19 | -c | --config) 20 | CONFIG_FILE="$2" 21 | shift 2 22 | ;; 23 | -d | --dump_id) 24 | DUMP_ID="$2" 25 | shift 2 26 | ;; 27 | -l | --listings) 28 | LISTINGS="$2" 29 | shift 2 30 | ;; 31 | -h | --help) 32 | help 33 | ;; 34 | --) 35 | shift 36 | break 37 | ;; 38 | *) 39 | echo "Invalid option: -$1" 40 | help 41 | ;; 42 | esac 43 | done 44 | 45 | # make environment variables available to downstream scripts 46 | set -a 47 | # shellcheck source=configs/base.conf 48 | . "$CONFIG_FILE" 49 | set +a 50 | 51 | if [ -z "${MAX_DOCS}" ]; then 52 | MAX_DOCS=-1 53 | fi 54 | 55 | ARTIFACTS_ARCHIVE="${DATA_ROOT%/}/artifacts-${ARTIFACTS_ID}.tar.gz" 56 | 57 | if [ ! -d "${DATA_ROOT%/}/artifacts-${ARTIFACTS_ID}" ]; then 58 | # download artifacts from bucket 59 | echo "Downloading artifacts from ${INPUT_BASE_URI%/}/artifacts-${ARTIFACTS_ID}.tar.gz" 60 | s5cmd --profile "$S3_PROFILE" --endpoint-url "$S3_ENDPOINT_URL" \ 61 | cp "${S3_BUCKET%/}/artifacts/artifacts-${ARTIFACTS_ID}.tar.gz" "${ARTIFACTS_ARCHIVE}" 62 | 63 | # extract artifacts 64 | mkdir -p "${DATA_ROOT%/}/artifacts-${ARTIFACTS_ID}" 65 | echo "Extracting artifacts to ${DATA_ROOT%/}/artifacts-${ARTIFACTS_ID}" 66 | tar -xzf "${ARTIFACTS_ARCHIVE}" -C "${DATA_ROOT%/}/artifacts-${ARTIFACTS_ID}" 67 | rm "${ARTIFACTS_ARCHIVE}" 68 | else 69 | echo "Artifacts already exist at ${DATA_ROOT%/}/artifacts-${ARTIFACTS_ID}; skipping download." 70 | fi 71 | 72 | 73 | # run pipeline 74 | ARTIFACTS_DIR="${DATA_ROOT%/}/artifacts-${ARTIFACTS_ID}" 75 | 76 | if [ -z "${LISTINGS}" ]; then 77 | LISTINGS="${ARTIFACTS_DIR%/}/listings/listings-${DUMP_ID}.txt" 78 | fi 79 | 80 | apptainer cache clean -f 81 | apptainer run \ 82 | --env AWS_ACCESS_KEY_ID="$AWS_ACCESS_KEY_ID" --env AWS_SECRET_ACCESS_KEY="$AWS_SECRET_ACCESS_KEY" \ 83 | "docker://docker.io/${DOCKER_REPO}:amd64" \ 84 | python3 /usr/app/src/pipeline.py \ 85 | --input "${LISTINGS}" \ 86 | --input_base_uri "${INPUT_BASE_URI}" \ 87 | --output_base_uri "${OUTPUT_BASE_URI}" \ 88 | --cc_snapshot_id "${DUMP_ID}" \ 89 | --artifacts_dir "${ARTIFACTS_DIR}" \ 90 | --dsir_buckets "${DSIR_FEATURE_DIM}" \ 91 | --max_docs "${MAX_DOCS}" \ 92 | --inputs_per_process "${INPUTS_PER_PROCESS}" \ 93 | --langs "${LANGUAGES[@]}" \ 94 | --endpoint_url "${S3_ENDPOINT_URL}" \ 95 | --minhash_ngram_size "${MINHASH_NGRAM_SIZE}" \ 96 | --minhash_num_permutations "${MINHASH_NUM_PERMUTATIONS}" \ 97 | --minhash_similarities "${MINHASH_SIMILARITIES[@]}" \ 98 | --filename_keep_patterns "${FILENAME_KEEP_PATTERNS[@]}" 99 | -------------------------------------------------------------------------------- /scripts/run_prep_artifacts.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | trap cleanup_on_error ERR SIGINT SIGTERM 5 | 6 | function cleanup_on_error() { 7 | echo "Error: $0:$LINENO: command \`$BASH_COMMAND\` failed with exit code $?" 8 | exit 1 9 | } 10 | 11 | while [[ $# -gt 0 ]]; do 12 | key="$1" 13 | case $key in 14 | --config) 15 | CONFIG_FILE="$2" 16 | shift 17 | shift 18 | ;; 19 | --listings) 20 | RAW_LISTINGS_FILE="$2" 21 | shift 22 | shift 23 | ;; 24 | --max_workers) 25 | MAX_WORKERS="$2" 26 | shift 27 | shift 28 | ;; 29 | *) 30 | echo "Invalid option: -$OPTARG" >&2 31 | ;; 32 | esac 33 | done 34 | 35 | set -a 36 | # shellcheck source=configs/base.conf 37 | . "$CONFIG_FILE" 38 | set +a 39 | 40 | # create random uuid if not provided 41 | RUN_ID=$(openssl rand -hex 4) 42 | echo "Created run id: $RUN_ID" 43 | 44 | ARTIFACTS_DIR="${DATA_ROOT%/}/artifacts-${RUN_ID}" 45 | LISTINGS_DIR="${ARTIFACTS_DIR%/}/listings" 46 | LISTINGS_FILE="${LISTINGS_DIR%/}/listings.txt" 47 | mkdir -p "${LISTINGS_DIR%/}" 48 | 49 | # write id to file if it doesn't exist 50 | if [ ! -f "${ARTIFACTS_DIR%/}/_RUN_ID" ]; then 51 | echo "Writing run id to file ${ARTIFACTS_DIR%/}/_RUN_ID" 52 | echo "$RUN_ID" >"${ARTIFACTS_DIR%/}/_RUN_ID" 53 | fi 54 | 55 | # fetch listings from s3 bucket if the listings file does not exist 56 | if [ ! -f "${RAW_LISTINGS_FILE}" ]; then 57 | echo "__FETCH_LISTINGS_START__ @ $(date) Fetching listings from s3 bucket..." 58 | s5cmd --profile "$S3_PROFILE" --endpoint-url "$S3_ENDPOINT_URL" \ 59 | ls "${S3_BUCKET%/}${S3_CCNET_PREFIX%/}/*" | 60 | grep "\.json\.gz$" | awk '{print $NF}' >"${LISTINGS_FILE}" 61 | echo "__FETCH_LISTINGS_END__ @ $(date) Done tetching listings from s3 bucket." 62 | else 63 | cp "${RAW_LISTINGS_FILE}" "${LISTINGS_FILE}" 64 | echo "copied listings file from ${RAW_LISTINGS_FILE} to ${LISTINGS_FILE}" 65 | fi 66 | 67 | # create a listings for each snapshot id if 68 | for snapshot_id in "${CC_SNAPSHOT_IDS[@]}"; do 69 | if grep "${snapshot_id}" "${LISTINGS_DIR%/}/listings.txt" >/dev/null 2>&1; then 70 | grep "${snapshot_id}" "${LISTINGS_DIR%/}/listings.txt" >"${LISTINGS_DIR%/}/listings-${snapshot_id}.txt" 71 | echo "__SNAPSHOT_LISTINGS_SUCCESS__ $snapshot_id" 72 | else 73 | echo "__SNAPSHOT_LISTINGS_FAIL__ $snapshot_id" 74 | fi 75 | done 76 | 77 | num_listings=$(wc -l <"${ARTIFACTS_DIR%/}/listings/listings.txt") 78 | echo "Toal number of listings: $num_listings" 79 | 80 | # copy config to artifacts dir 81 | cp "$CONFIG_FILE" "${ARTIFACTS_DIR%/}/config.conf" 82 | 83 | # Reset artifacts dir on docker mounted volume 84 | ARTIFACTS_DIR="${DOCKER_MNT_DIR%/}/artifacts-${RUN_ID}" 85 | for lang in "${LANGUAGES[@]}"; do 86 | echo "__LANG_PREP_START__ ${lang} @ $(date)" 87 | docker run --env AWS_ACCESS_KEY_ID="$AWS_ACCESS_KEY_ID" --env AWS_SECRET_ACCESS_KEY="$AWS_SECRET_ACCESS_KEY" \ 88 | -v "${DATA_ROOT%/}":"${DOCKER_MNT_DIR%/}" -t "${DOCKER_REPO}" \ 89 | python3 src/prep_artifacts.py \ 90 | --artifacts_dir "${ARTIFACTS_DIR%/}" \ 91 | --cc_input "${ARTIFACTS_DIR%/}/listings/listings.txt" \ 92 | --cc_input_base_uri "${S3_BUCKET%/}${S3_CCNET_PREFIX%/}" \ 93 | --cache_dir "${DOCKER_MNT_DIR%/}/.hf_cache" \ 94 | --lang "${lang}" \ 95 | --max_workers "${MAX_WORKERS}" \ 96 | --endpoint_url "$DOCKER_S3_ENDPOINT_URL" \ 97 | --dsir_num_samples "${DSIR_NUM_SAMPLES}" \ 98 | --dsir_feature_dim "${DSIR_FEATURE_DIM}" \ 99 | --classifiers_num_samples "${CLASSIFIERS_NUM_SAMPLES}" \ 100 | --max_paragraphs_per_book_sample "${MAX_PARAGRAPHS_PER_BOOK_SAMPLE}" \ 101 | --max_samples_per_book "${MAX_SAMPLES_PER_BOOK}" 102 | echo "__LANG_PREP_END__ ${lang} @ $(date)" 103 | done 104 | 105 | echo "__UPDATE_CONENTLISTS_START__ @ $(date)" 106 | docker run -v "${DATA_ROOT%/}":"${DOCKER_MNT_DIR%/}" -t "${DOCKER_REPO}" \ 107 | python3 src/artifacts/update_resources.py \ 108 | --langs "${LANGUAGES[@]}" \ 109 | --artifacts_dir "${ARTIFACTS_DIR%/}" \ 110 | --block_categories "${DOMAIN_BLACKLIST_CATEGORIES[@]}" 111 | 112 | echo "__UPDATE_CONENTLISTS_END__ @ $(date)" 113 | 114 | # package artifacts 115 | echo "__PACKAGE_ARTIFACTS_START__ @ $(date)" 116 | ARTIFACTS_DIR="${DATA_ROOT%/}/artifacts-${RUN_ID}" 117 | EXPORT_ARTIFACTS="${DATA_ROOT%/}/_EXPORT_artifacts-${RUN_ID}" 118 | mkdir -p "${EXPORT_ARTIFACTS%/}" 119 | 120 | # copy wikiref model to artifacts dir 121 | cp "${DATA_ROOT%/}/wikiref-models/en/en-model.bin" "${ARTIFACTS_DIR%/}/classifiers/en/wikiref.model.bin" 122 | 123 | # move artifacts to export 124 | cp -r "${ARTIFACTS_DIR%/}/dsir" "${EXPORT_ARTIFACTS%/}/" 125 | cp -r "${ARTIFACTS_DIR%/}/classifiers" "${EXPORT_ARTIFACTS%/}/" 126 | cp -r "${ARTIFACTS_DIR%/}/bad_words" "${EXPORT_ARTIFACTS%/}/" 127 | cp -r "${ARTIFACTS_DIR%/}/bad_urls" "${EXPORT_ARTIFACTS%/}/" 128 | cp -r "${ARTIFACTS_DIR%/}/listings" "${EXPORT_ARTIFACTS%/}/" 129 | cp -r "${ARTIFACTS_DIR%/}/_RUN_ID" "${EXPORT_ARTIFACTS%/}/" 130 | cp -r "${ARTIFACTS_DIR%/}/logs" "${EXPORT_ARTIFACTS%/}/" 131 | 132 | # package artifacts 133 | tar -czf "${EXPORT_ARTIFACTS%/}.tar.gz" -C "${EXPORT_ARTIFACTS%/}" . 134 | echo "__PACKAGE_ARTIFACTS_END__ @ $(date)" 135 | --------------------------------------------------------------------------------