├── YiZhao_technical_report.pdf ├── 7_DataAnalysis ├── resources │ ├── HIT.jfif │ ├── simsun.ttf │ └── hit_stopwords.txt ├── eval_pipeline.py ├── corpus_eval_visulization.py └── corpus_evaluator.py ├── 2_toxic_filter ├── sensitive_words │ └── violence.txt └── 2_toxic_filter.py ├── 5_text_dedup ├── clean_helpers │ ├── __init__.py │ ├── concatenation.py │ ├── utils.py │ └── deduplication.py └── 5_clean.py ├── requirements.txt ├── 6_text_dedup └── text_dedup │ ├── __init__.py │ ├── utils │ ├── __init__.py │ ├── preprocess.py │ ├── union_find.py │ ├── tokenization.py │ ├── analysis.py │ ├── timer.py │ └── add_args.py │ └── minhash.py ├── 4_perplexity_filter └── kenlm │ ├── run.py │ └── model.py ├── 3_rule_filter.py ├── README.md └── 1_pii.py /YiZhao_technical_report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HITsz-TMG/YiZhao/HEAD/YiZhao_technical_report.pdf -------------------------------------------------------------------------------- /7_DataAnalysis/resources/HIT.jfif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HITsz-TMG/YiZhao/HEAD/7_DataAnalysis/resources/HIT.jfif -------------------------------------------------------------------------------- /7_DataAnalysis/resources/simsun.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HITsz-TMG/YiZhao/HEAD/7_DataAnalysis/resources/simsun.ttf -------------------------------------------------------------------------------- /2_toxic_filter/sensitive_words/violence.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HITsz-TMG/YiZhao/HEAD/2_toxic_filter/sensitive_words/violence.txt -------------------------------------------------------------------------------- /5_text_dedup/clean_helpers/__init__.py: -------------------------------------------------------------------------------- 1 | from .deduplication import build_dedup_template, build_dedup_document 2 | from .concatenation import concatenate_lm_fr_ester -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | regex==2024.9.11 2 | datasets==3.0.0 3 | chardet==5.2.0 4 | ftfy==6.2.3 5 | langdetect==1.0.9 6 | opencc==1.1.9 7 | kenlm==0.2.0 8 | sentencepiece==0.2.0 9 | jsonlines==2.0.0 10 | torch==2.2.2 11 | scipy==1.12.0 12 | rich==13.7.1 13 | tiktoken==0.7.0 14 | openai==1.46.1 15 | matplotlib==3.9.2 16 | seaborn==0.13.2 17 | jieba==0.42.1 18 | wordcloud==1.9.3 19 | pandas==2.2.2 20 | numpy==1.26.4 -------------------------------------------------------------------------------- /6_text_dedup/text_dedup/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2021-06-05 12:48:33 4 | # @Author : Chenghao Mou (mouchenghao@gmail.com) 5 | 6 | """Text deduplication simplified.""" 7 | 8 | import logging 9 | 10 | from rich.logging import RichHandler 11 | 12 | logger = logging.getLogger("text_dedup") 13 | logger.setLevel(logging.INFO) 14 | logger.addHandler(RichHandler(rich_tracebacks=True)) 15 | logger.propagate = False 16 | -------------------------------------------------------------------------------- /5_text_dedup/clean_helpers/concatenation.py: -------------------------------------------------------------------------------- 1 | from itertools import groupby 2 | 3 | from datasets import Dataset 4 | 5 | from clean_helpers.utils import parse_meta 6 | 7 | 8 | def concatenate_lm_fr_ester(ds: Dataset, num_proc: int, batch_size: int) -> Dataset: 9 | dataset_in_memory = [ 10 | (*parse_meta(row["meta"])["id"].split("_id_"), row["text"]) for row in ds 11 | ] 12 | dataset_in_memory.sort() 13 | new_texts = [] 14 | new_metas = [] 15 | for doc_id, segments in groupby(dataset_in_memory, key=lambda x: x[0]): 16 | sorted_segment = sorted( 17 | [elt[1:] for elt in segments], 18 | key=lambda x: int(x[0]) 19 | ) 20 | new_texts.append("\n".join([elt[1] for elt in sorted_segment])) 21 | new_metas.append({"id": doc_id}) 22 | 23 | new_ds = Dataset.from_dict({"text": new_texts, "meta": new_metas}) 24 | return new_ds 25 | -------------------------------------------------------------------------------- /6_text_dedup/text_dedup/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2022-12-26 15:42:09 4 | # @Author : Chenghao Mou (mouchenghao@gmail.com) 5 | 6 | from utils.add_args import add_bloom_filter_args 7 | from utils.add_args import add_exact_hash_args 8 | from utils.add_args import add_io_args 9 | from utils.add_args import add_meta_args 10 | from utils.add_args import add_minhash_args 11 | from utils.add_args import add_sa_args 12 | from utils.add_args import add_simhash_args 13 | from utils.timer import Timer 14 | from utils.tokenization import ngrams 15 | from utils.union_find import UnionFind 16 | 17 | __all__ = [ 18 | "add_bloom_filter_args", 19 | "add_exact_hash_args", 20 | "add_io_args", 21 | "add_meta_args", 22 | "add_minhash_args", 23 | "add_sa_args", 24 | "add_simhash_args", 25 | "Timer", 26 | "ngrams", 27 | "UnionFind", 28 | ] 29 | -------------------------------------------------------------------------------- /5_text_dedup/clean_helpers/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Dict 3 | 4 | 5 | def parse_meta(meta) -> Dict: 6 | if isinstance(meta, str): 7 | meta = eval(meta) 8 | return meta 9 | 10 | 11 | normalise_dataset_name_regex = re.compile( 12 | r"^(?:/gpfswork/rech/six/uty16tp/dataset/tokenization/)?(bigscience-catalogue-lm-data/[^/]+)(?:/data)?$" 13 | ) 14 | 15 | 16 | language_regex = re.compile( 17 | r"^(?:/gpfswork/rech/six/uty16tp/dataset/tokenization/)?bigscience-catalogue-lm-data/lm_([^_]+)_.*(?:/data)?$" 18 | ) 19 | def get_language(dataset_name: str): 20 | lang_candidate = language_regex.match(dataset_name).group(1) 21 | 22 | # Normalise chinese languages, so that we only consider simplified and traditional chinese as the two chinese languages 23 | if lang_candidate in ["zh", "zhs", "zh-cn"]: 24 | lang_candidate = "zhs" 25 | elif lang_candidate in ["zht", "zh-tw"]: 26 | lang_candidate = "zht" 27 | else: 28 | assert lang_candidate[:2] != "zh" 29 | 30 | return lang_candidate 31 | -------------------------------------------------------------------------------- /6_text_dedup/text_dedup/utils/preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2023-05-06 19:39:27 4 | # @Author : Chenghao Mou (mouchenghao@gmail.com) 5 | 6 | import regex as re 7 | 8 | DIGIT_RE = re.compile(r"\d") 9 | PUNCT_OR_NON_PRINTING_CHARS_RE = re.compile(r"[\p{P}\p{C}\p{S}]+") 10 | 11 | def normalize(line: str) -> str: 12 | """ 13 | Normalize a line of text. Source: https://github.com/facebookresearch/cc_net/blob/bda555bd1cf1ee2e0b925363e62a61cd46c8b60d/cc_net/text_normalizer.py#L180 14 | 15 | Parameters 16 | ---------- 17 | line : str 18 | The line of text to normalize. 19 | 20 | Returns 21 | ------- 22 | str 23 | The normalized line of text. 24 | 25 | Examples 26 | -------- 27 | >>> normalize("Hello, world!") 28 | 'hello world' 29 | >>> normalize("Hello, 123!\\n\\t\\b") 30 | 'hello 000' 31 | """ 32 | line = line.strip() 33 | if not line: 34 | return line 35 | line = line.lower() 36 | line = DIGIT_RE.sub("0", line) 37 | line = PUNCT_OR_NON_PRINTING_CHARS_RE.sub("", line) 38 | return line 39 | -------------------------------------------------------------------------------- /6_text_dedup/text_dedup/utils/union_find.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2022-12-26 15:37:44 4 | # @Author : Chenghao Mou (mouchenghao@gmail.com) 5 | 6 | 7 | class UnionFind: 8 | """ 9 | A data structure for maintaining disjoint sets. This helps build connected components for given duplicate pairs. 10 | 11 | Examples 12 | -------- 13 | >>> uf = UnionFind() 14 | >>> uf.union(1, 2) 15 | >>> uf.union(2, 3) 16 | >>> uf.union(4, 5) 17 | >>> uf.find(1) 18 | 1 19 | >>> uf.find(2) 20 | 1 21 | >>> uf.find(3) 22 | 1 23 | >>> uf.find(4) 24 | 4 25 | >>> uf.find(5) 26 | 4 27 | """ 28 | 29 | def __init__(self): 30 | self.parent = {} 31 | 32 | def find(self, x): 33 | if x not in self.parent: 34 | self.parent[x] = x 35 | return x 36 | 37 | if self.parent[x] != x: 38 | self.parent[x] = self.find(self.parent[x]) 39 | 40 | return self.parent[x] 41 | 42 | def union(self, x, y): 43 | px = self.find(x) 44 | py = self.find(y) 45 | self.parent[px] = self.parent[py] = min(px, py) 46 | -------------------------------------------------------------------------------- /6_text_dedup/text_dedup/utils/tokenization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2022-12-26 15:59:42 4 | # @Author : Chenghao Mou (mouchenghao@gmail.com) 5 | from itertools import tee 6 | from typing import List 7 | from typing import Text 8 | 9 | 10 | def ngrams(sequence: List[Text], n: int, min_length: int = 5): 11 | """ 12 | Return the ngrams generated from a sequence of items, as an iterator. 13 | 14 | This is a modified version of nltk.util.ngrams. 15 | 16 | Parameters 17 | ---------- 18 | sequence : List[Text] 19 | The sequence of items. 20 | n : int 21 | The length of each ngram. 22 | min_length : int, optional 23 | The minimum length of each ngram, by default 5 24 | 25 | Returns 26 | ------- 27 | iterator 28 | The ngrams. 29 | 30 | Examples 31 | -------- 32 | >>> list(ngrams(["a", "b", "c", "d"], 2, min_length=1)) 33 | [('a', 'b'), ('b', 'c'), ('c', 'd')] 34 | >>> list(ngrams(["a", "b", "c", "d"], 2, min_length=5)) 35 | [] 36 | >>> list(ngrams(["a", "b"], 3, min_length=1)) 37 | [('a', 'b')] 38 | """ 39 | if len(sequence) < min_length: 40 | return [] 41 | if len(sequence) < n: 42 | return [tuple(sequence)] 43 | iterables = tee(iter(sequence), n) 44 | for i, sub_iterable in enumerate(iterables): 45 | for _ in range(i): 46 | next(sub_iterable, None) 47 | return zip(*iterables) 48 | -------------------------------------------------------------------------------- /6_text_dedup/text_dedup/utils/analysis.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2023-01-02 15:18:55 4 | # @Author : Chenghao Mou (mouchenghao@gmail.com) 5 | from typing import List 6 | 7 | from text_dedup.utils.tokenization import ngrams 8 | 9 | 10 | def jaccard_similarity( 11 | doc1: str | List[str], 12 | doc2: str | List[str], 13 | ngram_size: int = 8, 14 | min_length: int = 0, 15 | ) -> float: 16 | """Compute the Jaccard similarity between two documents. 17 | 18 | Parameters 19 | ---------- 20 | doc1 : str or List[str] 21 | The first document. 22 | doc2 : str or List[str] 23 | The second document. 24 | ngram_size : int, optional 25 | The size of n-grams, by default 8 26 | min_length : int, optional 27 | The minimum length of each n-gram, by default 0 28 | 29 | Returns 30 | ------- 31 | float 32 | The Jaccard similarity. 33 | 34 | Examples 35 | -------- 36 | >>> jaccard_similarity("hello world", "hello world") 37 | 1.0 38 | >>> jaccard_similarity("hello world", "hello world!") 39 | 0.8 40 | >>> jaccard_similarity("hello world".split(), "hello world!".split(), ngram_size=1) 41 | 0.3333333333333333 42 | """ 43 | words1 = set(" ".join(ng) for ng in ngrams(list(doc1), ngram_size, min_length=min_length)) 44 | words2 = set(" ".join(ng) for ng in ngrams(list(doc2), ngram_size, min_length=min_length)) 45 | return len(words1 & words2) / max(1, len(words1 | words2)) 46 | -------------------------------------------------------------------------------- /7_DataAnalysis/eval_pipeline.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from corpus_evaluator import corpus_quality_measure_fn 4 | from corpus_eval_visulization import scores_visualization 5 | 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--random_seed", type=int, default=1234) 10 | 11 | parser.add_argument("--data_path", type=str) 12 | parser.add_argument("--data_num", type=int, default=10) 13 | parser.add_argument("--text_column", type=str) 14 | parser.add_argument("--tiktoken_cache", type=str) 15 | 16 | parser.add_argument("--eval_path", type=str) 17 | 18 | parser.add_argument("--figure_dir", type=str) 19 | 20 | parser.add_argument("--model", type=str, default="gpt-3.5-turbo-1106") 21 | parser.add_argument("--api_key", type=str) 22 | parser.add_argument("--organization", type=str) 23 | parser.add_argument("--num_proc", type=int, default=1) 24 | args = parser.parse_args() 25 | 26 | # args.eval_path = args.figure_dir + "/result.jsonl" 27 | 28 | tiktoken_cache_dir = args.tiktoken_cache 29 | os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir 30 | 31 | corpus = corpus_quality_measure_fn( 32 | data_path=args.data_path, 33 | eval_path=args.eval_path, 34 | data_num=args.data_num, 35 | text_column=args.text_column, 36 | model=args.model, 37 | api_key=args.api_key, 38 | organization=args.organization, 39 | num_proc=args.num_proc,) 40 | 41 | scores_visualization(corpus, args.text_column, args.figure_dir) 42 | 43 | -------------------------------------------------------------------------------- /4_perplexity_filter/kenlm/run.py: -------------------------------------------------------------------------------- 1 | from model import KenlmModel 2 | import json 3 | import jsonlines 4 | import argparse 5 | from tqdm import tqdm 6 | 7 | def save_jsonl(data, output_path): 8 | with open(output_path, 'w', encoding='utf-8') as output_file: 9 | for item in data: 10 | output_file.write(json.dumps(item, ensure_ascii=False) + "\n") 11 | 12 | def read_jsonl(input_path): 13 | output_data = [] 14 | with open(input_path, 'r+', encoding='utf-8') as f: 15 | for item in jsonlines.Reader(f): 16 | output_data.append(item) 17 | return output_data 18 | 19 | def perplexity_filter(input_path, output_path): 20 | input_data = read_jsonl(input_path) 21 | filtered_data = [] 22 | 23 | for tmp in tqdm(input_data): 24 | score = model.get_perplexity(tmp[args.text_column]) 25 | if score <= 2095: 26 | filtered_data.append(tmp) 27 | 28 | save_jsonl(filtered_data, output_path) 29 | 30 | 31 | 32 | if __name__=='__main__': 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument( 35 | "--input_path", 36 | type=str, 37 | help="Path to input file(jsonl).", 38 | ) 39 | parser.add_argument( 40 | "--output_path", 41 | type=str, 42 | help="Path to output file(jsonl).", 43 | ) 44 | parser.add_argument('--text_column', type=str) 45 | parser.add_argument('--language', type=str, help="zh or en") 46 | args = parser.parse_args() 47 | 48 | # model taken from https://huggingface.co/edugp/kenlm 49 | model = KenlmModel.from_pretrained("kenlm/wikipedia", args.language) 50 | perplexity_filter(args.input_path, args.output_path) 51 | print('Perplexity Filter Done!') 52 | 53 | -------------------------------------------------------------------------------- /6_text_dedup/text_dedup/utils/timer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2022-12-26 15:45:46 4 | # @Author : Chenghao Mou (mouchenghao@gmail.com) 5 | import time 6 | 7 | 8 | class TimerContext: 9 | def __init__(self, timer: "Timer", name: str): 10 | self.timer = timer 11 | self.name = name 12 | self.start_time = None 13 | 14 | def __enter__(self): 15 | self.start_time = time.time() 16 | 17 | def __exit__(self, exc_type, exc_val, exc_tb): 18 | if any([exc_type, exc_val, exc_tb]): 19 | raise exc_val 20 | self.timer.elapsed_times[self.name] = time.time() - self.start_time 21 | 22 | 23 | class Timer: 24 | """ 25 | A simple timer that tracks the elapsed time of each context. 26 | 27 | Examples 28 | -------- 29 | >>> t = Timer() 30 | >>> with t("test"): 31 | ... time.sleep(1) 32 | >>> assert int(t.elapsed_times.get("test", 0)) >= 1, "The elapsed time should be 1 second." 33 | """ 34 | 35 | def __init__(self): 36 | self.elapsed_times = {} 37 | 38 | def __call__(self, name: str) -> TimerContext: 39 | """ 40 | Create a context with the given name. 41 | 42 | Parameters 43 | ---------- 44 | name: str 45 | The name of the context. 46 | 47 | Returns 48 | ------- 49 | TimerContext 50 | The context. 51 | 52 | Examples 53 | -------- 54 | >>> t = Timer() 55 | >>> with t("test"): 56 | ... time.sleep(1) 57 | >>> assert int(t.elapsed_times.get("test", 0)) == 1, "The elapsed time should be 1 second." 58 | >>> with t("test2"): 59 | ... time.sleep(2) 60 | >>> assert int(t.elapsed_times.get("test2", 0)) == 2, "The elapsed time should be 2 seconds." 61 | """ 62 | return TimerContext(self, name) 63 | -------------------------------------------------------------------------------- /2_toxic_filter/2_toxic_filter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import chardet 4 | import argparse 5 | from pathlib import Path 6 | 7 | def load_jsonl(path): 8 | with open(path, 'r', encoding='UTF-8') as f: 9 | return [json.loads(l) for l in f] 10 | 11 | class CorpusFilter: 12 | def __init__(self, directory_path): 13 | self.sensitive_keywords = self.load_sensitive_keywords(directory_path) 14 | 15 | def detect_encoding(self, file_path): 16 | with open(file_path, 'rb') as file: 17 | raw_data = file.read(5000) 18 | result = chardet.detect(raw_data) 19 | encoding = result['encoding'] 20 | return encoding 21 | 22 | def load_sensitive_keywords(self, directory_path): 23 | # Load sensitive keywords from all .txt files in the specified directory 24 | sensitive_keywords = set() 25 | for filename in os.listdir(directory_path): 26 | if filename.endswith('.txt'): 27 | file_path = os.path.join(directory_path, filename) 28 | encoding = self.detect_encoding(file_path) 29 | with open(file_path, 'r', encoding=encoding) as file: 30 | for line in file: 31 | keyword = line.strip().rstrip(',') 32 | if keyword: 33 | sensitive_keywords.add(keyword) 34 | return list(sensitive_keywords) 35 | 36 | def is_sensitive(self, text): 37 | for keyword in self.sensitive_keywords: 38 | if keyword in text: 39 | return True 40 | return False 41 | 42 | def filter_corpus(self, input_file_path, output_file_path): 43 | with open(input_file_path, 'r', encoding='utf-8') as input_file, \ 44 | open(output_file_path, 'w', encoding='utf-8') as output_file: 45 | for line in input_file: 46 | try: 47 | data = json.loads(line) 48 | text = data.get(args.text_column, '') 49 | if not self.is_sensitive(text): 50 | output_file.write(json.dumps(data, ensure_ascii=False) + '\n') 51 | except json.JSONDecodeError: 52 | continue # Ignore lines with parsing errors 53 | 54 | 55 | 56 | if __name__ == '__main__': 57 | parser = argparse.ArgumentParser() 58 | # The default input and output are jsonl files 59 | parser.add_argument('--input_path', type=str) 60 | parser.add_argument('--output_path', type=str) 61 | parser.add_argument('--text_column', type=str) 62 | args = parser.parse_args() 63 | 64 | directory_path = Path(__file__).parent / "sensitive_words" 65 | filter = CorpusFilter(directory_path) 66 | data = load_jsonl(args.input_path) 67 | filter.filter_corpus(args.input_path, args.output_path) 68 | 69 | -------------------------------------------------------------------------------- /7_DataAnalysis/corpus_eval_visulization.py: -------------------------------------------------------------------------------- 1 | import os 2 | import platform 3 | from pathlib import Path 4 | from collections import defaultdict 5 | 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | import seaborn as sns 9 | import pandas as pd 10 | from pylab import * 11 | import jieba 12 | from wordcloud import WordCloud, ImageColorGenerator, STOPWORDS 13 | from PIL import Image 14 | 15 | 16 | sns.set_palette("hls") 17 | 18 | from matplotlib.font_manager import FontProperties 19 | system = platform.system() 20 | 21 | font = FontProperties(fname='fonts/opentype/noto/NotoSerifCJK-Black.ttc') 22 | 23 | 24 | def get_all_scores(corpus): 25 | scores_dict = defaultdict(list) 26 | 27 | for obj in corpus: 28 | quality = obj["quality"] 29 | for aspect, result in quality.items(): 30 | if result["score"] >= 0: 31 | scores_dict[aspect].append(result["score"]) 32 | 33 | print("{:6s} | {:5s} | {:5s} | {:5s} | {:5s} | {} ".format("", "Mean", "Std", "Min", "Max", "Count")) 34 | for aspect, score_list in scores_dict.items(): 35 | mean_score = np.mean(score_list) 36 | std_score = np.std(score_list) 37 | min_screo = min(score_list) 38 | max_score = max(score_list) 39 | print("{:5s} | {:5.2f} | {:5.2f} | {:5d} | {:5d} | {}".format(aspect, mean_score, std_score, min_screo, max_score, len(score_list))) 40 | 41 | return scores_dict 42 | 43 | 44 | def get_wordcloud(corpus, text_column, figure_dir): 45 | text_list = [obj[text_column] for obj in corpus] 46 | text = "\n".join(text_list) 47 | 48 | wordlist = jieba.cut(text) 49 | wordlist = [w for w in wordlist if len(w) > 1] 50 | space_list = ' '.join(wordlist) 51 | 52 | backgroud = np.array(Image.open(Path(__file__).parent / "resources/HIT.jfif")) 53 | 54 | with open(Path(__file__).parent / "resources/hit_stopwords.txt", "r", encoding="utf-8") as f: 55 | stopwords = [w.rstrip() for w in f.readlines()] 56 | 57 | wc = WordCloud(width=1400, height=2200, 58 | background_color='white', 59 | mode='RGB', 60 | mask=backgroud, 61 | max_words=500, 62 | stopwords=STOPWORDS.update(stopwords), 63 | max_font_size=150, 64 | relative_scaling=0.6, 65 | random_state=50, 66 | scale=2, 67 | font_path=str(Path(__file__).parent / "resources/simsun.ttf"), 68 | ).generate(space_list) 69 | 70 | image_color = ImageColorGenerator(backgroud) 71 | wc.recolor(color_func=image_color) 72 | 73 | plt.imshow(wc) 74 | plt.axis('off') 75 | plt.show() 76 | wc.to_file(os.path.join(figure_dir, "wordcloud.png")) 77 | 78 | 79 | def get_plot(scores_dict: dict, figure_dir: str): 80 | sns.set_palette("hls") 81 | fig, axs = plt.subplots(1, 5, figsize=(25, 5)) 82 | 83 | color_list = ["#FF55BB", "#00DFA2", "#FFD3A3", "#0079FF", "#F6FA70"] 84 | for i, (aspect, score_list) in enumerate(scores_dict.items()): 85 | 86 | sns.histplot(score_list, bins=10, color=color_list[i], ax=axs[i]) 87 | sns.kdeplot(score_list, color="seagreen", lw=3, ax=axs[i]) 88 | # sns.distplot(score_list, bins=10, kde_kws={"color": "seagreen", "lw": 3}, hist_kws={"color": color_list[i]}, ax=axs[i]) 89 | axs[i].set_title(aspect, fontproperties=font) 90 | 91 | plt.savefig(os.path.join(figure_dir, "quality_hist.png")) 92 | plt.show() 93 | 94 | 95 | def scores_visualization(corpus, text_column, figure_dir): 96 | get_wordcloud(corpus, text_column, figure_dir) 97 | scores_dict = get_all_scores(corpus) 98 | get_plot(scores_dict, figure_dir) -------------------------------------------------------------------------------- /3_rule_filter.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import ftfy 4 | import regex 5 | from langdetect import detect 6 | from tqdm import tqdm 7 | import opencc 8 | 9 | def load_jsonl(path): 10 | with open(path, 'r', encoding='UTF-8') as f: 11 | return [json.loads(l) for l in f] 12 | 13 | class RuleFilter: 14 | def __init__(self): 15 | self.OPENCC_CONVERTER = opencc.OpenCC('t2s.json') 16 | self.punctuation_unicode = { 17 | ',': ',', 18 | '。': '.', 19 | '、': ',', 20 | '„': '"', 21 | '”': '"', 22 | '“': '"', 23 | '«': '"', 24 | '»': '"', 25 | '1': '"', 26 | '」': '"', 27 | '「': '"', 28 | '《': '"', 29 | '》': '"', 30 | '´': "'", 31 | '∶': ':', 32 | ':': ':', 33 | '?': '?', 34 | '!': '!', 35 | '(': '(', 36 | ')': ')', 37 | ';': ';', 38 | '–': '-', 39 | '—': ' - ', 40 | '.': '. ', 41 | '~': '~', 42 | '’': "'", 43 | '…': '...', 44 | '━': '-', 45 | '〈': '<', 46 | '〉': '>', 47 | '【': '[', 48 | '】': ']', 49 | '%': '%', 50 | '►': '-', 51 | } 52 | self.various_whitespaces = { 53 | ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', 54 | ' ', ' ', ' ', ' ', '', '', '', '', '', '' 55 | } 56 | 57 | def handle(self, text): 58 | # unicode 59 | text = ftfy.fix_text(text, normalization="NFC") 60 | # language filter 61 | if detect(text) != args.language: 62 | return None 63 | 64 | # Standardization of Punctuation 65 | text = ''.join([ 66 | self.punctuation_unicode.get(c, c) for c in text 67 | ]) 68 | # Standardization of Whitespace 69 | text = ''.join([ 70 | char if char not in self.various_whitespaces else ' ' for char in text 71 | ]) 72 | 73 | # Replace all matched consecutive punctuation with a single punctuation 74 | pattern = r'(\p{P})\1+' 75 | text = regex.sub(pattern, r'\1', text) 76 | text = text.strip() 77 | 78 | # Filter out texts with too high a punctuation ratio and too short a text length 79 | punctuation_count = len(regex.findall(r'\p{P}', text)) 80 | total_chars = len(text) 81 | punctuation_ratio = punctuation_count / total_chars 82 | if punctuation_ratio > args.punctuation_ratio_threshold or len(text) < args.text_length_threshold: 83 | return None 84 | 85 | 86 | # Convert Traditional Chinese Characters to Simplified Chinese 87 | return self.OPENCC_CONVERTER.convert(text) 88 | 89 | def filter(self, input_file_path, output_file_path): 90 | with open(input_file_path, 'r', encoding='utf-8') as input_file, \ 91 | open(output_file_path, 'w', encoding='utf-8') as output_file: 92 | for line in input_file: 93 | try: 94 | data = json.loads(line) 95 | text = data.get(args.text_column, '') 96 | result = self.handle(text) 97 | if result: 98 | data[args.text_column] = result 99 | output_file.write(json.dumps(data, ensure_ascii=False) + '\n') 100 | except json.JSONDecodeError: 101 | continue # Ignore lines with parsing errors 102 | 103 | 104 | 105 | if __name__ == '__main__': 106 | parser = argparse.ArgumentParser() 107 | # The default input and output are jsonl files 108 | parser.add_argument('--input_path', type=str) 109 | parser.add_argument('--output_path', type=str) 110 | parser.add_argument('--text_column', type=str) 111 | parser.add_argument('--language', type=str) 112 | parser.add_argument('--punctuation_ratio_threshold', type=float, default=0.5) 113 | parser.add_argument('--text_length_threshold', type=int, default=128) 114 | args = parser.parse_args() 115 | 116 | filter = RuleFilter() 117 | filter.filter(args.input_path, args.output_path) 118 | -------------------------------------------------------------------------------- /4_perplexity_filter/kenlm/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import unicodedata 4 | from typing import Dict 5 | 6 | import kenlm 7 | import sentencepiece 8 | from huggingface_hub import cached_download, hf_hub_url 9 | 10 | 11 | class SentencePiece: 12 | def __init__( 13 | self, 14 | model: str, 15 | ): 16 | super().__init__() 17 | self.sp = sentencepiece.SentencePieceProcessor() 18 | self.sp.load(str(model)) 19 | 20 | def do(self, text: dict) -> dict: 21 | tokenized = self.sp.encode_as_pieces(text) 22 | return " ".join(tokenized) 23 | 24 | 25 | class KenlmModel: 26 | digit_re: re.Pattern = re.compile(r"\d") 27 | unicode_punct: Dict[str, str] = { 28 | ",": ",", 29 | "。": ".", 30 | "、": ",", 31 | "„": '"', 32 | "”": '"', 33 | "“": '"', 34 | "«": '"', 35 | "»": '"', 36 | "1": '"', 37 | "」": '"', 38 | "「": '"', 39 | "《": '"', 40 | "》": '"', 41 | "´": "'", 42 | "∶": ":", 43 | ":": ":", 44 | "?": "?", 45 | "!": "!", 46 | "(": "(", 47 | ")": ")", 48 | ";": ";", 49 | "–": "-", 50 | "—": " - ", 51 | ".": ". ", 52 | "~": "~", 53 | "’": "'", 54 | "…": "...", 55 | "━": "-", 56 | "〈": "<", 57 | "〉": ">", 58 | "【": "[", 59 | "】": "]", 60 | "%": "%", 61 | "►": "-", 62 | } 63 | unicode_punct_re = re.compile(f"[{''.join(unicode_punct.keys())}]") 64 | non_printing_chars_re = re.compile( 65 | f"[{''.join(map(chr, list(range(0,32)) + list(range(127,160))))}]" 66 | ) 67 | kenlm_model_dir = None 68 | sentence_piece_model_dir = None 69 | 70 | def __init__( 71 | self, 72 | model_dataset: str, 73 | language: str, 74 | lower_case: bool = False, 75 | remove_accents: bool = False, 76 | normalize_numbers: bool = True, 77 | punctuation: int = 1, 78 | ): 79 | self.model = kenlm.Model(os.path.join(model_dataset, f"{language}.arpa.bin")) 80 | self.tokenizer = SentencePiece(os.path.join(model_dataset, f"{language}.sp.model")) 81 | self.accent = remove_accents 82 | self.case = lower_case 83 | self.numbers = normalize_numbers 84 | self.punct = punctuation 85 | 86 | @classmethod 87 | def from_pretrained( 88 | cls, 89 | model_dataset: str, 90 | language: str, 91 | ): 92 | return cls( 93 | model_dataset, 94 | language, 95 | False, 96 | False, 97 | True, 98 | 1, 99 | ) 100 | 101 | def pp(self, log_score, length): 102 | return 10.0 ** (-log_score / length) 103 | 104 | def get_perplexity(self, doc: str, normalize_cc_net: bool = True): 105 | if normalize_cc_net: 106 | doc = self.normalize( 107 | doc, 108 | accent=self.accent, 109 | case=self.case, 110 | numbers=self.numbers, 111 | punct=self.punct, 112 | ) 113 | # Tokenize (after normalizing): See https://github.com/facebookresearch/cc_net/blob/bda555bd1cf1ee2e0b925363e62a61cd46c8b60d/cc_net/mine.py#L352 for full pipeline 114 | doc = self.tokenizer.do(doc) 115 | doc_log_score, doc_length = 0, 0 116 | for line in doc.split("\n"): 117 | log_score = self.model.score(line) 118 | length = len(line.split()) + 1 119 | doc_log_score += log_score 120 | doc_length += length 121 | return round(self.pp(doc_log_score, doc_length), 1) 122 | 123 | def normalize( 124 | self, 125 | line: str, 126 | accent: bool = True, 127 | case: bool = True, 128 | numbers: bool = True, 129 | punct: int = 1, 130 | ) -> str: 131 | line = line.strip() 132 | if not line: 133 | return line 134 | if case: 135 | line = line.lower() 136 | if accent: 137 | line = self.strip_accents(line) 138 | if numbers: 139 | line = self.digit_re.sub("0", line) 140 | if punct == 1: 141 | line = self.replace_unicode_punct(line) 142 | elif punct == 2: 143 | line = self.remove_unicode_punct(line) 144 | line = self.remove_non_printing_char(line) 145 | return line 146 | 147 | def strip_accents(self, line: str) -> str: 148 | """Strips accents from a piece of text.""" 149 | nfd = unicodedata.normalize("NFD", line) 150 | output = [c for c in nfd if unicodedata.category(c) != "Mn"] 151 | if len(output) == line: 152 | return line 153 | return "".join(output) 154 | 155 | def replace_unicode_punct(self, text: str) -> str: 156 | return "".join(self.unicode_punct.get(c, c) for c in text) 157 | 158 | def remove_unicode_punct(self, text: str) -> str: 159 | """More aggressive version of replace_unicode_punct but also faster.""" 160 | return self.unicode_punct_re.sub("", text) 161 | 162 | def remove_non_printing_char(self, text: str) -> str: 163 | return self.non_printing_chars_re.sub("", text) 164 | -------------------------------------------------------------------------------- /7_DataAnalysis/corpus_evaluator.py: -------------------------------------------------------------------------------- 1 | import re 2 | import random 3 | from collections import OrderedDict 4 | import json 5 | import tiktoken 6 | import openai 7 | from openai import OpenAI 8 | from datasets import Dataset, load_dataset 9 | 10 | # add http proxy 11 | # import os 12 | # os.environ["http_proxy"] = "http://127.0.0.1:10809" 13 | # os.environ["https_proxy"] = "http://127.0.0.1:10809" 14 | 15 | PROMPT = """你是一个语料评价专家,负责对单条语料(通常是一段自然语言文本)的质量进行打分以用于大语言模型的预训练 16 | 你的评价标准是: 17 | 语言质量(0-10分):考察语料的语法、拼写、词汇是否正确,语言表达是否流畅。语言质量高的语料利于模型学习语言规则,可以得高分。得分依据:语法和拼写正确(2分),词汇丰富(2分),表达流畅(2分),长难句或生僻词出现(2分),语言总体复杂(2分)。 18 | 19 | 信息量(0-10分):考察语料所包含的知识量和概念量。信息量大的语料有利于模型学习丰富知识,可以得高分。得分依据:包含专业知识或生僻概念(3分),篇幅较长或讨论多个话题(3分),详尽叙述某一话题(2分),提供新的信息或见解(2分)。 20 | 21 | 新颖性(0-10分):考察语料中的新奇词汇、新信息或新思想对模型理解范围的扩展作用。新颖性高的语料可以得高分。得分依据:包含新词或新概念(3分),提供新信息或新见解(3分),采用新角度或新形式表达观点(2分),创造新的词或短语(2分)。 22 | 23 | 连贯性(0-10分): 主题明确,观点连贯,论证严谨,构成完整论述(3分);主题基本清晰,且论证严谨。(3分) 各部分同属同一话题,构成连贯整体(4分)。 24 | 25 | 纯净度(0-10分):考察语料含有无关信息如广告、营销、垃圾信息的数量,含此类信息少而大部分内容都与主题相关的语料可以得高分。得分依据:主要内容表达完整(3分),垃圾信息含量少(3分),完全没有垃圾信息(4分) 26 | 27 | 通过以上评价标准,你将对下面的语料进行打分: 28 | 【语料开始】 29 | 30 | {corpus} 31 | 32 | 【语料结束】 33 | 34 | 请先分条给出评价理由,再给出对应分数并格式化输出。 35 | 示例: 36 | 【语言质量】:语法和拼写基本正确,词汇较丰富,表达流畅,出现生僻词如“幽灵枪”和长句,语言较复杂。【分数】8 37 | 【信息量】:涉及专业领域知识如各类枪支、美国控枪法案等,讨论多个话题如美国枪支文化与政策、美国枪支暴力现状等,详尽论述美国枪支状况,提供大量数据与信息。【分数】9 38 | 【新颖性】:出现新词“幽灵枪”和新概念如“极端枪支文化”,从政治经济角度揭示美国枪支问题新原因,以全新的角度解析美国枪支文化。【分数】8 39 | 【连贯性】:文中各部分紧密衔接,从美国枪支政策演变到枪支问题分析,再到政治经济因素剖析,行文逻辑清晰,段落结构明确。【分数】9 40 | 【纯净度】:文中的主要内容表达完整,大部分文本都与主题相关,但是结尾含有推广引流信息,不过垃圾信息含量较少。【分数】7 41 | 42 | 输出:""" 43 | 44 | all_aspects = ["语言质量", "信息量", "新颖性", "连贯性", "纯净度"] 45 | 46 | tokenizer = tiktoken.get_encoding('cl100k_base') 47 | 48 | 49 | def read_data(data_path: str, data_num: int) -> Dataset: 50 | dataset = load_dataset("json", data_files=[data_path], split="train", keep_in_memory=True) 51 | 52 | if data_num is not None: 53 | data_num = min(data_num, len(dataset)) 54 | random_indices = random.sample(range(len(dataset)), data_num) 55 | 56 | return dataset.select(random_indices) 57 | 58 | 59 | def cut_corpus(text, max_len=1000): 60 | text_tokens = tokenizer.encode(str(text).strip()) 61 | if len(text_tokens) > max_len: 62 | text_readable = False 63 | text_tokens = text_tokens[:max_len] 64 | while not text_readable and len(text_tokens) > 1: 65 | try: 66 | text = tokenizer.decode(text_tokens) 67 | text_readable = True 68 | except: 69 | text_tokens = text_tokens[:-1] 70 | return text 71 | 72 | 73 | def call_openai_func(instruction: str, model: str = "gpt-3.5-turbo-1106", api_key: str = None, organization: str = None) -> str: 74 | 75 | openai.api_key = api_key 76 | openai.organization = organization 77 | 78 | client = OpenAI(api_key=api_key, organization=organization) 79 | 80 | completion = client.chat.completions.create( 81 | model=model, 82 | messages=[ 83 | {"role": "system", 84 | "content": "You are a helpful assistant."}, 85 | {"role": "user", "content": instruction}, 86 | ], 87 | temperature=0.2, 88 | max_tokens=512, 89 | ) 90 | return completion.choices[0].message.content 91 | 92 | 93 | 94 | def extract_result(text: str) -> dict: 95 | pattern = r'【(.*?)】:(.*?)【分数】(\d+)\n' 96 | 97 | matches = re.findall(pattern, text+"\n") 98 | 99 | result = OrderedDict({aspect: {"reason": "", "score": -1} for aspect in all_aspects}) 100 | assert len(matches) == len(all_aspects) 101 | for match in matches: 102 | aspect = match[0] 103 | reason = match[1] 104 | score = match[2] 105 | assert aspect in all_aspects 106 | result[aspect] = {"reason": reason, "score": int(score)} 107 | 108 | return result 109 | 110 | def save_jsonl(data, output_path): 111 | with open(output_path, 'w', encoding='utf-8') as output_file: 112 | for item in data: 113 | output_file.write(json.dumps(item, ensure_ascii=False) + "\n") 114 | 115 | def corpus_quality_measure_fn( 116 | data_path: str, 117 | eval_path: str = None, 118 | data_num: int = None, 119 | text_column: str = "text", 120 | model: str = "gpt-3.5-turbo-1106", 121 | api_key: str = None, 122 | organization: str = None, 123 | num_proc: int = 1,): 124 | 125 | def eval_single_item(obj): 126 | text = obj[text_column] 127 | instruction = PROMPT.format(corpus=cut_corpus(text)) 128 | 129 | try: 130 | response = call_openai_func(instruction, model, api_key, organization) 131 | result = extract_result(response) 132 | except Exception as e: 133 | print("Error") 134 | print(e) 135 | result = OrderedDict({aspect: {"reason": "", "score": -1} for aspect in all_aspects}) 136 | 137 | obj["quality"] = result 138 | return obj 139 | 140 | corpus = read_data(data_path, data_num) 141 | corpus = corpus.map(eval_single_item, num_proc=num_proc) 142 | 143 | if eval_path is not None: 144 | save_jsonl(corpus, eval_path) 145 | # corpus.to_json(eval_path, batch_size=128, force_ascii=False) 146 | 147 | return corpus 148 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 📦 YiZhao: A 2TB Open Financial Dataset 2 | 3 |
4 | 🤗 Hugging Face | 5 | 🤖 ModelScope | 6 | 🪄 YiZhao-12B-Chat | 7 | 📑 Technical Report 8 |
9 | 10 | Data and tools for generating and inspecting **YiZhao**, a safe, high-quality, open-sourced bilingual financial corpus (Chinese, English) released by Harbin Institute of Technology (Shenzhen) and China Merchants Bank Artificial Intelligence Laboratory. 11 | 12 | ## 🌟 Environment 13 | Our recommended Python version is **3.11.4**. 14 | ``` 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | ## 🧩 Data Preprocessing 19 | 20 | ### 1. Remove personal information 21 | This step completes the removal of personal information such as IP addresses, emails, and phone numbers. 22 | #### Example usage 23 | ``` 24 | python 1_pii.py \ 25 | --input_path input.jsonl \ 26 | --output_path output.jsonl \ 27 | --text_column text \ 28 | --num_proc 4 \ 29 | --batch_size 100 30 | ``` 31 | 32 | ### 2. Sensitive Words 33 | To avoid the inclusion of toxic content in the training data, one approach is to filter out texts that contain specific sensitive keywords. You need to store the ***txt*** files containing sensitive words in `2_toxic_filter/sensitive_words`. 34 | #### Example usage 35 | ``` 36 | python 2_toxic_filter/2_toxic_filter.py \ 37 | --input_path input.jsonl \ 38 | --output_path output.jsonl \ 39 | --text_column text \ 40 | ``` 41 | 42 | ### 3. Rule Filtering 43 | This step completes multiple rule-based data filtering. 44 | - Language Filtering: Retain only text data in a specific language (***zh-cn*** or ***en***). 45 | - Punctuation and whitespace consistency processing: Unify Chinese and English punctuation within the text, and standardize different types of whitespace characters as well. 46 | - Deduplication of consecutive punctuation: Replace all matched consecutive punctuation marks with a single punctuation mark. 47 | - Punctuation Ratio Filtering: Filter out texts with too high a punctuation ratio. 48 | - Data Length Filtering: Filter out text data that is too short. 49 | #### Example usage 50 | ``` 51 | python 3_rule_filter.py \ 52 | --input_path input.jsonl \ 53 | --output_path output.jsonl \ 54 | --text_column text \ 55 | --language zh-cn \ 56 | --punctuation_ratio_threshold 0.5 \ 57 | --text_length_threshold 128 \ 58 | ``` 59 | 60 | ### 4. Perplexity Filtering 61 | You need to first download the model from the [KenLM repository](https://huggingface.co/edugp/kenlm), and then modify the corresponding model path in the following line in `4_perplexity_filter/kenlm/run.py`. 62 | ```python 63 | model = KenlmModel.from_pretrained("kenlm/wikipedia", args.language) #language = zh or en 64 | ``` 65 | #### Example usage 66 | ``` 67 | python 4_perplexity_filter/kenlm/run.py \ 68 | --input_path input.jsonl \ 69 | --output_path output.jsonl \ 70 | --text_column text \ 71 | --language zh \ 72 | ``` 73 | 74 | ### 5. Exact Deduplication 75 | Deduplicate identical text entries in the dataset. 76 | #### Example usage 77 | ``` 78 | python 5_text_dedup/5_clean.py \ 79 | --input_path input.jsonl \ 80 | --output_path output.jsonl \ 81 | --text_column text \ 82 | --cache cache_dir \ 83 | --num_proc 2 \ 84 | --batch_size 100 85 | ``` 86 | 87 | ### 6. Fuzzy Deduplication 88 | Deduplicate similar texts in the dataset. 89 | #### Example usage 90 | ``` 91 | python 6_text_dedup/text_dedup/minhash.py \ 92 | --input_path input.jsonl \ 93 | --output_path output.jsonl \ 94 | --column text \ 95 | --cache_dir cache_dir \ 96 | --threshold 0.8 \ 97 | --false_positive_weight 0.5 \ 98 | --false_negative_weight 0.5 \ 99 | ``` 100 | 101 | ### 7. Financial relevance filtering and security risk filtering 102 | Using a financial relevance classifier (🤗[fin-model-zh-v0.1](https://huggingface.co/HIT-TMG/fin-model-zh-v0.1) and [fin-model-en-v0.1](https://huggingface.co/HIT-TMG/fin-model-en-v0.1)) and a security risk identification classifier (🤗[risk-model-zh-v0.1](https://huggingface.co/HIT-TMG/risk-model-zh-v0.1) and [risk-model-en-v0.1](https://huggingface.co/HIT-TMG/risk-model-en-v0.1)), we filter out high-quality financial corpus. 103 | 104 | 105 | 106 | ## ⚡️ Data Evaluation 107 | We evaluate each piece of data from the following aspects: 108 | - **Language Quality (0-10 points)**: This examines whether the data is grammatically correct, spelled correctly, uses appropriate vocabulary, and if the expression is fluent. High language quality aids the model in learning language rules, resulting in a higher score. ***Scoring criteria***: correct grammar and spelling (2 points), rich vocabulary (2 points), fluent expression (2 points), use of complex sentences or rare words (2 points), and overall language complexity (2 points). 109 | 110 | - **Information Content (0-10 points)**: This measures the amount of knowledge and concepts contained in the data. Data with high information content helps the model learn rich knowledge, leading to a higher score. ***Scoring criteria***: includes specialized knowledge or obscure concepts (3 points), longer length or discussion of multiple topics (3 points), detailed discussion of a single topic (2 points), and providing new information or insights (2 points). 111 | 112 | - **Novelty (0-10 points)**: This evaluates the extent to which new vocabulary, information, or ideas in the data expand the model's understanding. Data with high novelty can receive higher scores. ***Scoring criteria***: includes new words or concepts (3 points), provides new information or insights (3 points), presents ideas from new perspectives or in new forms (2 points), and creates new words or phrases (2 points). 113 | 114 | - **Coherence (0-10 points)**: ***Scoring criteria***: This assesses whether the data has a clear theme, coherent arguments, and rigorous reasoning, forming a complete discussion (3 points); a mostly clear theme with rigorous reasoning (3 points); all parts belong to the same topic, forming a coherent whole (4 points). 115 | 116 | - **Purity (0-10 points)**: This evaluates the amount of irrelevant information, such as ads, marketing, or spam, in the data. Data with little to no such information and content that mostly relates to the topic can score higher. ***Scoring criteria***: the main content is fully expressed (3 points), low spam content (3 points), and no spam content at all (4 points). 117 | 118 | #### Example usage 119 | ``` 120 | python 7_DataAnalysis/eval_pipeline.py \ 121 | --data_path input.jsonl \ 122 | --eval_path output.jsonl \ 123 | --text_column text \ 124 | --tiktoken_cache cache_dir \ 125 | --figure_dir figure_dir \ 126 | --model gpt-3.5-turbo-1106 \ 127 | --api_key xxxx \ 128 | --organization xxxx \ 129 | --num_proc 1 \ 130 | ``` 131 | 132 | -------------------------------------------------------------------------------- /5_text_dedup/clean_helpers/deduplication.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from functools import partial 3 | from typing import List, Set, Tuple, Dict, Callable, Optional 4 | import hashlib 5 | import re 6 | import string 7 | import urllib 8 | 9 | from datasets import Dataset 10 | 11 | 12 | # ======== DEDUPLICATION FUNCTIONS =================== 13 | from clean_helpers.utils import parse_meta 14 | 15 | 16 | def build_dedup_template(min_template_line_size: int, min_template_line_occurence: int): 17 | def dedup_template(ds: Dataset, num_proc: int, batch_size: int) -> Dataset: 18 | """Computes and remove templates lines""" 19 | # Compute the hash of each lines 20 | split_into_lines_and_hashes = ds.map( 21 | split_text_to_lines_and_hash, 22 | num_proc=num_proc, 23 | batched=True, 24 | batch_size=batch_size, 25 | remove_columns=ds.column_names 26 | ) 27 | lines_and_hashes = split_into_lines_and_hashes.remove_columns( 28 | set(split_into_lines_and_hashes.column_names) - {"lines", "hashes"} 29 | ) 30 | 31 | # Find template lines 32 | count_lines_occurence = defaultdict(lambda: 0) 33 | for row in lines_and_hashes: 34 | filtered_lines_and_hashes = [ 35 | (line, hash_) 36 | for line, hash_ in zip(row["lines"], row["hashes"]) 37 | if len(line) >= min_template_line_size 38 | ] 39 | for _, hash_ in filtered_lines_and_hashes: 40 | count_lines_occurence[hash_] += 1 41 | 42 | template_line_hashes = {k for k, v in count_lines_occurence.items() if v >= min_template_line_occurence} 43 | del count_lines_occurence 44 | 45 | # Clean dataset 46 | return split_into_lines_and_hashes.map( 47 | build_remove_template_lines(template_line_hashes), 48 | num_proc=num_proc, 49 | batched=True, 50 | batch_size=batch_size, 51 | remove_columns=split_into_lines_and_hashes.column_names 52 | ) 53 | 54 | return dedup_template 55 | 56 | 57 | def build_dedup_document(batch_normalizer: Callable[[Dict], List[str]]): 58 | def dedup_document(ds: Dataset, num_proc: int, batch_size: int) -> Dataset: 59 | hashed_documents = ds.map( 60 | lambda batch: {**batch, "hash": get_hash(batch_normalizer(batch))}, 61 | num_proc=num_proc, 62 | batched=True, 63 | batch_size=batch_size 64 | ) 65 | 66 | hashes = set() 67 | 68 | return hashed_documents.map( 69 | partial(delete_text_from_duplicates, hashes=hashes), 70 | num_proc=1, # VERY IMPORTANT: hashes will be updated, and is not thread safe. 71 | batched=True, 72 | batch_size=batch_size, 73 | remove_columns=hashed_documents.column_names 74 | ) 75 | 76 | return dedup_document 77 | 78 | 79 | # =========== HELPERS =============== 80 | 81 | def get_hash(texts: List[str]) -> List[str]: 82 | """Get hash of content field.""" 83 | return [hashlib.md5(text.strip().encode("utf-8")).hexdigest() for text in texts] 84 | 85 | def split_text_in_lines(text: str) -> List[str]: 86 | return [line.strip() for line in text.split("\n")] 87 | 88 | def split_text_to_lines_and_hash(batch: Dict[str, List]): 89 | lines_per_texts = [split_text_in_lines(text) for text in batch["text"]] 90 | return { 91 | **{k: v for k, v in batch.items() if k != "text"}, 92 | "lines": lines_per_texts, 93 | "hashes": [get_hash(lines) for lines in lines_per_texts] 94 | } 95 | 96 | 97 | def clean_text(lines_and_hashes: List[Tuple[str, int]], template_line_hashes: Set[str]): 98 | return "\n".join([line for line, hash_ in lines_and_hashes if hash_ not in template_line_hashes]) 99 | 100 | 101 | def build_remove_template_lines(template_line_hashes: Set[str]): 102 | def remove_template_lines(batch: Dict[str, List]): 103 | cleaned_texts = [ 104 | clean_text( 105 | list(zip(lines, hashes)), 106 | template_line_hashes 107 | ) 108 | for lines, hashes in zip(batch["lines"], batch["hashes"]) 109 | ] 110 | return { 111 | **{ 112 | key: value 113 | for key, value in batch.items() 114 | if key not in ["lines", "hashes"] 115 | }, 116 | "text": [cleaned_text for cleaned_text in cleaned_texts] 117 | } 118 | 119 | return remove_template_lines 120 | 121 | 122 | def is_new_hash(hash_: str, hashes: Set[str]) -> bool: 123 | """Check if current hash is still in set of unique hashes and remove if true.""" 124 | if hash_ in hashes: 125 | return False 126 | else: 127 | hashes.add(hash_) 128 | return True 129 | 130 | def delete_text_from_duplicates(batch: Dict[str, List], hashes: Set[str]) -> Dict[str, List]: 131 | return { 132 | **{k: v for k, v in batch.items() if k != "hash"}, 133 | "text": [text if is_new_hash(hash_, hashes) else "" for text, hash_ in zip(batch["text"], batch["hash"])] 134 | } 135 | 136 | def url_with_only_some_query_param(url: str, query_param_map: Optional[dict] = None) -> str: 137 | url_parse = urllib.parse.urlparse(url) 138 | query = url_parse.query 139 | 140 | url_query_params = urllib.parse.parse_qsl(query) 141 | 142 | if query_param_map is None: 143 | url_query_params_new = {} 144 | else: 145 | url_query_params_new = [(query_param_map[old_key], old_value) for (old_key, old_value) in url_query_params if old_key in query_param_map] 146 | 147 | url_new_query = urllib.parse.urlencode(url_query_params_new, encoding="utf-8") 148 | url_parse = url_parse._replace(query=url_new_query) 149 | new_url = urllib.parse.urlunparse(url_parse) 150 | return new_url 151 | 152 | # =========== BATCH NORMALISER =============== 153 | 154 | 155 | # this only keeps letter characters 156 | remove_non_character_regex = re.compile(f'\s+|\d+|[{re.escape(string.punctuation)}]') 157 | def document_batch_normalizer(batch: Dict) -> List[str]: 158 | return [remove_non_character_regex.sub('', text) for text in batch["text"]] 159 | 160 | 161 | def strict_url_batch_normalizer(batch: Dict) -> List[str]: 162 | return [parse_meta(meta)["url"] for meta in batch["meta"]] 163 | 164 | 165 | url_host_and_path_regex = re.compile(r"^(.[^?]*)") 166 | def url_host_and_path_batch_normalizer(batch: Dict) -> List[str]: 167 | return [url_host_and_path_regex.match(parse_meta(meta)["url"]).group(1) for meta in batch["meta"]] 168 | 169 | lm_es_pseudocrawl_filtered_341_es_cointelegraph_com_regex = re.compile(r"^((?:(?!/amp)/?(?:[^?/]*))+)(?:/amp)?") 170 | def url_lm_es_pseudocrawl_filtered_341_es_cointelegraph_com(batch: Dict) -> List[str]: 171 | return [lm_es_pseudocrawl_filtered_341_es_cointelegraph_com_regex.match(parse_meta(meta)["url"]).group(1) for meta in batch["meta"]] 172 | 173 | def url_lm_en_pseudocrawl_filtered_619_www_qut_edu_au(batch: Dict) -> List[str]: 174 | return [url_with_only_some_query_param(parse_meta(meta)["url"], {"id": "id", "news-id": "id"}) for meta in batch["meta"]] -------------------------------------------------------------------------------- /6_text_dedup/text_dedup/utils/add_args.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2022-11-05 09:16:34 4 | # @Author : Chenghao Mou (mouchenghao@gmail.com) 5 | import argparse 6 | 7 | 8 | def add_io_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: # pragma: no cover 9 | """ 10 | Add input/output arguments to parser. 11 | 12 | Parameters 13 | ---------- 14 | parser : argparse.ArgumentParser 15 | Parser to add arguments to. 16 | 17 | Returns 18 | ------- 19 | parser : argparse.ArgumentParser 20 | Parser with added arguments. 21 | """ 22 | parser.add_argument("--input_path", type=str, help="`path` in load_dataset", required=False), 23 | parser.add_argument("--name", type=str, help="`name` in load_dataset"), 24 | parser.add_argument("--data_dir", type=str, help="`data_dir` in load_dataset"), 25 | parser.add_argument("--data_files", type=str, help="`data_files` in load_dataset"), 26 | parser.add_argument("--split", type=str, help="`split` in load_dataset"), 27 | parser.add_argument("--cache_dir", type=str, help="`cache_dir` in load_dataset", default=".cache"), 28 | parser.add_argument("--revision", type=str, help="`revision` in load_dataset"), 29 | parser.add_argument( 30 | "--use_auth_token", action=argparse.BooleanOptionalAction, help="`use_auth_token` in load_dataset" 31 | ), 32 | parser.add_argument("--local", action=argparse.BooleanOptionalAction, help="Use local dataset", default=False), 33 | parser.add_argument("--output_path", type=str, help="Path to deduplicated dataset output", required=False), 34 | parser.add_argument( 35 | "--debug", action=argparse.BooleanOptionalAction, help="Whether to run in debug mode", default=False 36 | ) 37 | return parser 38 | 39 | 40 | def add_meta_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: # pragma: no cover 41 | """ 42 | Add meta arguments to parser. 43 | 44 | Parameters 45 | ---------- 46 | parser : argparse.ArgumentParser 47 | Parser to add arguments to. 48 | 49 | Returns 50 | ------- 51 | parser : argparse.ArgumentParser 52 | Parser with added arguments. 53 | """ 54 | parser.add_argument( 55 | "--column", 56 | type=str, 57 | help="""Text column to use for deduplication. Concatenate desired columns beforehand if needed.""", 58 | required=False, 59 | ), 60 | parser.add_argument( 61 | "--batch_size", 62 | type=int, 63 | help="""Batch size to use for dataset iteration. Mainly for memory efficiency.""", 64 | default=1000000, 65 | ), 66 | return parser 67 | 68 | 69 | def add_minhash_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: # pragma: no cover 70 | """ 71 | Add MinHash arguments to parser. 72 | 73 | Parameters 74 | ---------- 75 | parser : argparse.ArgumentParser 76 | Parser to add arguments to. 77 | 78 | Returns 79 | ------- 80 | parser : argparse.ArgumentParser 81 | Parser with added arguments. 82 | """ 83 | parser.add_argument( 84 | "--ngram", 85 | type=int, 86 | default=5, 87 | help="Ngram size to use in MinHash.", 88 | ) 89 | parser.add_argument( 90 | "--min_length", 91 | type=int, 92 | default=5, 93 | help="Minimum number of tokens to use in MinHash. Shorter documents will be filtered out.", 94 | ) 95 | parser.add_argument("--seed", type=int, default=42, help="Seed to use in MinHash") 96 | parser.add_argument("--num_perm", type=int, default=256, help="Number of permutations to use in MinHash") 97 | parser.add_argument( 98 | "--threshold", type=float, default=0.7, help="Jaccard similarity threshold to use in MinHashLSH" 99 | ) 100 | parser.add_argument( 101 | "--b", 102 | type=int, 103 | default=None, 104 | help="Number of bands", 105 | ) 106 | parser.add_argument( 107 | "--r", 108 | type=int, 109 | default=None, 110 | help="Number of rows per band", 111 | ) 112 | 113 | return parser 114 | 115 | 116 | def add_simhash_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: # pragma: no cover 117 | """ 118 | Add SimHash arguments to parser. 119 | 120 | Parameters 121 | ---------- 122 | parser : argparse.ArgumentParser 123 | Parser to add arguments to. 124 | 125 | Returns 126 | ------- 127 | parser : argparse.ArgumentParser 128 | Parser with added arguments. 129 | """ 130 | parser.add_argument( 131 | "--ngram", 132 | type=int, 133 | default=3, 134 | help="""Ngram size to use in SimHash.""", 135 | ) 136 | parser.add_argument("--f", type=int, default=64, help="Simhash bit size"), 137 | parser.add_argument("--bit_diff", type=int, default=3, help="Bit difference to use in SimHash"), 138 | parser.add_argument( 139 | "--num_bucket", type=int, default=4, help="Number of buckets to use in SimHash, must be larger than bit_diff" 140 | ), 141 | return parser 142 | 143 | 144 | def add_sa_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: # pragma: no cover 145 | """ 146 | Add Suffix Array arguments to parser. 147 | 148 | Parameters 149 | ---------- 150 | parser : argparse.ArgumentParser 151 | Parser to add arguments to. 152 | 153 | Returns 154 | ------- 155 | parser : argparse.ArgumentParser 156 | Parser with added arguments. 157 | """ 158 | parser.add_argument( 159 | "--k", type=int, default=100, help="Minimum byte length of a duplicate substring in Suffix Array Deduplication" 160 | ), 161 | parser.add_argument( 162 | "--strategy", 163 | type=str, 164 | default="overlapping", 165 | help="Strategy when there are overlapping duplicate substrings", 166 | choices=["overlapping", "longest"], 167 | ) 168 | parser.add_argument( 169 | "--google_repo_path", type=str, help="Path to google-research-deduplication codebase", required=True 170 | ), 171 | return parser 172 | 173 | 174 | def add_bloom_filter_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: # pragma: no cover 175 | """ 176 | Add Bloom Filter arguments to parser. 177 | 178 | Parameters 179 | ---------- 180 | parser : argparse.ArgumentParser 181 | Parser to add arguments to. 182 | 183 | Returns 184 | ------- 185 | parser : argparse.ArgumentParser 186 | Parser with added arguments. 187 | """ 188 | parser.add_argument("--error_rate", type=float, default=1e-6, help="Error rate to use in BloomFilter"), 189 | parser.add_argument("--hash_func", type=str, default="md5", help="Hash function to use in BloomFilter"), 190 | parser.add_argument("--initial_capacity", type=int, default=100, help="Initial capacity of BloomFilter"), 191 | return parser 192 | 193 | 194 | def add_exact_hash_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: # pragma: no cover 195 | """ 196 | Add Exact Hash arguments to parser. 197 | 198 | Parameters 199 | ---------- 200 | parser : argparse.ArgumentParser 201 | Parser to add arguments to. 202 | 203 | Returns 204 | ------- 205 | parser : argparse.ArgumentParser 206 | Parser with added arguments. 207 | """ 208 | parser.add_argument("--hash_func", type=str, default="md5", help="Hash function to use in ExactHash"), 209 | return parser 210 | 211 | def add_own_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 212 | parser.add_argument("--output_duped", type=str, help="duped path to store"), 213 | parser.add_argument( 214 | "--false_positive_weight", 215 | type=float, 216 | default=0.5, 217 | help="false_positive_weight", 218 | ), 219 | parser.add_argument( 220 | "--false_negative_weight", 221 | type=float, 222 | default=0.5, 223 | help="false_negative_weight", 224 | ), 225 | parser.add_argument("--dataset_name", type=str, help="dataset_name",default="text_dedup.jsonl"), 226 | # parser.add_argument("--output_duped", type=str, help="duped path to store"), 227 | return parser -------------------------------------------------------------------------------- /5_text_dedup/5_clean.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import random 5 | from functools import partial 6 | 7 | import torch 8 | from datasets import Dataset, load_dataset, load_from_disk, concatenate_datasets 9 | from pathlib import Path 10 | from typing import Tuple, Optional, List, Dict 11 | from datasets.utils.logging import set_verbosity_info 12 | from numpy.random import default_rng 13 | 14 | 15 | from clean_helpers import build_dedup_template, build_dedup_document, concatenate_lm_fr_ester 16 | from clean_helpers.deduplication import document_batch_normalizer, url_host_and_path_batch_normalizer, \ 17 | url_lm_es_pseudocrawl_filtered_341_es_cointelegraph_com, url_lm_en_pseudocrawl_filtered_619_www_qut_edu_au 18 | 19 | 20 | 21 | set_verbosity_info() 22 | logger = logging.getLogger(__name__) 23 | torch.set_num_threads(1) 24 | 25 | # Deduplication functions and boolean to save a sample of the modifications: function(ds: Dataset, num_proc: int, batch_size: int) -> Dataset 26 | DEDUPS = { 27 | "dedup_template_soft": (build_dedup_template( 28 | min_template_line_size=15, 29 | min_template_line_occurence=10, 30 | ), True), 31 | "dedup_pseudocrawl_newspapers": (build_dedup_template( 32 | min_template_line_size=0, 33 | min_template_line_occurence=2, 34 | ), True), 35 | "dedup_document": (build_dedup_document(document_batch_normalizer), True), 36 | "dedup_document_on_url": (build_dedup_document(url_host_and_path_batch_normalizer), True), 37 | "dedup_document_on_url_lm_es_pseudocrawl-filtered_341_es_cointelegraph_com": (build_dedup_document( 38 | url_lm_es_pseudocrawl_filtered_341_es_cointelegraph_com 39 | ), True), 40 | "dedup_document_on_url_lm_en_pseudocrawl_filtered_619_www_qut_edu_au": (build_dedup_document( 41 | url_lm_en_pseudocrawl_filtered_619_www_qut_edu_au 42 | ), True), 43 | "concatenate_lm_fr_ester": (concatenate_lm_fr_ester, False) 44 | } 45 | 46 | 47 | DEDUPS_KEYS = set(DEDUPS.keys()) 48 | 49 | def get_size_per_example(texts: List[str]) -> Dict: 50 | size_values = [len(text.encode()) for text in texts] 51 | examples = {"bytes_len": size_values} 52 | return examples 53 | 54 | def quick_size_estimation( 55 | ds: Dataset, 56 | num_proc: int, 57 | batch_size: int, 58 | content_key:str ="text" 59 | ) -> int: 60 | if len(ds) == 0: 61 | return 0 62 | rng = default_rng(1991) 63 | subset_size = min(10000, len(ds)) 64 | indices = rng.choice(len(ds), size=subset_size, replace=False, shuffle=False) 65 | partial_ds = ds.select(indices) 66 | ratio = float(len(ds)) / float(subset_size) 67 | 68 | partial_ds = partial_ds.map( 69 | get_size_per_example, 70 | batched=True, 71 | num_proc=num_proc, 72 | batch_size=batch_size, 73 | input_columns=[content_key], 74 | remove_columns=partial_ds.column_names, 75 | ) 76 | len_bytes = sum(partial_ds["bytes_len"]) 77 | return len_bytes * ratio 78 | 79 | 80 | 81 | 82 | def filter_diff_text(examples, in_text_col, out_text_col): 83 | return [text_in != text_out for text_in, text_out in zip(examples[in_text_col], examples[out_text_col])] 84 | 85 | def get_args(): 86 | parser = argparse.ArgumentParser() 87 | parser.add_argument("--input_path", type=str, required=True, help="Dataset path we load the dataset from.") 88 | parser.add_argument("--output_path", type=Path, required=True, 89 | help="Path where we save resulting dataset after modifications.") 90 | parser.add_argument('--text_column', type=str) 91 | parser.add_argument("--cache", type=str, required=True, help="Cache Path.") 92 | parser.add_argument("--checks_save_path", type=Path, default=None, 93 | help="Path where we save samples we've removed or changed throughout the modifications.") 94 | parser.add_argument("--num_proc", type=int, default=1) 95 | parser.add_argument("--batch_size", type=int, default=100) 96 | parser.add_argument("--load_arrow_file", action="store_true", 97 | help="Option to indicate how to load original dataset. By default we use `load_dataset`. " 98 | "If the flag is use, we use `load_from_disk`") 99 | parser.add_argument("--sampling_size_map_checks", type=int, default=None, 100 | help="Optional argument. Checked dataset, ie sample we've changed throughout the " 101 | "modifications, are either save in whole or only a subset. If set to None, this flag " 102 | "saves everything, otherwise it saves a subset with its size corresponding to this value.") 103 | parser.add_argument("--sampling_size_filter_checks", type=int, default=None, 104 | help="Optional argument. Checked dataset, ie sample we've removed throughout the " 105 | "modifications, are either save in whole or only a subset. If set to None, this flag " 106 | "saves everything, otherwise it saves a subset with its size corresponding to this value.") 107 | parser.add_argument("--from_scratch", action="store_true", help="Resave all datasets on disk.") 108 | parser.add_argument("--save_to_json", default=True, help="Save output dataset in json format.") 109 | return parser.parse_args() 110 | 111 | def log_stats(title: str, original_ds: Dataset, after_transformation_ds: Dataset, operation_type: str, args): 112 | original_length = len(original_ds) 113 | after_transformation_length = len(after_transformation_ds) 114 | original_bytes = quick_size_estimation(original_ds, batch_size=args.batch_size, num_proc=args.num_proc, content_key=args.text_column) 115 | after_transformation_btyes = quick_size_estimation(after_transformation_ds, batch_size=args.batch_size, num_proc=args.num_proc, content_key=args.text_column) 116 | logger.info(title) 117 | logger.info(f" Initial number of samples: {original_length} samples") 118 | logger.info(f" {operation_type} samples: {original_length - after_transformation_length} samples") 119 | logger.info(f" {operation_type} percentage: {(original_length - after_transformation_length) / original_length * 100:.2f} %") 120 | logger.info(f" Final number of samples: {after_transformation_length} samples") 121 | logger.info(f" Initial size in bytes: {original_bytes * 1e-9:.4f} GB") 122 | logger.info(f" {operation_type} bytes: {(original_bytes - after_transformation_btyes) * 1e-9:.4f} GB") 123 | logger.info(f" {operation_type} percentage in bytes: {(original_bytes - after_transformation_btyes) / original_bytes * 100:.2f} %") 124 | logger.info(f" Final size in bytes: {after_transformation_btyes * 1e-9:.4f} GB") 125 | 126 | 127 | 128 | def get_modified_documents( 129 | ds: Dataset, 130 | mapped_ds: Dataset, 131 | num_proc: int, 132 | batch_size: int, 133 | sampling_size: Optional[int], 134 | text_column, 135 | ) -> Dataset: 136 | remove_columns = set(ds.column_names) 137 | remove_columns.remove(text_column) 138 | ds = ds.remove_columns(remove_columns) 139 | ds = ds.rename_column(text_column, f"old_text") 140 | 141 | assert len(mapped_ds) == len(ds), f"Mapping function are batched, but they should not alter the size of the batch." 142 | mapped_diff_ds = concatenate_datasets([mapped_ds.flatten_indices(), ds.flatten_indices()], axis=1).filter( 143 | partial(filter_diff_text, in_text_col="old_text", out_text_col=text_column), 144 | batched=True, 145 | num_proc=num_proc, 146 | batch_size=batch_size 147 | ) 148 | 149 | logger.info("Examples of modified examples:") 150 | idx_samples = random.sample(range(len(mapped_diff_ds)), min(len(mapped_diff_ds), 10)) 151 | for idx in idx_samples: 152 | logger.info(f" Examples n°{idx} :\n{json.dumps(mapped_diff_ds[idx], indent=2)}") 153 | 154 | if sampling_size is not None: 155 | idx_samples = random.sample(range(len(mapped_diff_ds)), min(len(mapped_diff_ds), sampling_size)) 156 | mapped_diff_ds = mapped_diff_ds.select(idx_samples) 157 | 158 | return mapped_diff_ds 159 | 160 | 161 | def apply_function(function_name: str, ds: Dataset, args) -> Tuple[Dataset, Optional[Dataset]]: 162 | logger.info(f"Applying: {function_name}") 163 | if function_name in DEDUPS: 164 | dedup_function, dedup_check = DEDUPS[function_name] 165 | deduplicated_ds = dedup_function(ds, num_proc=args.num_proc, batch_size=args.batch_size) 166 | log_stats(f"Applied deduplication function: {function_name}", ds, deduplicated_ds, operation_type="Deduplicated", args=args) 167 | 168 | # Some deduplication do not preserve the number of samples, so alignement is lost. For example "dedup_document" 169 | if args.checks_save_path is not None and dedup_check: 170 | deduped_diff_ds = get_modified_documents(ds, deduplicated_ds, args.num_proc, args.batch_size, args.sampling_size_map_checks, args.text_column) 171 | return deduplicated_ds, deduped_diff_ds 172 | else: 173 | return deduplicated_ds, None 174 | else: 175 | raise NotImplementedError(f"{function_name} has not matched any existing function names. Available names:\n" 176 | f"Dedup functions: {DEDUPS_KEYS}\n" 177 | ) 178 | 179 | def main(): 180 | logging.basicConfig( 181 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 182 | datefmt="%m/%d/%Y %H:%M:%S", 183 | level=logging.INFO, 184 | ) 185 | args = get_args() 186 | logger.info(f"** The job is runned with the following arguments: **\n{args}\n **** ") 187 | 188 | # Load dataset 189 | logger.info(f" ===== Loading {args.input_path} =====") 190 | if args.load_arrow_file: 191 | ds = load_from_disk(args.input_path) 192 | else: 193 | ds = load_dataset("json", data_files=args.input_path, split="train", cache_dir= args.cache) 194 | 195 | # Apply series of dedups 196 | logger.info(f" ===== Applying transformations =====") 197 | 198 | preprocessings = ["dedup_template_soft", "dedup_document"] 199 | for idx, preprocessing in enumerate(preprocessings): 200 | ds, ds_diff = apply_function(preprocessing, ds, args) 201 | if ds_diff is not None and len(ds_diff) != 0: 202 | saving_path = args.checks_save_path / f"{idx}_{preprocessing}_checks" 203 | if not args.from_scratch and saving_path.exists(): 204 | continue 205 | tmp_save_path = Path(saving_path.parent, f"tmp-{saving_path.name}") 206 | logger.info(f" ===== Saving examples to check after {preprocessing} =====") 207 | ds_diff.save_to_disk(tmp_save_path) 208 | tmp_save_path.rename(saving_path) 209 | 210 | 211 | # Save to disk 212 | if args.from_scratch or not args.output_path.exists(): 213 | logger.info(f" ===== Saving dataset =====") 214 | logger.info(f"Saving to final dataset at {args.output_path}.") 215 | tmp_save_path = Path(args.output_path.parent, f"tmp-{args.output_path.name}") 216 | if len(ds) == 0: 217 | logger.info("Dataset was empty. Not saving anything.") 218 | return 219 | if args.save_to_json: 220 | ds.to_json( 221 | tmp_save_path, 222 | num_proc=args.num_proc, 223 | force_ascii=False 224 | ) 225 | else: 226 | ds.save_to_disk(tmp_save_path) 227 | tmp_save_path.rename(args.output_path) 228 | else: 229 | logging.info(f"Dataset was already saved at {args.output_path}") 230 | 231 | 232 | if __name__ == "__main__": 233 | main() 234 | -------------------------------------------------------------------------------- /1_pii.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from functools import partial 3 | from pathlib import Path 4 | import logging 5 | import random 6 | import sys 7 | import regex 8 | from datasets.utils.logging import set_verbosity_info 9 | from datasets import load_dataset, load_from_disk 10 | 11 | set_verbosity_info() 12 | logger = logging.getLogger(__name__) 13 | high_risk_tags = {'KEY', 'EMAIL', 'USER', 'IP_ADDRESS'} # , 'NUMBER', "ID"} 14 | year_patterns = [ 15 | # yyyy-yyyy or yyyy/yyyy 16 | regex.compile(r"(?:^|[\b\s@?,!;:\'\")(.\p{Han}])([1-2][0-9]{3}[\p{Pd}/][1-2][0-9]{3})(?:$|[\s@,?!;:\'\"(.\p{Han}])"), 17 | # yyyy-mm-dd or yyyy-dd-mm or yyyy/mm/dd or yyyy/dd/mm or yyyy.mm.dd or yyyy.dd.mm 18 | regex.compile(r"(?:^|[\b\s@?,!;:\'\")(.\p{Han}])([1-2][0-9]{3}[\p{Pd}/.][0-3][0-9][\p{Pd}/.][0-3][0-9])(?:$|[\s@,?!;:\'\"(.\p{Han}])"), 19 | # mm-dd-yyyy or dd-mm-yyyy or mm/dd/yyyy or dd/mm/yyyy or mm.dd.yyyy or dd.mm.yyyy or the same but with yy instead of yyyy 20 | regex.compile(r"(?:^|[\b\s@?,!;:\'\")(.\p{Han}])([0-3][0-9][\p{Pd}/.][0-3][0-9][\p{Pd}/.](?:[0-9]{2}|[1-2][0-9]{3}))(?:$|[\s@,?!;:\'\"(.\p{Han}])"), 21 | # mm-yyyy or mm/yyyy or the same but with yy 22 | regex.compile(r"(?:^|[\b\s@?,!;:\'\")(.\p{Han}])([0-3][0-9][\p{Pd}/](?:[0-9]{2}|[1-2][0-9]{3}))(?:$|[\s@,?!;:\'\"(.\p{Han}])"), 23 | # yyyy-mm or yyyy/mm 24 | regex.compile(r"(?:^|[\b\s@?,!;:\'\")(.\p{Han}])([1-2][0-9]{3}-[0-3][0-9])(?:$|[\s@,?!;:\'\"(.\p{Han}])"), 25 | ] 26 | 27 | # Patterns for high-risk character strings 28 | id_pattern = r'(?:^|[\b\s@?,!;:\'\")(.\p{Han}])([A-Za-z]*(?:[\p{Pd}]*\p{Nd}){6,})(?:$|[\b\s@?,!;:\'\")(.\p{Han}])' 29 | 30 | # https://regex101.com/r/JQkmh8/5 31 | key_pattern = r'(?:^|[\b\s@?,!:;\'\")(.\p{Han}])((?:(?:[A-Za-z]+[\p{Nd}\p{Pd}\/\+\=:_]+|[\p{Nd}\p{Pd}\/\+\=:]+[A-Za-z]+)){4,}|(?:(?:\p{Nd}{3,}|[A-Z]+\p{Nd}+[A-Z]*|\p{Nd}+[A-Z]+\p{Nd}*)[ \p{Pd}]?){3,})(?:$|[\b\s\p{Han}@?,!;:\'\")(.])' 32 | 33 | ipv4_pattern = r'(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)(?:\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)){3}' 34 | ipv6_pattern = r'(?:[0-9a-fA-F]{1,4}:){7,7}[0-9a-fA-F]{1,4}|(?:[0-9a-fA-F]{1,4}:){1,7}:|(?:[0-9a-fA-F]{1,4}:){1,6}:[0-9a-fA-F]{1,4}|(?:[0-9a-fA-F]{1,4}:){1,5}(?::[0-9a-fA-F]{1,4}){1,2}|(?:[0-9a-fA-F]{1,4}:){1,4}(?::[0-9a-fA-F]{1,4}){1,3}|(?:[0-9a-fA-F]{1,4}:){1,3}(?::[0-9a-fA-F]{1,4}){1,4}|(?:[0-9a-fA-F]{1,4}:){1,2}(?::[0-9a-fA-F]{1,4}){1,5}|[0-9a-fA-F]{1,4}:(?:(?::[0-9a-fA-F]{1,4}){1,6})|:(?:(?::[0-9a-fA-F]{1,4}){1,7}|:)|fe80:(?::[0-9a-fA-F]{0,4}){0,4}%[0-9a-zA-Z]{1,}|::(?:ffff(?::0{1,4}){0,1}:){0,1}(?:(?:25[0-5]|(?:2[0-4]|1{0,1}[0-9]){0,1}[0-9])\.){3,3}(?:25[0-5]|(?:2[0-4]|1{0,1}[0-9]){0,1}[0-9])|(?:[0-9a-fA-F]{1,4}:){1,4}:(?:(?:25[0-5]|(?:2[0-4]|1{0,1}[0-9]){0,1}[0-9])\.){3,3}(25[0-5]|(?:2[0-4]|1{0,1}[0-9]){0,1}[0-9])' 35 | ip_pattern = r"(?:^|[\b\s@?,!;:\'\")(.\p{Han}])(" + r"|".join([ipv4_pattern, ipv6_pattern]) + ")(?:$|[\s@,?!;:\'\"(.\p{Han}])" 36 | 37 | # https://regex101.com/r/EpA5B7/1 38 | email_pattern = r''' 39 | (?<= ^ | [\b\s@,?!;:)('".\p{Han}<] ) 40 | ( 41 | [^\b\s@?!;,:)('"<]+ 42 | @ 43 | [^\b\s@!?;,/]* 44 | [^\b\s@?!;,/:)('">.] 45 | \. 46 | \p{L} \w{1,} 47 | ) 48 | (?= $ | [\b\s@,?!;:)('".\p{Han}>] ) 49 | ''' 50 | 51 | # https://regex101.com/r/mOqi1s/3 52 | user_pattern = r''' 53 | (?<= ^ | [)(\s@,?!;:'"\p{Han}] ) 54 | (@ 55 | [^)(\s@,?!;:'"]{3,} 56 | ) 57 | ''' 58 | # Examples from https://regexpattern.com/phone-number/ 59 | # https://regex101.com/r/lZZ0XP/4 60 | # Also matches MLS numbers 61 | # phone_pattern = r'(?:^|[\s\'\"(\p{Han}])((?:\+\p{Nd}+[ \/.\p{Pd}]*)?(?:(?:\(\+?\p{Nd}+\))?(?:[ \/.\p{Pd}]*\p{Nd})){7,}(?:[\t\f #]*\p{Nd}+)?)(?:$|[\s@,?!;:\'\"(.\p{Han}])' 62 | 63 | id_regex = regex.compile(id_pattern, flags=regex.MULTILINE) #, re.MULTILINE) 64 | key_regex = regex.compile(key_pattern, flags=regex.MULTILINE) #, re.MULTILINE) 65 | ipv4_regex = regex.compile(ipv4_pattern) 66 | ipv6_regex = regex.compile(ipv6_pattern) 67 | ip_regex = regex.compile(ip_pattern, flags=regex.MULTILINE) #, re.MULTILINE) 68 | email_regex = regex.compile(email_pattern, flags=regex.MULTILINE|regex.VERBOSE) #, re.MULTILINE) 69 | user_regex = regex.compile(user_pattern, flags=regex.MULTILINE|regex.VERBOSE) #, re.MULTILINE) 70 | # phone_regex = regex.compile(phone_pattern, flags=regex.MULTILINE) #, re.MULTILINE) 71 | 72 | 73 | 74 | mst_regexes = {} 75 | for tag in high_risk_tags: 76 | if tag == 'ID': 77 | mst_regexes['ID'] = id_regex 78 | elif tag == 'KEY': 79 | mst_regexes['KEY'] = key_regex 80 | elif tag == 'IPv4': 81 | mst_regexes['IPv4'] = ipv4_regex 82 | elif tag == 'IPv6': 83 | mst_regexes['IPv6'] = ipv6_regex 84 | elif tag == 'IP_ADDRESS': 85 | mst_regexes['IP_ADDRESS'] = ip_regex 86 | elif tag == 'EMAIL': 87 | mst_regexes['EMAIL'] = email_regex 88 | elif tag == 'USER': 89 | mst_regexes['USER'] = user_regex 90 | # elif tag == 'NUMBER': 91 | # mst_regexes['NUMBER'] = phone_regex 92 | else: 93 | sys.stderr.write('Dont have tag regex pattern for %s =(' % tag) 94 | 95 | def ip_has_digit(matched_str): 96 | """Checks to make sure the PII span is not just :: or whatever that may 97 | accidentally be picked up by making sure there are digits.""" 98 | return any(map(str.isdigit, matched_str)) 99 | 100 | def matches_date_pattern(matched_str): 101 | # Screen out date false positives 102 | for year_regex in year_patterns: 103 | if year_regex.match(matched_str): 104 | return True 105 | return False 106 | 107 | 108 | def detect_pii(text, lang, tag_types): 109 | matches = [] 110 | for tag in tag_types: 111 | label_pattern = mst_regexes[tag] 112 | # !! regex.match happens here!! 113 | matches_tmp = label_pattern.finditer(text) 114 | for match in matches_tmp: 115 | if match.groups(): 116 | if len(match.groups()) > 1 and match.groups()[1]: 117 | sys.stderr.write("Warning: Found substring matches in the main match.") 118 | 119 | matched_str = match.groups() 120 | 121 | matched_str = matched_str[0] 122 | if matched_str: 123 | if tag in ["IP_ADDRESS"]: 124 | # Filter out false positive IPs 125 | if not ip_has_digit(matched_str): 126 | continue 127 | if tag in ["ID", "IP_ADDRESS"]: #, "NUMBER"]: 128 | # Filter out date false positives 129 | if matches_date_pattern(matched_str): 130 | continue 131 | 132 | matches += [(matched_str, match.span(), str(label_pattern), tag, lang)] 133 | return matches 134 | 135 | 136 | #@title Redaction function defined here. 137 | def redact_pii(text, matches): 138 | """Takes a match as defined in the detect_pii function and redacts it from the full string, returning a